Compare commits

..

73 Commits

Author SHA1 Message Date
Aidan
2acd9d0f0e uri fix 2023-11-28 11:12:18 -05:00
Aidan
49b6e77e4b cleaned 2023-11-28 11:04:54 -05:00
Lance Release
5372843281 Updating package-lock.json 2023-11-15 03:15:10 +00:00
Lance Release
54677b8f0b Updating package-lock.json 2023-11-15 02:42:38 +00:00
Lance Release
ebcf9bf6ae Bump version: 0.3.6 → 0.3.7 2023-11-15 02:42:25 +00:00
Bert
797514bcbf fix: node remote implement table.countRows (#648) 2023-11-13 17:43:20 -05:00
Rok Mihevc
1c872ce501 feat: add RemoteTable.version in Python (#644)
Please note: this is not tested as we don't have a server here and
testing against a mock object wouldn't be that interesting.
2023-11-13 21:43:48 +01:00
Bert
479f471c14 fix: node send db header for GET requests (#646) 2023-11-11 16:33:25 -05:00
Ayush Chaurasia
ae0d2f2599 fix: Pydantic 1.x compat for weak_lru caching in embeddings API (#643)
Colab has pydantic 1.x by default and pydantic 1.x BaseModel objects
don't support weakref creation by default that we use to cache embedding
models
https://github.com/lancedb/lancedb/blob/main/python/lancedb/embeddings/utils.py#L206
. It needs to be added to slot.
2023-11-10 15:02:38 +05:30
Ayush Chaurasia
1e8678f11a Multi-task instructor model with quantization support & weak_lru cache for embedding function models (#612)
resolves #608
2023-11-09 12:34:18 +05:30
QianZhu
662968559d fix saas open_table and table_names issues (#640)
- added check whether a table exists in SaaS open_table
- remove prefilter not supported warning in SaaS search
- fixed issues for SaaS table_names
2023-11-07 17:34:38 -08:00
Rob Meng
9d895801f2 upgrade lance to 0.8.14 (#636)
upgrade lance
2023-11-07 19:01:29 -05:00
Rob Meng
80613a40fd skip missing file on mirrored dir when deleting (#635)
mirrored store is not garueeteed to have all the files. Ignore the ones
that doesn't exist.
2023-11-07 12:33:32 -05:00
Lei Xu
d43ef7f11e chore: apple silicon runner (#633)
Close #632
2023-11-06 21:04:32 -08:00
Lei Xu
554e068917 chore: improve create_table API consistency between local and remote SDK (#627) 2023-11-03 13:15:11 -07:00
Bert
567734dd6e fix: node remote connection handles non http errors (#624)
https://github.com/lancedb/lancedb/issues/623

Fixes issue trying to print response status when using remote client. If
the error is not an HTTP error (e.g. dns/network failure), there won't
be a response.
2023-11-03 10:24:56 -04:00
Ayush Chaurasia
1589499f89 Exponential standoff retry support for handling rate limited embedding functions (#614)
Users ingesting data using rate limited apis don't need to manually make
the process sleep for counter rate limits
resolves #579
2023-11-02 19:20:10 +05:30
Lance Release
682e95fa83 Updating package-lock.json 2023-11-01 22:20:49 +00:00
Lance Release
1ad5e7f2f0 Updating package-lock.json 2023-11-01 21:16:20 +00:00
Lance Release
ddb3ef4ce5 Bump version: 0.3.5 → 0.3.6 2023-11-01 21:16:06 +00:00
Lance Release
ef20b2a138 [python] Bump version: 0.3.2 → 0.3.3 2023-11-01 21:15:55 +00:00
Lei Xu
2e0f251bfd chore: bump lance to 8.10 (#622) 2023-11-01 14:14:38 -07:00
Ayush Chaurasia
2cb91e818d Disable posthog on docs & reduce sentry trace factor (#607)
- posthog charges per event and docs events are registered very
frequently. We can keep tracking them on GA
- Reduced sentry trace factor
2023-11-02 01:13:16 +05:30
Chang She
2835c76336 doc: node sdk now supports windows (#616) 2023-11-01 10:04:18 -07:00
Bert
8068a2bbc3 ci: cancel in progress runs on new push (#620) 2023-11-01 11:33:48 -04:00
Bert
24111d543a fix!: sort table names (#619)
https://github.com/lancedb/lance/issues/1385
2023-11-01 10:50:09 -04:00
QianZhu
7eec2b8f9a Qian/query option doc (#615)
- API documentation improvement for queries (table.search)
- a small bug fix for the remote API on create_table

![image](https://github.com/lancedb/lancedb/assets/1305083/712e9bd3-deb8-4d81-8cd0-d8e98ef68f4e)

![image](https://github.com/lancedb/lancedb/assets/1305083/ba22125a-8c36-4e34-a07f-e39f0136e62c)
2023-10-31 19:50:05 -07:00
Will Jones
b2b70ea399 increment pylance (#618) 2023-10-31 18:07:03 -07:00
Bert
e50a3c1783 added api docs for prefilter flag (#617)
Added the prefilter flag argument to the `LanceQueryBuilder.where`.

This should make it display here:

https://lancedb.github.io/lancedb/python/python/#lancedb.query.LanceQueryBuilder.select

And also in intellisense like this:
<img width="848" alt="image"
src="https://github.com/lancedb/lancedb/assets/5846846/e0c53f4f-96bc-411b-9159-680a6c4d0070">

Also adds some improved documentation about the `where` argument to this
method.

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
2023-10-31 16:39:32 -04:00
Weston Pace
b517134309 feat: allow prefiltering with index (#610)
Support for prefiltering with an index was added in lance version 0.8.7.
We can remove the lancedb check that prevents this. Closes #261
2023-10-31 13:11:03 -07:00
Lei Xu
6fb539b5bf doc: add doc to use GPU for indexing (#611) 2023-10-30 15:25:00 -07:00
Lance Release
f37fe120fd Updating package-lock.json 2023-10-26 22:30:16 +00:00
Lance Release
2e115acb9a Updating package-lock.json 2023-10-26 21:48:01 +00:00
Lance Release
27a638362d Bump version: 0.3.4 → 0.3.5 2023-10-26 21:47:44 +00:00
Bert
22a6695d7a fix conv version (#605) 2023-10-26 17:44:11 -04:00
Lance Release
57eff82ee7 Updating package-lock.json 2023-10-26 21:03:07 +00:00
Lance Release
7732f7d41c Bump version: 0.3.3 → 0.3.4 2023-10-26 21:02:52 +00:00
Bert
5ca98c326f feat: added dataset stats api to node (#604) 2023-10-26 17:00:48 -04:00
Bert
b55db397eb feat: added data stats apis (#596) 2023-10-26 13:10:17 -04:00
Rob Meng
c04d72ac8a expose remap index api (#603)
expose index remap options in `compact_files`
2023-10-25 22:10:37 -04:00
Rob Meng
28b02fb72a feat: expose optimize index api (#602)
expose `optimize_index` api.
2023-10-25 19:40:23 -04:00
Lance Release
f3cf986777 [python] Bump version: 0.3.1 → 0.3.2 2023-10-24 19:06:38 +00:00
Bert
c73fcc8898 update lance to 0.8.7 (#598) 2023-10-24 14:49:36 -04:00
Chang She
cd9debc3b7 fix(python): fix multiple embedding functions bug (#597)
Closes #594

The embedding functions are pydantic models so multiple instances with
the same parameters are considered ==, which means that if you have
multiple embedding columns it's possible for the embeddings to get
overwritten. Instead we use `is` instead of == to avoid this problem.

testing: modified unit test to include this case
2023-10-24 13:05:05 -04:00
Rob Meng
26a97ba997 feat: add checkout method to table to reuse existing store and connections (#593)
Prior to this PR, to get a new version of a table, we need to re-open
the table. This has a few downsides w.r.t. performance:
* Object store is recreated, which takes time and throws away existing
warm connections
* Commit handler is thrown aways as well, which also may contain warm
connections
2023-10-23 12:06:13 -04:00
Rob Meng
ce19fedb08 feat: include manifest files in mirrow store (#589) 2023-10-21 12:21:41 -04:00
Will Jones
14e8e48de2 Revert "[python] Bump version: 0.3.2 → 0.3.3"
This reverts commit c30faf6083.
2023-10-20 17:52:49 -07:00
Will Jones
c30faf6083 [python] Bump version: 0.3.2 → 0.3.3 2023-10-20 17:30:00 -07:00
Ayush Chaurasia
64a4f025bb [Docs]: Minor Fixes (#587)
* Filename typo
* Remove rick_morty csv as users won't really be able to use it.. We can
create a an executable colab and download it from a bucket or smth.
2023-10-20 16:14:35 +02:00
Ayush Chaurasia
6dc968e7d3 [Docs] Embeddings API: Add multi-lingual semantic search example (#582) 2023-10-20 18:40:49 +05:30
Ayush Chaurasia
06b5b69f1e [Docs]Versioning docs (#586)
closes #564

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-10-20 18:40:16 +05:30
Lance Release
6bd3a838fc Updating package-lock.json 2023-10-19 20:45:39 +00:00
Lance Release
f36fea8f20 Updating package-lock.json 2023-10-19 20:06:10 +00:00
Lance Release
0a30591729 Bump version: 0.3.2 → 0.3.3 2023-10-19 20:05:57 +00:00
Chang She
0ed39b6146 chore: bump lance version in python/rust lancedb (#584)
To include latest v0.8.6

Co-authored-by: Chang She <chang@lancedb.com>
2023-10-19 13:05:12 -07:00
Ayush Chaurasia
a8c7f80073 [Docs] Update embedding function docs (#581) 2023-10-18 13:04:42 +05:30
Ayush Chaurasia
0293bbe142 [Python]Embeddings API refactor (#580)
Sets things up for this -> https://github.com/lancedb/lancedb/issues/579
- Just separates out the registry/ingestion code from the function
implementation code
- adds a `get_registry` util
- package name "open-clip" -> "open-clip-torch"
2023-10-17 22:32:19 -07:00
Ayush Chaurasia
7372656369 [Docs] Add posthog telemetry to docs (#577)
Allows creation of funnels and user journeys
2023-10-17 21:11:59 -07:00
QianZhu
d46bc5dd6e list table pagination draft (#574) 2023-10-16 21:09:20 -07:00
Prashanth Rao
86efb11572 Add pyarrow date and timestamp type conversion from pydantic (#576) 2023-10-16 19:42:24 -07:00
Chang She
bb01ad5290 doc: fix broken link and add README (#573)
Fix broken link to embedding functions

testing: broken link was verified after local docs build to have been
repaired

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-10-16 16:13:07 -07:00
Lance Release
1b8cda0941 Updating package-lock.json 2023-10-16 16:10:07 +00:00
Lance Release
bc85a749a3 Updating package-lock.json 2023-10-16 15:12:15 +00:00
Lance Release
02c35d3457 Bump version: 0.3.1 → 0.3.2 2023-10-16 15:11:57 +00:00
Rob Meng
345c136cfb implement remote api calls for table mutation (#567)
Add more APIs to remote table for Node SDK
* `add` rows
* `overwrite` table with rows
* `create` table

This has been tested against dev stack
2023-10-16 11:07:58 -04:00
Rok Mihevc
043e388254 docs: show source of documented functions (#569) 2023-10-15 09:05:36 -07:00
Lei Xu
fe64fc4671 feat(python,js): deletion operation on remote tables (#568) 2023-10-14 15:47:19 -07:00
Rok Mihevc
6d66404506 docs: switch python examples to be row based (#554) 2023-10-14 14:07:43 -07:00
Lei Xu
eff94ecea8 chore: bump lance to 0.8.5 (#561)
Bump lance to 0.5.8
2023-10-14 12:38:43 -07:00
Ayush Chaurasia
7dfb555fea [DOCS][PYTHON] Update embeddings API docs & Example (#516)
This PR adds an overview of embeddings docs:
- 2 ways to vectorize your data using lancedb - explicit & implicit
- explicit - manually vectorize your data using `wit_embedding` function
- Implicit - automatically vectorize your data as it comes by ingesting
your embedding function details as table metadata
- Multi-modal example w/ disappearing embedding function
2023-10-14 07:56:07 +05:30
Lance Release
f762a669e7 Updating package-lock.json 2023-10-13 22:27:48 +00:00
Lance Release
0bdc7140dd Updating package-lock.json 2023-10-13 21:24:05 +00:00
Lance Release
8f6e955b24 Bump version: 0.3.0 → 0.3.1 2023-10-13 21:23:54 +00:00
77 changed files with 5248 additions and 925 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.0
current_version = 0.3.7
commit = True
message = Bump version: {current_version} → {new_version}
tag = True

View File

@@ -11,6 +11,10 @@ on:
- .github/workflows/node.yml
- docker-compose.yml
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
# Disable full debug symbol generation to speed up CI build and keep memory down
# "1" means line tables only, which is useful for panic tracebacks.

View File

@@ -38,7 +38,7 @@ jobs:
node/vectordb-*.tgz
node-macos:
runs-on: macos-12
runs-on: macos-13
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
strategy:

View File

@@ -8,6 +8,11 @@ on:
paths:
- python/**
- .github/workflows/python.yml
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
linux:
timeout-minutes: 30
@@ -32,18 +37,19 @@ jobs:
run: |
pip install -e .[tests]
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock black isort
- name: Black
run: black --check --diff --no-color --quiet .
- name: isort
run: isort --check --diff --quiet .
pip install pytest pytest-mock ruff
- name: Lint
run: ruff format --check .
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest
run: pytest --doctest-modules lancedb
mac:
timeout-minutes: 30
runs-on: "macos-12"
strategy:
matrix:
mac-runner: [ "macos-13", "macos-13-xlarge" ]
runs-on: "${{ matrix.mac-runner }}"
defaults:
run:
shell: bash
@@ -62,8 +68,6 @@ jobs:
pip install -e .[tests]
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock black
- name: Black
run: black --check --diff --no-color --quiet .
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
pydantic1x:
@@ -95,4 +99,4 @@ jobs:
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest
run: pytest --doctest-modules lancedb
run: pytest --doctest-modules lancedb

View File

@@ -10,6 +10,10 @@ on:
- rust/**
- .github/workflows/rust.yml
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
# This env var is used by Swatinem/rust-cache@v2 for the cache
# key, so we set it to make sure it is always consistent.
@@ -44,8 +48,11 @@ jobs:
- name: Run tests
run: cargo test --all-features
macos:
runs-on: macos-12
timeout-minutes: 30
strategy:
matrix:
mac-runner: [ "macos-13", "macos-13-xlarge" ]
runs-on: "${{ matrix.mac-runner }}"
defaults:
run:
shell: bash

View File

@@ -5,23 +5,23 @@ exclude = ["python"]
resolver = "2"
[workspace.dependencies]
lance = { "version" = "=0.8.3", "features" = ["dynamodb"] }
lance-linalg = { "version" = "=0.8.3" }
lance-testing = { "version" = "=0.8.3" }
lance = { "version" = "=0.8.14", "features" = ["dynamodb"] }
lance-linalg = { "version" = "=0.8.14" }
lance-testing = { "version" = "=0.8.14" }
# Note that this one does not include pyarrow
arrow = { version = "43.0.0", optional = false }
arrow-array = "43.0"
arrow-data = "43.0"
arrow-ipc = "43.0"
arrow-ord = "43.0"
arrow-schema = "43.0"
arrow-arith = "43.0"
arrow-cast = "43.0"
arrow = { version = "47.0.0", optional = false }
arrow-array = "47.0"
arrow-data = "47.0"
arrow-ipc = "47.0"
arrow-ord = "47.0"
arrow-schema = "47.0"
arrow-arith = "47.0"
arrow-cast = "47.0"
chrono = "0.4.23"
half = { "version" = "=2.2.1", default-features = false, features = [
"num-traits"
half = { "version" = "=2.3.1", default-features = false, features = [
"num-traits",
] }
log = "0.4"
object_store = "0.6.1"
object_store = "0.7.1"
snafu = "0.7.4"
url = "2"

26
docs/README.md Normal file
View File

@@ -0,0 +1,26 @@
# LanceDB Documentation
LanceDB docs are deployed to https://lancedb.github.io/lancedb/.
Docs is built and deployed automatically by [Github Actions](.github/workflows/docs.yml)
whenever a commit is pushed to the `main` branch. So it is possible for the docs to show
unreleased features.
## Building the docs
### Setup
1. Install LanceDB. From LanceDB repo root: `pip install -e python`
2. Install dependencies. From LanceDB repo root: `pip install -r docs/requirements.txt`
3. Make sure you have node and npm setup
4. Make sure protobuf and libssl are installed
### Building node module and create markdown files
See [Javascript docs README](docs/src/javascript/README.md)
### Build docs
From LanceDB repo root:
Run: `PYTHONPATH=. mkdocs build -f docs/mkdocs.yml`
If successful, you should see a `docs/site` directory that you can verify locally.

View File

@@ -37,7 +37,7 @@ plugins:
docstring_style: numpy
rendering:
heading_level: 4
show_source: false
show_source: true
show_symbol_type_in_heading: true
show_signature_annotations: true
show_root_heading: true
@@ -73,7 +73,14 @@ nav:
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- 🧬 Embeddings: embedding.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
- 🧬 Embeddings:
- embeddings/index.md
- Ingest Embedding Functions: embeddings/embedding_functions.md
- Available Functions: embeddings/default_embedding_functions.md
- Create Custom Embedding Functions: embeddings/api.md
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
- 🔍 Python full-text search: fts.md
- 🔌 Integrations:
- integrations/index.md
@@ -105,7 +112,14 @@ nav:
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- Embeddings: embedding.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Embeddings:
- embeddings/index.md
- Ingest Embedding Functions: embeddings/embedding_functions.md
- Available Functions: embeddings/default_embedding_functions.md
- Create Custom Embedding Functions: embeddings/api.md
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
- Python full-text search: fts.md
- Integrations:
- integrations/index.md

View File

@@ -71,9 +71,41 @@ a single PQ code.
### Use GPU to build vector index
Lance Python SDK has experimental GPU support for creating IVF index.
Using GPU for index creation requires [PyTorch>2.0](https://pytorch.org/) being installed.
You can specify the GPU device to train IVF partitions via
- **accelerator**: Specify to `"cuda"`` to enable GPU training.
- **accelerator**: Specify to ``cuda`` or ``mps`` (on Apple Silicon) to enable GPU training.
=== "Linux"
<!-- skip-test -->
``` { .python .copy }
# Create index using CUDA on Nvidia GPUs.
tbl.create_index(
num_partitions=256,
num_sub_vectors=96,
accelerator="cuda"
)
```
=== "Macos"
<!-- skip-test -->
```python
# Create index using MPS on Apple Silicon.
tbl.create_index(
num_partitions=256,
num_sub_vectors=96,
accelerator="mps"
)
```
Trouble shootings:
If you see ``AssertionError: Torch not compiled with CUDA enabled``, you need to [install
PyTorch with CUDA support](https://pytorch.org/get-started/locally/).
## Querying an ANN Index

Binary file not shown.

After

Width:  |  Height:  |  Size: 342 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB

213
docs/src/embeddings/api.md Normal file
View File

@@ -0,0 +1,213 @@
To use your own custom embedding function, you need to follow these 2 simple steps.
1. Create your embedding function by implementing the `EmbeddingFunction` interface
2. Register your embedding function in the global `EmbeddingFunctionRegistry`.
Let us see how this looks like in action.
![](../assets/embeddings_api.png)
`EmbeddingFunction` & `EmbeddingFunctionRegistry` handle low-level details for serializing schema and model information as metadata. To build a custom embdding function, you don't need to worry about those details and simply focus on setting up the model.
## `TextEmbeddingFunction` Interface
There is another optional layer of abstraction provided in form of `TextEmbeddingFunction`. You can use this if your model isn't multi-modal in nature and only operates on text. In such case both source and vector fields will have the same pathway for vectorization, so you simply just need to setup the model and rest is handled by `TextEmbeddingFunction`. You can read more about the class and its attributes in the class reference.
Let's implement `SentenceTransformerEmbeddings` class. All you need to do is implement the `generate_embeddings()` and `ndims` function to handle the input types you expect and register the class in the global `EmbeddingFunctionRegistry`
```python
from lancedb.embeddings import register
@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
name: str = "all-MiniLM-L6-v2"
# set more default instance vars like device, etc.
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
def generate_embeddings(self, texts):
return self._embedding_model().encode(list(texts), ...).tolist()
def ndims(self):
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
@cached(cache={})
def _embedding_model(self):
return sentence_transformers.SentenceTransformer(name)
```
This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and defaul settings.
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
```python
from lancedb.pydantic import LanceModel, Vector
registry = EmbeddingFunctionRegistry.get_instance()
stransformer = registry.get("sentence-transformers").create()
class TextModelSchema(LanceModel):
vector: Vector(stransformer.ndims) = stransformer.VectorField()
text: str = stransformer.SourceField()
tbl = db.create_table("table", schema=TextModelSchema)
tbl.add(pd.DataFrame({"text": ["halo", "world"]}))
result = tbl.search("world").limit(5)
```
NOTE:
You can always implement the `EmbeddingFunction` interface directly if you want or need to, `TextEmbeddingFunction` just makes it much simpler and faster for you to do so, by setting up the boiler plat for text-specific use case
## Multi-modal embedding function example
You can also use the `EmbeddingFunction` interface to implement more complex workflows such as multi-modal embedding function support. LanceDB implements `OpenClipEmeddingFunction` class that suppports multi-modal seach. Here's the implementation that you can use as a reference to build your own multi-modal embedding functions.
```python
@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):
name: str = "ViT-B-32"
pretrained: str = "laion2b_s34b_b79k"
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_preprocess = PrivateAttr()
_tokenizer = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
model.to(self.device)
self._model, self._preprocess = model, preprocess
self._tokenizer = open_clip.get_tokenizer(self.name)
self._ndims = None
def ndims(self):
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
with torch.no_grad():
text_features = self._model.encode_text(text.to(self.device))
if self.normalize:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().squeeze()
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given images
"""
images = self.sanitize_input(images)
embeddings = []
for i in range(0, len(images), self.batch_size):
j = min(i + self.batch_size, len(images))
batch = images[i:j]
embeddings.extend(self._parallel_get(batch))
return embeddings
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
"""
Issue concurrent requests to retrieve the image data
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [future.result() for future in futures]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
with torch.no_grad():
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
"""
encode a single image tensor and optionally normalize the output
"""
image_features = self._model.encode_image(image_tensor)
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
```

View File

@@ -0,0 +1,156 @@
There are various Embedding functions available out of the box with lancedb. We're working on supporting other popular embedding APIs.
## Text Embedding Functions
Here are the text embedding functions registered by default
### Sentence Transformers
Here are the parameters that you can set when registering a `sentence-transformers` object, and their default values:
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `"all-MiniLM-L6-v2"` | The name of the model. |
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
| `normalize` | `bool` | `True` | Whether to normalize the input text before feeding it to the model. |
```python
db = lancedb.connect("/tmp/db")
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get("sentence-transformers").create(device="cpu")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words)
table.add(
[
{"text": "hello world"}
{"text": "goodbye world"}
]
)
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
```
### OpenAIEmbeddings
LanceDB has OpenAI embeddings function in the registry by default. It is registered as `openai` and here are the parameters that you can customize when creating the instances
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `"text-embedding-ada-002"` | The name of the model. |
```python
db = lancedb.connect("/tmp/db")
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get("openai").create()
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words)
table.add(
[
{"text": "hello world"}
{"text": "goodbye world"}
]
)
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
```
## Multi-modal embedding functions
Multi-modal embedding functions allow you query your table using both images and text.
### OpenClipEmbeddings
We support CLIP model embeddings using the open souce alternbative, open-clip which support various customizations. It is registered as `open-clip` and supports following customizations.
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `"ViT-B-32"` | The name of the model. |
| `pretrained` | `str` | `"laion2b_s34b_b79k"` | The name of the pretrained model to load. |
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
| `batch_size` | `int` | `64` | The number of images to process in a batch. |
| `normalize` | `bool` | `True` | Whether to normalize the input images before feeding them to the model. |
This embedding function supports ingesting images as both bytes and urls. You can query them using both test and other images.
NOTE:
LanceDB supports ingesting images directly from accessible links.
```python
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get("open-clip").create()
class Images(LanceModel):
label: str
image_uri: str = func.SourceField() # image uri as the source
image_bytes: bytes = func.SourceField() # image bytes as the source
vector: Vector(func.ndims()) = func.VectorField() # vector column
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
[{"label": labels, "image_uri": uris, "image_bytes": image_bytes}]
)
```
Now we can search using text from both the default vector column and the custom vector column
```python
# text search
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
print(actual.label) # prints "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(frombytes.label)
```
Because we're using a multi-modal embedding function, we can also search using images
```python
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(io.BytesIO(image_bytes))
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
print(actual.label == "dog")
# image search using a custom vector column
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
print(actual.label)
```
If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue.

View File

@@ -0,0 +1,82 @@
Representing multi-modal data as vector embeddings is becoming a standard practice. Embedding functions themselves be thought of as a part of the processing pipeline that each request(input) has to be passed through. After initial setup these components are not expected to change for a particular project.
This is main motivation behind our new embedding functions API, that allow you simply set it up once and the table remembers it, effectively making the **embedding functions disappear in the background** so you don't have to worry about modelling and simply focus on the DB aspects of VectorDB.
You can simply follow these steps and forget about the details of your embedding functions as long as you don't intend to change it.
### Step 1 - Define the embedding function
We have some pre-defined embedding functions in the global registry with more coming soon. Here's let's an implementation of CLIP as example.
```
registry = EmbeddingFunctionRegistry.get_instance()
clip = registry.get("open-clip").create()
```
You can also define your own embedding function by implementing the `EmbeddingFunction` abstract base interface. It subclasses PyDantic Model which can be utilized to write complex schemas simply as we'll see next!
### Step 2 - Define the Data Model or Schema
Our embedding function from the previous section abstracts away all the details about the models and dimensions required to define the schema. You can simply set a feild as **source** or **vector** column. Here's how
```python
class Pets(LanceModel):
vector: Vector(clip.ndims) = clip.VectorField()
image_uri: str = clip.SourceField()
```
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for `vector` column & `SourceField` tells that when adding data, automatically use the embedding function to encode `image_uri`.
### Step 3 - Create LanceDB Table
Now that we have chosen/defined our embedding function and the schema, we can create the table
```python
db = lancedb.connect("~/lancedb")
table = db.create_table("pets", schema=Pets)
```
That's it! We have ingested all the information needed to embed source and query inputs. We can now forget about the model and dimension details and start to build or VectorDB
### Step 4 - Ingest lots of data and run vector search!
Now you can just add the data and it'll be vectorized automatically
```python
table.add([{"image_uri": u} for u in uris])
```
Our OpenCLIP query embedding function support querying via both text and images.
```python
result = table.search("dog")
```
Let's query an image
```python
p = Path("path/to/images/samoyed_100.jpg")
query_image = Image.open(p)
table.search(query_image)
```
### A little fun with PyDantic
LanceDB is integrated with PyDantic. Infact we've used the integration in the above example to define the schema. It is also being used behing the scene by the embdding function API to ingest useful information as table metadata.
You can also use it for adding utility operations in the schema. For example, in our multi-modal example, you can search images using text or another image. Let us define a utility function to plot the image.
```python
class Pets(LanceModel):
vector: Vector(clip.ndims) = clip.VectorField()
image_uri: str = clip.SourceField()
@property
def image(self):
return Image.open(self.image_uri)
```
Now, you can covert your search results to pydantic model and use this property.
```python
rs = table.search(query_image).limit(3).to_pydantic(Pets)
rs[2].image
```
![](../assets/dog_clip_output.png)
Now that you've the basic idea about LanceDB embedding function, let us now dive deeper into the API that you can use to implement your own embedding functions!

View File

@@ -1,13 +1,20 @@
# Embedding Functions
# Embedding
Embeddings are high dimensional floating-point vector representations of your data or query.
Anything can be embedded using some embedding model or function.
For a given embedding function, the output will always have the same number of dimensions.
Embeddings are high dimensional floating-point vector representations of your data or query. Anything can be embedded using some embedding model or function. Position of embedding in a high dimensional vector space has semantic significance to a degree that depends on the type of modal and training. These embeddings when projected in a 2-D space generally group similar entities close-by forming groups.
## Creating an embedding function
![](../assets/embedding_intro.png)
Any function that takes as input a batch (list) of data and outputs a batch (list) of embeddings
can be used by LanceDB as an embedding function. The input and output batch sizes should be the same.
# Creating an embedding function
LanceDB supports 2 major ways of vectorizing your data, explicit and implicit.
1. By manually embedding the data before ingesting in the table
2. By automatically embedding the data and query as they come, by ingesting embedding function information in the table itself! Covered in [Next Section](embedding_functions.md)
Whatever workflow you prefer, we have the tools to support you.
## Explicit Vectorization
In this workflow, you can create your embedding function and vectorize your data using lancedb's `with_embedding` function. Let's look at some examples.
### HuggingFace example
@@ -134,9 +141,9 @@ belong in the same latent space and your results will be nonsensical.
The above snippet returns an array of records with the 10 closest vectors to the query.
## Roadmap
## Implicit vectorization / Ingesting embedding functions
Representing multi-modal data as vector embeddings is becoming a standard practice. Embedding functions themselves be thought of as a part of the processing pipeline that each request(input) has to be passed through. After initial setup these components are not expected to change for a particular project.
In the near future, we'll be integrating the embedding functions deeper into LanceDB<br/>.
The goal is that you just have to configure the function once when you create the table,
and then you'll never have to deal with embeddings / vectors after that unless you want to.
We'll also integrate more popular models and APIs.
This is main motivation behind our new embedding functions API, that allow you simply set it up once and the table remembers it, effectively making the **embedding functions disappear in the background** so you don't have to worry about modelling and simply focus on the DB aspects of VectorDB.
Learn more in the Next Section

View File

@@ -251,8 +251,9 @@ After a table has been created, you can always add more data to it using
### Adding Pandas DataFrame
```python
df = pd.DataFrame([{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}])
df = pd.DataFrame({
"vector": [[1.3, 1.4], [9.5, 56.2]], "item": ["fizz", "buzz"], "price": [100.0, 200.0]
})
tbl.add(df)
```
@@ -261,17 +262,12 @@ After a table has been created, you can always add more data to it using
### Adding to table using Iterator
```python
import pandas as pd
def make_batches():
for i in range(5):
yield pd.DataFrame(
{
"vector": [[3.1, 4.1], [1, 1]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
})
yield [
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
]
tbl.add(make_batches())
```
@@ -306,9 +302,10 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
```python
import lancedb
import pandas as pd
data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
data = [{"x": 1, "vector": [1, 2]},
{"x": 2, "vector": [3, 4]},
{"x": 3, "vector": [5, 6]}]
db = lancedb.connect("./.lancedb")
table = db.create_table("my_table", data)
table.to_pandas()

View File

@@ -67,7 +67,7 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
## Documentation Quick Links
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
* [`Embedding Functions`](embedding.md) - functions for working with embeddings.
* [`Embedding Functions`](embeddings/index.md) - functions for working with embeddings.
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -114,13 +114,10 @@
}
],
"source": [
"import pandas as pd\n",
"\n",
"data = pd.DataFrame({\n",
" \"vector\": [[1.1, 1.2], [0.2, 1.8]],\n",
" \"lat\": [45.5, 40.1],\n",
" \"long\": [-122.7, -74.1]\n",
"})\n",
"data = [\n",
" {\"vector\": [1.1, 1.2], \"lat\": 45.5, \"long\": -122.7},\n",
" {\"vector\": [0.2, 1.8], \"lat\": 40.1, \"long\": -74.1},\n",
"]\n",
"\n",
"db.create_table(\"table2\", data)\n",
"\n",
@@ -366,11 +363,11 @@
"def make_batches():\n",
" for i in range(5):\n",
" yield pd.DataFrame(\n",
" {\n",
" \"vector\": [[3.1, 4.1], [1, 1]],\n",
" \"item\": [\"foo\", \"bar\"],\n",
" \"price\": [10.0, 20.0],\n",
" })\n",
" {\n",
" \"vector\": [[3.1, 4.1], [1, 1]],\n",
" \"item\": [\"foo\", \"bar\"],\n",
" \"price\": [10.0, 20.0],\n",
" })\n",
"\n",
"tbl = db.create_table(\"table5\", make_batches(), schema=PydanticSchema)\n",
"tbl.schema"
@@ -572,9 +569,11 @@
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame([{\"vector\": [1.3, 1.4], \"item\": \"fizz\", \"price\": 100.0},\n",
" {\"vector\": [9.5, 56.2], \"item\": \"buzz\", \"price\": 200.0}])\n",
"tbl.add(df)"
"data = [\n",
" {\"vector\": [1.3, 1.4], \"item\": \"fizz\", \"price\": 100.0},\n",
" {\"vector\": [9.5, 56.2], \"item\": \"buzz\", \"price\": 200.0}\n",
"]\n",
"tbl.add(data)"
]
},
{
@@ -596,17 +595,12 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"import pandas as pd\n",
"\n",
"def make_batches():\n",
" for i in range(5):\n",
" yield pd.DataFrame(\n",
" {\n",
" \"vector\": [[3.1, 4.1], [1, 1]],\n",
" \"item\": [\"foo\", \"bar\"],\n",
" \"price\": [10.0, 20.0],\n",
" })\n",
" yield [\n",
" {\"vector\": [3.1, 4.1], \"item\": \"foo\", \"price\": 10.0},\n",
" {\"vector\": [1, 1], \"item\": \"bar\", \"price\": 20.0},\n",
" ]\n",
"tbl.add(make_batches())"
]
},

View File

@@ -39,7 +39,6 @@ to lazily generate data:
from typing import Iterable
import pyarrow as pa
import lancedb
def make_batches() -> Iterable[pa.RecordBatch]:
for i in range(5):

View File

@@ -11,15 +11,13 @@ pip install duckdb lancedb
We will re-use [the dataset created previously](./arrow.md):
```python
import pandas as pd
import lancedb
db = lancedb.connect("data/sample-lancedb")
data = pd.DataFrame({
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0]
})
data = [
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
]
table = db.create_table("pd_table", data=data)
arrow_table = table.to_arrow()
```

View File

@@ -22,21 +22,19 @@ pip install lancedb
::: lancedb.query.LanceQueryBuilder
::: lancedb.query.LanceFtsQueryBuilder
## Embeddings
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
::: lancedb.embeddings.functions.EmbeddingFunction
::: lancedb.embeddings.base.EmbeddingFunction
::: lancedb.embeddings.functions.TextEmbeddingFunction
::: lancedb.embeddings.base.TextEmbeddingFunction
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
::: lancedb.embeddings.sentence_transformers.SentenceTransformerEmbeddings
::: lancedb.embeddings.functions.OpenAIEmbeddings
::: lancedb.embeddings.openai.OpenAIEmbeddings
::: lancedb.embeddings.functions.OpenClipEmbeddings
::: lancedb.embeddings.open_clip.OpenClipEmbeddings
::: lancedb.embeddings.with_embeddings
@@ -56,7 +54,7 @@ pip install lancedb
## Utilities
::: lancedb.vector
::: lancedb.schema.vector
## Integrations

View File

@@ -0,0 +1,4 @@
window.addEventListener("DOMContentLoaded", (event) => {
!function(t,e){var o,n,p,r;e.__SV||(window.posthog=e,e._i=[],e.init=function(i,s,a){function g(t,e){var o=e.split(".");2==o.length&&(t=t[o[0]],e=o[1]),t[e]=function(){t.push([e].concat(Array.prototype.slice.call(arguments,0)))}}(p=t.createElement("script")).type="text/javascript",p.async=!0,p.src=s.api_host+"/static/array.js",(r=t.getElementsByTagName("script")[0]).parentNode.insertBefore(p,r);var u=e;for(void 0!==a?u=e[a]=[]:a="posthog",u.people=u.people||[],u.toString=function(t){var e="posthog";return"posthog"!==a&&(e+="."+a),t||(e+=" (stub)"),e},u.people.toString=function(){return u.toString(1)+".people (stub)"},o="capture identify alias people.set people.set_once set_config register register_once unregister opt_out_capturing has_opted_out_capturing opt_in_capturing reset isFeatureEnabled onFeatureFlags getFeatureFlag getFeatureFlagPayload reloadFeatureFlags group updateEarlyAccessFeatureEnrollment getEarlyAccessFeatures getActiveMatchingSurveys getSurveys".split(" "),n=0;n<o.length;n++)g(u,o[n]);e._i.push([i,s,a])},e.__SV=1)}(document,window.posthog||[]);
posthog.init('phc_oENDjGgHtmIDrV6puUiFem2RB4JA8gGWulfdulmMdZP',{api_host:'https://app.posthog.com'})
});

View File

@@ -4,7 +4,7 @@
In a recommendation system or search engine, you can find similar products from
the one you searched.
In LLM and other AI applications,
each data point can be [presented by the embeddings generated from some models](embedding.md),
each data point can be [presented by the embeddings generated from some models](embeddings/index.md),
it returns the most relevant features.
A search in high-dimensional vector space, is to find `K-Nearest-Neighbors (KNN)` of the query vector.

View File

@@ -8,6 +8,7 @@ const excludedGlobs = [
"../src/embedding.md",
"../src/examples/*.md",
"../src/guides/tables.md",
"../src/embeddings/*.md",
];
const nodePrefix = "javascript";

View File

@@ -10,6 +10,7 @@ excluded_globs = [
"../src/integrations/voxel51.md",
"../src/guides/tables.md",
"../src/python/duckdb.md",
"../src/embeddings/*.md",
]
python_prefix = "py"
@@ -17,29 +18,45 @@ python_file = ".py"
python_folder = "python"
files = glob.glob(glob_string, recursive=True)
excluded_files = [f for excluded_glob in excluded_globs for f in glob.glob(excluded_glob, recursive=True)]
excluded_files = [
f
for excluded_glob in excluded_globs
for f in glob.glob(excluded_glob, recursive=True)
]
def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
in_code_block = False
# Python code has strict indentation
strip_length = 0
skip_test = False
for line in lines:
if "skip-test" in line:
skip_test = True
if line.strip().startswith(prefix + python_prefix):
in_code_block = True
strip_length = len(line) - len(line.lstrip())
elif in_code_block and line.strip().startswith(suffix):
in_code_block = False
yield "\n"
if not skip_test:
yield "\n"
skip_test = False
elif in_code_block:
yield line[strip_length:]
if not skip_test:
yield line[strip_length:]
for file in filter(lambda file: file not in excluded_files, files):
with open(file, "r") as f:
lines = list(yield_lines(iter(f), "```", "```"))
if len(lines) > 0:
out_path = Path(python_folder) / Path(file).name.strip(".md") / (Path(file).name.strip(".md") + python_file)
print(lines)
out_path = (
Path(python_folder)
/ Path(file).name.strip(".md")
/ (Path(file).name.strip(".md") + python_file)
)
print(out_path)
out_path.parent.mkdir(exist_ok=True, parents=True)
with open(out_path, "w") as out:
out.writelines(lines)
out.writelines(lines)

View File

@@ -10,7 +10,7 @@ npm install vectordb
This will download the appropriate native library for your platform. We currently
support x86_64 Linux, aarch64 Linux, Intel MacOS, and ARM (M1/M2) MacOS. We do not
yet support Windows or musl-based Linux (such as Alpine Linux).
yet support musl-based Linux (such as Alpine Linux).
## Usage

74
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.3.0",
"version": "0.3.7",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.3.0",
"version": "0.3.7",
"cpu": [
"x64",
"arm64"
@@ -53,11 +53,11 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.3.0",
"@lancedb/vectordb-darwin-x64": "0.3.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.3.0",
"@lancedb/vectordb-linux-x64-gnu": "0.3.0",
"@lancedb/vectordb-win32-x64-msvc": "0.3.0"
"@lancedb/vectordb-darwin-arm64": "0.3.7",
"@lancedb/vectordb-darwin-x64": "0.3.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.3.7",
"@lancedb/vectordb-linux-x64-gnu": "0.3.7",
"@lancedb/vectordb-win32-x64-msvc": "0.3.7"
}
},
"node_modules/@apache-arrow/ts": {
@@ -317,9 +317,9 @@
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.0.tgz",
"integrity": "sha512-Fg+k/cSnqmNQlSWyDp0PpaAJ67kAISfZAD+zZ3mcE8/3ml2I/wM/GVjPy2zeiQX9aR93lG1mZXFSNTDUc74tWQ==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.7.tgz",
"integrity": "sha512-QsDxcbhrumJg+Cyflpnj8EY+bZojbco5K7VSeKvguqeXUGb62ksyOZuUTCn2sqJaCgy1KZ1qC5U8jBqfgZHc2w==",
"cpu": [
"arm64"
],
@@ -329,9 +329,9 @@
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.0.tgz",
"integrity": "sha512-CXp4b/brMbnBPZuGzKIOskd9uD90R73rWubaJ0du/Kt6fcyQX1dM1wEhWTLxI6eKf8IDL/R9QLL2cIahm1J86w==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.7.tgz",
"integrity": "sha512-fgv10kI04UycgpmhJLUcCswgvSdgsGuj65o+W5usmVdxYZiWpoXBBXRkWYMjUX5RNe3mY1Ff6QPBbToR0WkSUA==",
"cpu": [
"x64"
],
@@ -341,9 +341,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.0.tgz",
"integrity": "sha512-1bjaRzYcDsWIRUbO2K/f+ohNmNvCgKcrrOhmiXSHVlYY8kH1LUMFZj+BhqBC0Ea0Stt7/1rsRLMRXRtaeVOEHw==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.7.tgz",
"integrity": "sha512-pvw+31+VKEH3YmS/GLKzEGt/Y2+c/IaE6JL6tIjXi2KY+ZcWuyyXpYnYiHHDw2EP7ubKj6+fKIG1P9tlxMcGMQ==",
"cpu": [
"arm64"
],
@@ -353,9 +353,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.0.tgz",
"integrity": "sha512-BEDIJ6ReGAi+tLTS/RzxIw621yo1UUUiVNTzPGV2didyiJCr1chIGbES+39d/wiFQM43Xs3CBZLNzp+jKkv0/w==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.7.tgz",
"integrity": "sha512-kHFURhfhJRqw4k1auseqQgOzAHB4oYpyzLCX3TCR3uTxqRQ7gFxxlO0TnIcwNRqLcGb9GmWxWWoR8k1CdCXrMw==",
"cpu": [
"x64"
],
@@ -365,9 +365,9 @@
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.0.tgz",
"integrity": "sha512-7K2kbWbShuifQF/6L/tWSz2DhKfIreHKlBdVOuBTYYOReQMHn5cJxgwuFgQHqMubZ9zcagtHpmo+Wtqd034OKQ==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.7.tgz",
"integrity": "sha512-zWfZ557v2Y+93dVrmqqnbiLeTOb0ptunAG0zGjyE+3oyi8j/4+bL56Fdv94k+dfNF4KrcqcULEcZhKik3/FQ9w==",
"cpu": [
"x64"
],
@@ -4869,33 +4869,33 @@
}
},
"@lancedb/vectordb-darwin-arm64": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.0.tgz",
"integrity": "sha512-Fg+k/cSnqmNQlSWyDp0PpaAJ67kAISfZAD+zZ3mcE8/3ml2I/wM/GVjPy2zeiQX9aR93lG1mZXFSNTDUc74tWQ==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.7.tgz",
"integrity": "sha512-QsDxcbhrumJg+Cyflpnj8EY+bZojbco5K7VSeKvguqeXUGb62ksyOZuUTCn2sqJaCgy1KZ1qC5U8jBqfgZHc2w==",
"optional": true
},
"@lancedb/vectordb-darwin-x64": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.0.tgz",
"integrity": "sha512-CXp4b/brMbnBPZuGzKIOskd9uD90R73rWubaJ0du/Kt6fcyQX1dM1wEhWTLxI6eKf8IDL/R9QLL2cIahm1J86w==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.7.tgz",
"integrity": "sha512-fgv10kI04UycgpmhJLUcCswgvSdgsGuj65o+W5usmVdxYZiWpoXBBXRkWYMjUX5RNe3mY1Ff6QPBbToR0WkSUA==",
"optional": true
},
"@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.0.tgz",
"integrity": "sha512-1bjaRzYcDsWIRUbO2K/f+ohNmNvCgKcrrOhmiXSHVlYY8kH1LUMFZj+BhqBC0Ea0Stt7/1rsRLMRXRtaeVOEHw==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.7.tgz",
"integrity": "sha512-pvw+31+VKEH3YmS/GLKzEGt/Y2+c/IaE6JL6tIjXi2KY+ZcWuyyXpYnYiHHDw2EP7ubKj6+fKIG1P9tlxMcGMQ==",
"optional": true
},
"@lancedb/vectordb-linux-x64-gnu": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.0.tgz",
"integrity": "sha512-BEDIJ6ReGAi+tLTS/RzxIw621yo1UUUiVNTzPGV2didyiJCr1chIGbES+39d/wiFQM43Xs3CBZLNzp+jKkv0/w==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.7.tgz",
"integrity": "sha512-kHFURhfhJRqw4k1auseqQgOzAHB4oYpyzLCX3TCR3uTxqRQ7gFxxlO0TnIcwNRqLcGb9GmWxWWoR8k1CdCXrMw==",
"optional": true
},
"@lancedb/vectordb-win32-x64-msvc": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.0.tgz",
"integrity": "sha512-7K2kbWbShuifQF/6L/tWSz2DhKfIreHKlBdVOuBTYYOReQMHn5cJxgwuFgQHqMubZ9zcagtHpmo+Wtqd034OKQ==",
"version": "0.3.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.7.tgz",
"integrity": "sha512-zWfZ557v2Y+93dVrmqqnbiLeTOb0ptunAG0zGjyE+3oyi8j/4+bL56Fdv94k+dfNF4KrcqcULEcZhKik3/FQ9w==",
"optional": true
},
"@neon-rs/cli": {

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.3.0",
"version": "0.3.7",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",
@@ -81,10 +81,10 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.3.0",
"@lancedb/vectordb-darwin-x64": "0.3.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.3.0",
"@lancedb/vectordb-linux-x64-gnu": "0.3.0",
"@lancedb/vectordb-win32-x64-msvc": "0.3.0"
"@lancedb/vectordb-darwin-arm64": "0.3.7",
"@lancedb/vectordb-darwin-x64": "0.3.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.3.7",
"@lancedb/vectordb-linux-x64-gnu": "0.3.7",
"@lancedb/vectordb-win32-x64-msvc": "0.3.7"
}
}

View File

@@ -20,7 +20,7 @@ import {
Utf8,
type Vector,
FixedSizeList,
vectorFromArray, type Schema, Table as ArrowTable
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
@@ -77,7 +77,9 @@ function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
// Creates the Arrow Type for a Vector column with dimension `dim`
function newVectorType (dim: number): FixedSizeList<Float32> {
const children = new Field<Float32>('item', new Float32())
// Somewhere we always default to have the elements nullable, so we need to set it to true
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
const children = new Field<Float32>('item', new Float32(), true)
return new FixedSizeList(dim, children)
}
@@ -88,6 +90,13 @@ export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown
return Buffer.from(await writer.toUint8Array())
}
// Converts an Array of records into Arrow IPC stream format
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = await convertToTable(data, embeddings)
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Converts an Arrow Table into Arrow IPC format
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
if (embeddings !== undefined) {
@@ -105,6 +114,23 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
return Buffer.from(await writer.toUint8Array())
}
// Converts an Arrow Table into Arrow IPC stream format
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
if (source === null) {
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
}
const vectors = await embeddings.embed(source.toArray() as T[])
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
table = table.assign(new ArrowTable({ vector: column }))
}
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Creates an empty Arrow Table
export function createEmptyTable (schema: Schema): ArrowTable {
return new ArrowTable(schema)

View File

@@ -23,7 +23,7 @@ import { Query } from './query'
import { isEmbeddingFunction } from './embedding/embedding_function'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles } = require('../native.js')
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
export { Query }
export type { EmbeddingFunction }
@@ -260,6 +260,27 @@ export interface Table<T = number[]> {
* ```
*/
delete: (filter: string) => Promise<void>
/**
* List the indicies on this table.
*/
listIndices: () => Promise<VectorIndex[]>
/**
* Get statistics about an index.
*/
indexStats: (indexUuid: string) => Promise<IndexStats>
}
export interface VectorIndex {
columns: string[]
name: string
uuid: string
}
export interface IndexStats {
numIndexedRows: number | null
numUnindexedRows: number | null
}
/**
@@ -502,6 +523,14 @@ export class LocalTable<T = number[]> implements Table<T> {
return res.metrics
})
}
async listIndices (): Promise<VectorIndex[]> {
return tableListIndices.call(this._tbl)
}
async indexStats (indexUuid: string): Promise<IndexStats> {
return tableIndexStats.call(this._tbl, indexUuid)
}
}
export interface CleanupStats {

View File

@@ -65,8 +65,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
const mirroredPath = path.join(dir, `${tableName}.lance`)
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
if (err != null) throw err
// there should be two dirs
assert.equal(files.length, 2)
// there should be three dirs
assert.equal(files.length, 3)
assert.isTrue(files[0].isDirectory())
assert.isTrue(files[1].isDirectory())
@@ -76,6 +76,12 @@ describe('LanceDB Mirrored Store Integration test', function () {
assert.isTrue(files[0].name.endsWith('.txn'))
})
fs.readdir(path.join(mirroredPath, '_versions'), { withFileTypes: true }, (err, files) => {
if (err != null) throw err
assert.equal(files.length, 1)
assert.isTrue(files[0].name.endsWith('.manifest'))
})
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
if (err != null) throw err
assert.equal(files.length, 1)
@@ -88,8 +94,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
if (err != null) throw err
// there should be two dirs
assert.equal(files.length, 3)
// there should be four dirs
assert.equal(files.length, 4)
assert.isTrue(files[0].isDirectory())
assert.isTrue(files[1].isDirectory())
assert.isTrue(files[2].isDirectory())
@@ -128,12 +134,13 @@ describe('LanceDB Mirrored Store Integration test', function () {
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
if (err != null) throw err
// there should be two dirs
assert.equal(files.length, 4)
// there should be five dirs
assert.equal(files.length, 5)
assert.isTrue(files[0].isDirectory())
assert.isTrue(files[1].isDirectory())
assert.isTrue(files[2].isDirectory())
assert.isTrue(files[3].isDirectory())
assert.isTrue(files[4].isDirectory())
// Three TXs now
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {

View File

@@ -63,6 +63,9 @@ export class HttpLancedbClient {
}
).catch((err) => {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
if (response.status !== 200) {
@@ -86,13 +89,17 @@ export class HttpLancedbClient {
{
headers: {
'Content-Type': 'application/json',
'x-api-key': this._apiKey()
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
params,
timeout: 10000
}
).catch((err) => {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
if (response.status !== 200) {
@@ -108,13 +115,18 @@ export class HttpLancedbClient {
/**
* Sent POST request.
*/
public async post (path: string, data?: any, params?: Record<string, string | number>): Promise<AxiosResponse> {
public async post (
path: string,
data?: any,
params?: Record<string, string | number>,
content?: string | undefined
): Promise<AxiosResponse> {
const response = await axios.post(
`${this._url}${path}`,
data,
{
headers: {
'Content-Type': 'application/json',
'Content-Type': content ?? 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
@@ -123,6 +135,9 @@ export class HttpLancedbClient {
}
).catch((err) => {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
if (response.status !== 200) {

View File

@@ -14,12 +14,16 @@
import {
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
type ConnectionOptions, type CreateTableOptions, type WriteOptions
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
type WriteOptions,
type IndexStats
} from '../index'
import { Query } from '../query'
import { Vector } from 'apache-arrow'
import { Vector, Table as ArrowTable } from 'apache-arrow'
import { HttpLancedbClient } from './client'
import { isEmbeddingFunction } from '../embedding/embedding_function'
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
/**
* Remote connection.
@@ -66,8 +70,60 @@ export class RemoteConnection implements Connection {
}
}
async createTable<T> (name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
throw new Error('Not implemented')
async createTable<T> (nameOrOpts: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
// Logic copied from LocatlConnection, refactor these to a base class + connectionImpl pattern
let schema
let embeddings: undefined | EmbeddingFunction<T>
let tableName: string
if (typeof nameOrOpts === 'string') {
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
embeddings = optsOrEmbedding
}
tableName = nameOrOpts
} else {
schema = nameOrOpts.schema
embeddings = nameOrOpts.embeddingFunction
tableName = nameOrOpts.name
}
let buffer: Buffer
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean {
if (data instanceof ArrowTable) {
return data.data.length === 0
}
return data.length === 0
}
if ((data === undefined) || isEmpty(data)) {
if (schema === undefined) {
throw new Error('Either data or schema needs to defined')
}
buffer = await fromTableToStreamBuffer(createEmptyTable(schema))
} else if (data instanceof ArrowTable) {
buffer = await fromTableToStreamBuffer(data, embeddings)
} else {
// data is Array<Record<...>>
buffer = await fromRecordsToStreamBuffer(data, embeddings)
}
const res = await this._client.post(
`/v1/table/${tableName}/create/`,
buffer,
undefined,
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
}
if (embeddings === undefined) {
return new RemoteTable(this._client, tableName)
} else {
return new RemoteTable(this._client, tableName, embeddings)
}
}
async dropTable (name: string): Promise<void> {
@@ -141,11 +197,39 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async add (data: Array<Record<string, unknown>>): Promise<number> {
throw new Error('Not implemented')
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/insert/`,
buffer,
{
mode: 'append'
},
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
}
return data.length
}
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
throw new Error('Not implemented')
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/insert/`,
buffer,
{
mode: 'overwrite'
},
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
}
return data.length
}
async createIndex (indexParams: VectorIndexParams): Promise<any> {
@@ -153,10 +237,28 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async countRows (): Promise<number> {
throw new Error('Not implemented')
const result = await this._client.post(`/v1/table/${this._name}/describe/`)
return result.data?.stats?.num_rows
}
async delete (filter: string): Promise<void> {
throw new Error('Not implemented')
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
}
async listIndices (): Promise<VectorIndex[]> {
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
return results.data.indexes?.map((index: any) => ({
columns: index.columns,
name: index.index_name,
uuid: index.index_uuid
}))
}
async indexStats (indexUuid: string): Promise<IndexStats> {
const results = await this._client.post(`/v1/table/${this._name}/index/${indexUuid}/stats/`)
return {
numIndexedRows: results.data.num_indexed_rows,
numUnindexedRows: results.data.num_unindexed_rows
}
}
}

View File

@@ -328,6 +328,24 @@ describe('LanceDB client', function () {
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
})
it('should be able to list index and stats', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
const indices = await table.listIndices()
expect(indices).to.have.lengthOf(1)
expect(indices[0].name).to.equal('vector_idx')
expect(indices[0].uuid).to.not.be.equal(undefined)
expect(indices[0].columns).to.have.lengthOf(1)
expect(indices[0].columns[0]).to.equal('vector')
const stats = await table.indexStats(indices[0].uuid)
expect(stats.numIndexedRows).to.equal(300)
expect(stats.numUnindexedRows).to.equal(0)
}).timeout(50_000)
})
describe('when using a custom embedding function', function () {
@@ -378,6 +396,40 @@ describe('LanceDB client', function () {
})
})
describe('Remote LanceDB client', function () {
describe('when the server is not reachable', function () {
it('produces a network error', async function () {
const con = await lancedb.connect({
uri: 'db://test-1234',
region: 'asdfasfasfdf',
apiKey: 'some-api-key'
})
// GET
try {
await con.tableNames()
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
}
// POST
try {
await con.createTable({ name: 'vectors', schema: new Schema([]) })
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
}
// Search
const table = await con.openTable('vectors')
try {
await table.search([0.1, 0.3]).execute()
} catch (err) {
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
}
})
})
})
describe('Query object', function () {
it('sets custom parameters', async function () {
const query = new Query([0.1, 0.3])

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.1
current_version = 0.3.3
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

1
python/LICENSE Symbolic link
View File

@@ -0,0 +1 @@
../LICENSE

View File

@@ -16,10 +16,13 @@ from typing import Optional
__version__ = importlib.metadata.version("lancedb")
from .db import URI, DBConnection, LanceDBConnection
from .common import URI
from .db import DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .schema import vector
from .utils import sentry_log
from .schema import vector # noqa: F401
from .utils import sentry_log # noqa: F401
import requests
def connect(
@@ -69,3 +72,26 @@ def connect(
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region, host_override)
return LanceDBConnection(uri)
def drop_database(uri: URI, api_key: str, region: str = "us-west-2"):
"""Drop a LanceDB database.
Parameters
----------
uri: str or Path
The uri of the database.
Examples
--------
>>> lancedb.drop_database(uri="db://", api_key="sk_...", region="...")
"""
if isinstance(uri, str) and uri.startswith("db://"):
control_plane_url = f"control-plane.{region}.api.lancedb.com"
requests.delete(
f"https://{control_plane_url}/api/v1/auth/token/delete",
json={"api_key": api_key}
)
return LanceDBConnection(uri).drop_database()

View File

@@ -1,4 +1,6 @@
import os
import time
from typing import Any
import numpy as np
import pytest
@@ -38,3 +40,26 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 10
class RateLimitedAPI:
rate_limit = 0.1 # 1 request per 0.1 second
last_request_time = 0
@staticmethod
def make_request():
current_time = time.time()
if current_time - RateLimitedAPI.last_request_time < RateLimitedAPI.rate_limit:
raise Exception("Rate limit exceeded. Please try again later.")
# Simulate a successful request
RateLimitedAPI.last_request_time = current_time
return "Request successful"
@registry.register("test-rate-limited")
class MockRateLimitedEmbeddingFunction(MockTextEmbeddingFunction):
def generate_embeddings(self, texts):
RateLimitedAPI.make_request()
return [self._compute_one_embedding(row) for row in texts]

View File

@@ -84,7 +84,9 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
context windows that don't cross document boundaries. In this case, we can
pass ``document_id`` as the group by.
>>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_pandas()
>>> (contextualize(data)
... .window(4).stride(2).text_col('token').groupby('document_id')
... .to_pandas())
token document_id
0 The quick brown fox 1
2 brown fox jumped over 1
@@ -92,18 +94,24 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
6 the lazy dog 1
9 I love sandwiches 2
``min_window_size`` determines the minimum size of the context windows that are generated
This can be used to trim the last few context windows which have size less than
``min_window_size``. By default context windows of size 1 are skipped.
``min_window_size`` determines the minimum size of the context windows
that are generated.This can be used to trim the last few context windows
which have size less than ``min_window_size``.
By default context windows of size 1 are skipped.
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_pandas()
>>> (contextualize(data)
... .window(6).stride(3).text_col('token').groupby('document_id')
... .to_pandas())
token document_id
0 The quick brown fox jumped over 1
3 fox jumped over the lazy dog 1
6 the lazy dog 1
9 I love sandwiches 2
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_pandas()
>>> (contextualize(data)
... .window(6).stride(3).min_window_size(4).text_col('token')
... .groupby('document_id')
... .to_pandas())
token document_id
0 The quick brown fox jumped over 1
3 fox jumped over the lazy dog 1
@@ -113,7 +121,9 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
class Contextualizer:
"""Create context windows from a DataFrame. See [lancedb.context.contextualize][]."""
"""Create context windows from a DataFrame.
See [lancedb.context.contextualize][].
"""
def __init__(self, raw_df):
self._text_col = None
@@ -183,7 +193,7 @@ class Contextualizer:
deprecated_in="0.3.1",
removed_in="0.4.0",
current_version=__version__,
details="Use the bar function instead",
details="Use to_pandas() instead",
)
def to_df(self) -> "pd.DataFrame":
return self.to_pandas()

View File

@@ -14,26 +14,39 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from abc import abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import pyarrow as pa
from overrides import EnforceOverrides, override
from pyarrow import fs
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme
if TYPE_CHECKING:
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
class DBConnection(ABC):
class DBConnection(EnforceOverrides):
"""An active LanceDB connection interface."""
@abstractmethod
def table_names(self) -> list[str]:
"""List all table names in the database."""
def table_names(
self, page_token: Optional[str] = None, limit: int = 10
) -> Iterable[str]:
"""List all table in this database
Parameters
----------
page_token: str, optional
The token to use for pagination. If not present, start from the beginning.
limit: int, default 10
The size of the page to return.
"""
pass
@abstractmethod
@@ -45,6 +58,7 @@ class DBConnection(ABC):
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
@@ -52,12 +66,24 @@ class DBConnection(ABC):
----------
name: str
The name of the table.
data: list, tuple, dict, pd.DataFrame; optional
The data to initialize the table. User must provide at least one of `data` or `schema`.
schema: pyarrow.Schema or LanceModel; optional
The schema of the table.
data: The data to initialize the table, *optional*
User must provide at least one of `data` or `schema`.
Acceptable types are:
- dict or list-of-dict
- pandas.DataFrame
- pyarrow.Table or pyarrow.RecordBatch
schema: The schema of the table, *optional*
Acceptable types are:
- pyarrow.Schema
- [LanceModel][lancedb.pydantic.LanceModel]
mode: str; default "create"
The mode to use when creating the table. Can be either "create" or "overwrite".
The mode to use when creating the table.
Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
on_bad_vectors: str, default "error"
@@ -150,7 +176,8 @@ class DBConnection(ABC):
... for i in range(5):
... yield pa.RecordBatch.from_arrays(
... [
... pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
... pa.array([[3.1, 4.1], [5.9, 26.5]],
... pa.list_(pa.float32(), 2)),
... pa.array(["foo", "bar"]),
... pa.array([10.0, 20.0]),
... ],
@@ -249,12 +276,15 @@ class LanceDBConnection(DBConnection):
def uri(self) -> str:
return self._uri
def table_names(self) -> list[str]:
"""Get the names of all tables in the database.
@override
def table_names(
self, page_token: Optional[str] = None, limit: int = 10
) -> Iterable[str]:
"""Get the names of all tables in the database. The names are sorted.
Returns
-------
list of str
Iterator of str.
A list of table names.
"""
try:
@@ -274,6 +304,7 @@ class LanceDBConnection(DBConnection):
for file_info in paths
if file_info.extension == "lance"
]
tables.sort()
return tables
def __len__(self) -> int:
@@ -282,6 +313,7 @@ class LanceDBConnection(DBConnection):
def __contains__(self, name: str) -> bool:
return name in self.table_names()
@override
def create_table(
self,
name: str,
@@ -313,6 +345,7 @@ class LanceDBConnection(DBConnection):
)
return tbl
@override
def open_table(self, name: str) -> LanceTable:
"""Open a table in the database.
@@ -327,6 +360,7 @@ class LanceDBConnection(DBConnection):
"""
return LanceTable.open(self, name)
@override
def drop_table(self, name: str, ignore_missing: bool = False):
"""Drop a table from the database.
@@ -345,6 +379,7 @@ class LanceDBConnection(DBConnection):
if not ignore_missing:
raise
@override
def drop_database(self):
filesystem, path = fs_from_uri(self.uri)
filesystem.delete_dir(path)

View File

@@ -11,15 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F401
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
from .cohere import CohereEmbeddingFunction
from .functions import (
EmbeddingFunction,
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
OpenAIEmbeddings,
OpenClipEmbeddings,
SentenceTransformerEmbeddings,
TextEmbeddingFunction,
)
from .instructor import InstructorEmbeddingFunction
from .open_clip import OpenClipEmbeddings
from .openai import OpenAIEmbeddings
from .registry import EmbeddingFunctionRegistry, get_registry
from .sentence_transformers import SentenceTransformerEmbeddings
from .utils import with_embeddings

View File

@@ -0,0 +1,181 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from abc import ABC, abstractmethod
from typing import List, Union
import numpy as np
import pyarrow as pa
from pydantic import BaseModel, Field, PrivateAttr
from .utils import TEXT, retry_with_exponential_backoff
class EmbeddingFunction(BaseModel, ABC):
"""
An ABC for embedding functions.
All concrete embedding functions must implement the following:
1. compute_query_embeddings() which takes a query and returns a list of embeddings
2. get_source_embeddings() which returns a list of embeddings for the source column
For text data, the two will be the same. For multi-modal data, the source column
might be images and the vector column might be text.
3. ndims method which returns the number of dimensions of the vector column
"""
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
max_retries: int = (
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
)
_ndims: int = PrivateAttr()
@classmethod
def create(cls, **kwargs):
"""
Create an instance of the embedding function
"""
return cls(**kwargs)
@abstractmethod
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for a given user query
"""
pass
@abstractmethod
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database
"""
pass
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for a given user query with retries
"""
return retry_with_exponential_backoff(
self.compute_query_embeddings, max_retries=self.max_retries
)(
*args,
**kwargs,
)
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database with retries
"""
return retry_with_exponential_backoff(
self.compute_source_embeddings, max_retries=self.max_retries
)(*args, **kwargs)
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, pa.Array):
texts = texts.to_pylist()
elif isinstance(texts, pa.ChunkedArray):
texts = texts.combine_chunks().to_pylist()
return texts
@classmethod
def safe_import(cls, module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
return importlib.import_module(module)
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION
if PYDANTIC_VERSION.major < 2:
return dict(self)
return self.model_dump()
@abstractmethod
def ndims(self):
"""
Return the dimensions of the vector column
"""
pass
def SourceField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the source column for this embedding function
"""
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
def VectorField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the target vector column for this embedding function
"""
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
def __eq__(self, __value: object) -> bool:
if not hasattr(__value, "__dict__"):
return False
return vars(self) == vars(__value)
def __hash__(self) -> int:
return hash(frozenset(vars(self).items()))
class EmbeddingFunctionConfig(BaseModel):
"""
This model encapsulates the configuration for a embedding function
in a lancedb table. It holds the embedding function, the source column,
and the vector column
"""
vector_column: str
source_column: str
function: EmbeddingFunction
class TextEmbeddingFunction(EmbeddingFunction):
"""
A callable ABC for embedding functions that take text as input
"""
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, *args, **kwargs)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
pass

View File

@@ -16,7 +16,8 @@ from typing import ClassVar, List, Union
import numpy as np
from .functions import TextEmbeddingFunction, register
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
@@ -30,7 +31,8 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. See the Cohere documentation for a list of available models.
The name of the model to use. See the Cohere documentation for
a list of available models.
Examples
--------
@@ -38,7 +40,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0")
cohere = EmbeddingFunctionRegistry
.get_instance()
.get("cohere")
.create(name="embed-multilingual-v2.0")
class TextModel(LanceModel):
text: str = cohere.SourceField()

View File

@@ -1,578 +0,0 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import importlib
import io
import json
import os
import socket
import urllib.error
import urllib.parse as urlparse
import urllib.request
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel, Field, PrivateAttr
from tqdm import tqdm
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing.
You can implement your own embedding function by subclassing EmbeddingFunction
or TextEmbeddingFunction and registering it with the registry.
Examples
--------
>>> registry = EmbeddingFunctionRegistry.get_instance()
>>> @registry.register("my-embedding-function")
... class MyEmbeddingFunction(EmbeddingFunction):
... def ndims(self) -> int:
... return 128
...
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
... return self.compute_source_embeddings(query, *args, **kwargs)
...
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
...
>>> registry.get("my-embedding-function")
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
"""
@classmethod
def get_instance(cls):
return __REGISTRY__
def __init__(self):
self._functions = {}
def register(self, alias: str = None):
"""
This creates a decorator that can be used to register
an EmbeddingFunction.
Parameters
----------
alias : Optional[str]
a human friendly name for the embedding function. If not
provided, the class name will be used.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunction):
raise TypeError("Must be a subclass of EmbeddingFunction")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
key = alias or cls.__name__
self._functions[key] = cls
cls.__embedding_function_registry_alias__ = alias
return cls
return decorator
def reset(self):
"""
Reset the registry to its initial state
"""
self._functions = {}
def get(self, name: str):
"""
Fetch an embedding function class by name
Parameters
----------
name : str
The name of the embedding function to fetch
Either the alias or the class name if no alias was provided
during registration
"""
return self._functions[name]
def parse_functions(
self, metadata: Optional[Dict[bytes, bytes]]
) -> Dict[str, "EmbeddingFunctionConfig"]:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
embedding function and source column
Parameters
----------
metadata : Optional[Dict[bytes, bytes]]
The metadata from an arrow table. Note that
the keys and values are bytes (pyarrow api)
Returns
-------
functions : dict
A mapping of vector column name to embedding function.
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None or b"embedding_functions" not in metadata:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
return {
obj["vector_column"]: EmbeddingFunctionConfig(
vector_column=obj["vector_column"],
source_column=obj["source_column"],
function=self.get(obj["name"])(**obj["model"]),
)
for obj in raw_list
}
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
func = conf.function
name = getattr(
func, "__embedding_function_registry_alias__", func.__class__.__name__
)
json_data = func.safe_model_dump()
return {
"name": name,
"model": json_data,
"source_column": conf.source_column,
"vector_column": conf.vector_column,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector configs
into a config dictionary that can be serialized into arrow metadata
"""
if func_list is None or len(func_list) == 0:
return None
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes
# so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
# Global instance
__REGISTRY__ = EmbeddingFunctionRegistry()
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
class EmbeddingFunction(BaseModel, ABC):
"""
An ABC for embedding functions.
All concrete embedding functions must implement the following:
1. compute_query_embeddings() which takes a query and returns a list of embeddings
2. get_source_embeddings() which returns a list of embeddings for the source column
For text data, the two will be the same. For multi-modal data, the source column
might be images and the vector column might be text.
3. ndims method which returns the number of dimensions of the vector column
"""
_ndims: int = PrivateAttr()
@classmethod
def create(cls, **kwargs):
"""
Create an instance of the embedding function
"""
return cls(**kwargs)
@abstractmethod
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for a given user query
"""
pass
@abstractmethod
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database
"""
pass
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, pa.Array):
texts = texts.to_pylist()
elif isinstance(texts, pa.ChunkedArray):
texts = texts.combine_chunks().to_pylist()
return texts
@classmethod
def safe_import(cls, module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
return importlib.import_module(module)
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION
if PYDANTIC_VERSION.major < 2:
return dict(self)
return self.model_dump()
@abstractmethod
def ndims(self):
"""
Return the dimensions of the vector column
"""
pass
def SourceField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the source column for this embedding function
"""
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
def VectorField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the target vector column for this embedding function
"""
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
class EmbeddingFunctionConfig(BaseModel):
"""
This model encapsulates the configuration for a embedding function
in a lancedb table. It holds the embedding function, the source column,
and the vector column
"""
vector_column: str
source_column: str
function: EmbeddingFunction
class TextEmbeddingFunction(EmbeddingFunction):
"""
A callable ABC for embedding functions that take text as input
"""
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, *args, **kwargs)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
pass
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the sentence-transformers library
https://huggingface.co/sentence-transformers
"""
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
@property
def embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.__class__.get_embedding_model(self.name, self.device)
def ndims(self):
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@classmethod
@cached(cache={})
def get_embedding_model(cls, name, device):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
Parameters
----------
name : str
The name of the model to load
device : str
The device to load the model on
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
sentence_transformers = cls.safe_import(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(name, device=device)
@register("openai")
class OpenAIEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the OpenAI API
https://platform.openai.com/docs/guides/embeddings
"""
name: str = "text-embedding-ada-002"
def ndims(self):
# TODO don't hardcode this
return 1536
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
openai = self.safe_import("openai")
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
return [v["embedding"] for v in rs]
@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the OpenClip API
For multi-modal text-to-image search
https://github.com/mlfoundations/open_clip
"""
name: str = "ViT-B-32"
pretrained: str = "laion2b_s34b_b79k"
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_preprocess = PrivateAttr()
_tokenizer = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip")
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
model.to(self.device)
self._model, self._preprocess = model, preprocess
self._tokenizer = open_clip.get_tokenizer(self.name)
self._ndims = None
def ndims(self):
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
with torch.no_grad():
text_features = self._model.encode_text(text.to(self.device))
if self.normalize:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().squeeze()
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given images
"""
images = self.sanitize_input(images)
embeddings = []
for i in range(0, len(images), self.batch_size):
j = min(i + self.batch_size, len(images))
batch = images[i:j]
embeddings.extend(self._parallel_get(batch))
return embeddings
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
"""
Issue concurrent requests to retrieve the image data
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [future.result() for future in tqdm(futures)]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
with torch.no_grad():
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
"""
encode a single image tensor and optionally normalize the output
"""
image_features = self._model.encode_image(image_tensor.to(self.device))
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
def url_retrieve(url: str):
"""
Parameters
----------
url: str
URL to download from
"""
try:
with urllib.request.urlopen(url) as conn:
return conn.read()
except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError("could not download {} due to {}".format(url, err))

View File

@@ -0,0 +1,137 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT, weak_lru
@register("instructor")
class InstructorEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
variety of tasks, including text classification, sentence similarity, and document retrieval.
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
"Represent the `domain` `text_type` for `task_objective`":
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
"Represent the document for retreival"
Parameters
----------
name: str
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
The default model is hkunlp/instructor-base
batch_size: int, default 32
The batch size to use when generating embeddings
device: str, default "cpu"
The device to use when generating embeddings
show_progress_bar: bool, default True
Whether to show a progress bar when generating embeddings
normalize_embeddings: bool, default True
Whether to normalize the embeddings
quantize: bool, default False
Whether to quantize the model
source_instruction: str, default "represent the docuement for retreival"
The instruction for the source column
query_instruction: str, default "represent the document for retreiving the most similar documents"
The instruction for the query
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
instructor = get_registry().get("instructor").create(
source_instruction="represent the docuement for retreival",
query_instruction="represent the document for retreiving the most similar documents"
)
class Schema(LanceModel):
vector: Vector(instructor.ndims()) = instructor.VectorField()
text: str = instructor.SourceField()
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=Schema, mode="overwrite")
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
tbl.add(texts)
"""
name: str = "hkunlp/instructor-base"
batch_size: int = 32
device: str = "cpu"
show_progress_bar: bool = True
normalize_embeddings: bool = True
quantize: bool = False
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
source_instruction: str = "represent the document for retrieval"
query_instruction: str = (
"represent the document for retrieving the most similar documents"
)
@weak_lru(maxsize=1)
def ndims(self):
model = self.get_model()
return model.encode("foo").shape[0]
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.generate_embeddings([[self.query_instruction, query]])
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
texts_formatted = []
for text in texts:
texts_formatted.append([self.source_instruction, text])
return self.generate_embeddings(texts_formatted)
def generate_embeddings(self, texts: List) -> List:
model = self.get_model()
res = model.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
normalize_embeddings=self.normalize_embeddings,
).tolist()
return res
@weak_lru(maxsize=1)
def get_model(self):
instructor_embedding = self.safe_import(
"InstructorEmbedding", "InstructorEmbedding"
)
torch = self.safe_import("torch", "torch")
model = instructor_embedding.INSTRUCTOR(self.name)
if self.quantize:
if (
"qnnpack" in torch.backends.quantized.supported_engines
): # fix for https://github.com/pytorch/pytorch/issues/29327
torch.backends.quantized.engine = "qnnpack"
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return model

View File

@@ -0,0 +1,175 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import io
import os
import urllib.parse as urlparse
from typing import List, Union
import numpy as np
import pyarrow as pa
from pydantic import PrivateAttr
from tqdm import tqdm
from .base import EmbeddingFunction
from .registry import register
from .utils import IMAGES, url_retrieve
@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the OpenClip API
For multi-modal text-to-image search
https://github.com/mlfoundations/open_clip
"""
name: str = "ViT-B-32"
pretrained: str = "laion2b_s34b_b79k"
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_preprocess = PrivateAttr()
_tokenizer = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip")
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
model.to(self.device)
self._model, self._preprocess = model, preprocess
self._tokenizer = open_clip.get_tokenizer(self.name)
self._ndims = None
def ndims(self):
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
with torch.no_grad():
text_features = self._model.encode_text(text.to(self.device))
if self.normalize:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().squeeze()
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given images
"""
images = self.sanitize_input(images)
embeddings = []
for i in range(0, len(images), self.batch_size):
j = min(i + self.batch_size, len(images))
batch = images[i:j]
embeddings.extend(self._parallel_get(batch))
return embeddings
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
"""
Issue concurrent requests to retrieve the image data
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [future.result() for future in tqdm(futures)]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
with torch.no_grad():
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
"""
encode a single image tensor and optionally normalize the output
"""
image_features = self._model.encode_image(image_tensor.to(self.device))
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()

View File

@@ -0,0 +1,49 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import numpy as np
from .base import TextEmbeddingFunction
from .registry import register
@register("openai")
class OpenAIEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the OpenAI API
https://platform.openai.com/docs/guides/embeddings
"""
name: str = "text-embedding-ada-002"
def ndims(self):
# TODO don't hardcode this
return 1536
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
openai = self.safe_import("openai")
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
return [v["embedding"] for v in rs]

View File

@@ -0,0 +1,186 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Dict, Optional
from .base import EmbeddingFunction, EmbeddingFunctionConfig
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing.
You can implement your own embedding function by subclassing EmbeddingFunction
or TextEmbeddingFunction and registering it with the registry.
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
Examples
--------
>>> registry = EmbeddingFunctionRegistry.get_instance()
>>> @registry.register("my-embedding-function")
... class MyEmbeddingFunction(EmbeddingFunction):
... def ndims(self) -> int:
... return 128
...
... def compute_query_embeddings(self, query: str, *args, **kwargs):
... return self.compute_source_embeddings(query, *args, **kwargs)
...
... def compute_source_embeddings(self, texts, *args, **kwargs):
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
...
>>> registry.get("my-embedding-function")
<class 'lancedb.embeddings.registry.MyEmbeddingFunction'>
"""
@classmethod
def get_instance(cls):
return __REGISTRY__
def __init__(self):
self._functions = {}
def register(self, alias: str = None):
"""
This creates a decorator that can be used to register
an EmbeddingFunction.
Parameters
----------
alias : Optional[str]
a human friendly name for the embedding function. If not
provided, the class name will be used.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunction):
raise TypeError("Must be a subclass of EmbeddingFunction")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
key = alias or cls.__name__
self._functions[key] = cls
cls.__embedding_function_registry_alias__ = alias
return cls
return decorator
def reset(self):
"""
Reset the registry to its initial state
"""
self._functions = {}
def get(self, name: str):
"""
Fetch an embedding function class by name
Parameters
----------
name : str
The name of the embedding function to fetch
Either the alias or the class name if no alias was provided
during registration
"""
return self._functions[name]
def parse_functions(
self, metadata: Optional[Dict[bytes, bytes]]
) -> Dict[str, "EmbeddingFunctionConfig"]:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
embedding function and source column
Parameters
----------
metadata : Optional[Dict[bytes, bytes]]
The metadata from an arrow table. Note that
the keys and values are bytes (pyarrow api)
Returns
-------
functions : dict
A mapping of vector column name to embedding function.
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None or b"embedding_functions" not in metadata:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
return {
obj["vector_column"]: EmbeddingFunctionConfig(
vector_column=obj["vector_column"],
source_column=obj["source_column"],
function=self.get(obj["name"])(**obj["model"]),
)
for obj in raw_list
}
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
func = conf.function
name = getattr(
func, "__embedding_function_registry_alias__", func.__class__.__name__
)
json_data = func.safe_model_dump()
return {
"name": name,
"model": json_data,
"source_column": conf.source_column,
"vector_column": conf.vector_column,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector configs
into a config dictionary that can be serialized into arrow metadata
"""
if func_list is None or len(func_list) == 0:
return None
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes
# so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
# Global instance
__REGISTRY__ = EmbeddingFunctionRegistry()
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
def get_registry():
"""
Utility function to get the global instance of the registry
Returns
-------
EmbeddingFunctionRegistry
The global registry instance
Examples
--------
from lancedb.embeddings import get_registry
registry = get_registry()
openai = registry.get("openai").create()
"""
return __REGISTRY__.get_instance()

View File

@@ -0,0 +1,89 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import numpy as np
from cachetools import cached
from .base import TextEmbeddingFunction
from .registry import register
from .utils import weak_lru
@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the sentence-transformers library
https://huggingface.co/sentence-transformers
"""
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
@property
def embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.get_embedding_model()
def ndims(self):
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@weak_lru(maxsize=1)
def get_embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
Parameters
----------
name : str
The name of the model to load
device : str
The device to load the model on
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
sentence_transformers = self.safe_import(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(self.name, device=self.device)

View File

@@ -11,9 +11,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
import random
import socket
import sys
from typing import Callable, Union
import time
import urllib.error
import weakref
from typing import Callable, List, Union
import numpy as np
import pyarrow as pa
@@ -24,7 +30,12 @@ from ..util import safe_import_pandas
from ..utils.general import LOGGER
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
def with_embeddings(
@@ -155,6 +166,113 @@ class FunctionWrapper:
yield from _chunker(arr)
def weak_lru(maxsize=128):
"""
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
is bounded.
Parameters
----------
maxsize : int, default 128
The maximum number of objects to cache.
Returns
-------
Callable
A decorator that can be applied to a method.
Examples
--------
>>> class Foo:
... @weak_lru()
... def bar(self, x):
... return x
>>> foo = Foo()
>>> foo.bar(1)
1
>>> foo.bar(2)
2
>>> foo.bar(1)
1
"""
def wrapper(func):
@functools.lru_cache(maxsize)
def _func(_self, *args, **kwargs):
return func(_self(), *args, **kwargs)
@functools.wraps(func)
def inner(self, *args, **kwargs):
return _func(weakref.ref(self), *args, **kwargs)
return inner
return wrapper
def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 7,
# errors: tuple = (),
):
"""Retry a function with exponential backoff.
Args:
func (function): The function to be retried.
initial_delay (float): Initial delay in seconds (default is 1).
exponential_base (float): The base for exponential backoff (default is 2).
jitter (bool): Whether to add jitter to the delay (default is True).
max_retries (int): Maximum number of retries (default is 10).
errors (tuple): Tuple of specific exceptions to retry on (default is (openai.error.RateLimitError,)).
Returns:
function: The decorated function.
"""
def wrapper(*args, **kwargs):
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
# should check the error message to be sure
except Exception as e:
num_retries += 1
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
delay *= exponential_base * (1 + jitter * random.random())
LOGGER.info(f"Retrying in {delay:.2f} seconds due to {e}")
time.sleep(delay)
return wrapper
def url_retrieve(url: str):
"""
Parameters
----------
url: str
URL to download from
"""
try:
with urllib.request.urlopen(url) as conn:
return conn.read()
except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError("could not download {} due to {}".format(url, err))
def api_key_not_found_help(provider):
LOGGER.error(f"Could not find API key for {provider}.")
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")

View File

@@ -19,6 +19,7 @@ import inspect
import sys
import types
from abc import ABC, abstractmethod
from datetime import date, datetime
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
import numpy as np
@@ -159,6 +160,10 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
return pa.bool_()
elif py_type == bytes:
return pa.binary()
elif py_type == date:
return pa.date32()
elif py_type == datetime:
return pa.timestamp("us")
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
)
@@ -322,7 +327,12 @@ class LanceModel(pydantic.BaseModel):
for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items():
src_func = get_extras(field_info, "source_column_for")
if src_func == func:
if src_func is func:
# note we can't use == here since the function is a pydantic
# model so two instances of the same function are ==, so if you
# have multiple vector columns from multiple sources, both will
# be mapped to the same source column
# GH594
configs.append(
EmbeddingFunctionConfig(
source_column=source, vector_column=vec, function=func

View File

@@ -14,7 +14,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
import deprecation
import numpy as np
@@ -23,14 +23,49 @@ import pydantic
from . import __version__
from .common import VECTOR_COLUMN_NAME
from .pydantic import LanceModel
from .util import safe_import_pandas
if TYPE_CHECKING:
from .pydantic import LanceModel
pd = safe_import_pandas()
class Query(pydantic.BaseModel):
"""A Query"""
"""The LanceDB Query
Attributes
----------
vector : List[float]
the vector to search for
filter : Optional[str]
sql filter to refine the query with, optional
prefilter : bool
if True then apply the filter before vector search
k : int
top k results to return
metric : str
the distance metric between a pair of vectors,
can support L2 (default), Cosine and Dot.
[metric definitions][search]
columns : Optional[List[str]]
which columns to return in the results
nprobes : int
The number of probes used - optional
- A higher number makes search more accurate but also slower.
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
refine_factor : Optional[int]
Refine the results by reading extra elements and re-ranking them in memory - optional
- A higher number makes search more accurate but also slower.
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
"""
vector_column: str = VECTOR_COLUMN_NAME
@@ -61,6 +96,10 @@ class Query(pydantic.BaseModel):
class LanceQueryBuilder(ABC):
"""Build LanceDB query based on specific query type:
vector or full text search.
"""
@classmethod
def create(
cls,
@@ -103,7 +142,7 @@ class LanceQueryBuilder(ABC):
if not isinstance(query, (list, np.ndarray)):
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
query = conf.function.compute_query_embeddings_with_retry(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
@@ -114,7 +153,7 @@ class LanceQueryBuilder(ABC):
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
query = conf.function.compute_query_embeddings_with_retry(query)[0]
return query, "vector"
else:
return query, "fts"
@@ -133,11 +172,11 @@ class LanceQueryBuilder(ABC):
deprecated_in="0.3.1",
removed_in="0.4.0",
current_version=__version__,
details="Use the bar function instead",
details="Use to_pandas() instead",
)
def to_df(self) -> "pd.DataFrame":
"""
Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.
*Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.*
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
@@ -226,13 +265,20 @@ class LanceQueryBuilder(ABC):
self._columns = columns
return self
def where(self, where) -> LanceQueryBuilder:
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
This feature is **EXPERIMENTAL** and may be removed and modified
without warning in the future.
Returns
-------
@@ -240,13 +286,12 @@ class LanceQueryBuilder(ABC):
The LanceQueryBuilder object.
"""
self._where = where
self._prefilter = prefilter
return self
class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
A builder for nearest neighbor queries for LanceDB.
Examples
--------
>>> import lancedb
@@ -302,7 +347,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
See discussion in [Querying an ANN Index][../querying-an-ann-index] for
See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
Parameters
@@ -369,14 +414,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
Parameters
----------
where: str
The where clause.
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
This feature is **EXPERIMENTAL** and may be removed and modified
without warning in the future. Currently this is only supported
in OSS and can only be used with a table that does not have an ANN
index.
without warning in the future.
Returns
-------
@@ -389,6 +434,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""
def __init__(self, table: "lancedb.table.Table", query: str):
super().__init__(table)
self._query = query

View File

@@ -13,7 +13,7 @@
import functools
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union
import aiohttp
import attrs
@@ -151,9 +151,13 @@ class RestfulLanceDBClient:
return await deserialize(resp)
@_check_not_closed
async def list_tables(self):
async def list_tables(
self, limit: int, page_token: Optional[str] = None
) -> Iterable[str]:
"""List all tables in the database."""
json = await self.get("/v1/table/", {})
if page_token is None:
page_token = ""
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"]
@_check_not_closed

View File

@@ -12,14 +12,19 @@
# limitations under the License.
import asyncio
import inspect
import logging
import uuid
from typing import List, Optional
from typing import Iterable, List, Optional, Union
from urllib.parse import urlparse
import pyarrow as pa
from overrides import override
from ..common import DATA
from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
@@ -52,11 +57,31 @@ class RemoteDBConnection(DBConnection):
def __repr__(self) -> str:
return f"RemoveConnect(name={self.db_name})"
def table_names(self) -> List[str]:
"""List the names of all tables in the database."""
result = self._loop.run_until_complete(self._client.list_tables())
return result
@override
def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]:
"""List the names of all tables in the database.
Parameters
----------
page_token: str
The last token to start the new page.
Returns
-------
An iterator of table names.
"""
while True:
result = self._loop.run_until_complete(
self._client.list_tables(limit, page_token)
)
if len(result) > 0:
page_token = result[len(result) - 1]
else:
break
for item in result:
yield item
@override
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
@@ -71,23 +96,50 @@ class RemoteDBConnection(DBConnection):
"""
from .table import RemoteTable
# TODO: check if table exists
# check if table exists
try:
self._loop.run_until_complete(
self._client.post(f"/v1/table/{name}/describe/")
)
except Exception:
logging.error(
"Table {name} does not exist."
"Please first call db.create_table({name}, data)"
)
return RemoteTable(self, name)
@override
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
if data is None and schema is None:
raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None:
raise NotImplementedError(
"embedding_functions is not supported for remote databases."
"Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature."
)
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
if data is not None:
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
data,
schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
else:
if schema is None:
@@ -109,6 +161,7 @@ class RemoteDBConnection(DBConnection):
)
return RemoteTable(self, name)
@override
def drop_table(self, name: str):
"""Drop a table from the database.
@@ -122,3 +175,8 @@ class RemoteDBConnection(DBConnection):
f"/v1/table/{name}/drop/",
)
)
async def close(self):
"""Close the connection to the database."""
self._loop.close()
await self._client.close()

View File

@@ -44,6 +44,14 @@ class RemoteTable(Table):
schema = json_to_schema(resp["schema"])
return schema
@property
def version(self) -> int:
"""Get the current version of the table"""
resp = self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
return resp["version"]
def to_arrow(self) -> pa.Table:
"""Return the table as an Arrow table."""
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
@@ -99,10 +107,12 @@ class RemoteTable(Table):
return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table:
if query.prefilter:
raise NotImplementedError("Cloud support for prefiltering is coming soon")
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()
def delete(self, predicate: str):
raise NotImplementedError
"""Delete rows from the table."""
payload = {"predicate": predicate}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
)

View File

@@ -16,26 +16,29 @@ from __future__ import annotations
import inspect
import os
from abc import ABC, abstractmethod
from datetime import timedelta
from functools import cached_property
from typing import Any, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
import lance
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from lance import LanceDataset
from lance.dataset import CleanupStats, ReaderLike
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionRegistry
from .embeddings.functions import EmbeddingFunctionConfig
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas
from .utils.events import register_event
if TYPE_CHECKING:
from datetime import timedelta
from lance.dataset import CleanupStats, ReaderLike
pd = safe_import_pandas()
@@ -87,7 +90,9 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
for vector_column, conf in functions.items():
func = conf.function
if vector_column not in data.column_names:
col_data = func.compute_source_embeddings(data[conf.source_column])
col_data = func.compute_source_embeddings_with_retry(
data[conf.source_column]
)
if schema is not None:
dtype = schema.field(vector_column).type
else:
@@ -150,13 +155,13 @@ class Table(ABC):
@property
@abstractmethod
def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
this [Table](Table)
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
of this Table
"""
raise NotImplementedError
def to_pandas(self):
def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
Returns
@@ -192,17 +197,18 @@ class Table(ABC):
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
L2 is euclidean distance.
num_partitions: int
num_partitions: int, default 256
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int
num_sub_vectors: int, default 96
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
If True, replace the existing index if it exists.
If False, raise an error if duplicate index exists.
- If True, replace the existing index if it exists.
- If False, raise an error if duplicate index exists.
accelerator: str, default None
If set, use the given accelerator to create the index.
Only support "cuda" for now.
@@ -221,8 +227,14 @@ class Table(ABC):
Parameters
----------
data: list-of-dict, dict, pd.DataFrame
The data to insert into the table.
data: DATA
The data to insert into the table. Acceptable types are:
- dict or list-of-dict
- pandas.DataFrame
- pyarrow.Table or pyarrow.RecordBatch
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
@@ -243,31 +255,70 @@ class Table(ABC):
query_type: str = "auto",
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
All query options are defined in [Query][lancedb.query.Query].
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
... ]
>>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query, vector_column_name="vector")
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"])
... .limit(2)
... .to_pandas())
caption original_width vector _distance
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
1 test 3000 [0.3, 6.2, 2.6] 23.089996
Parameters
----------
query: str, list, np.ndarray, PIL.Image.Image, default None
The query to search for. If None then
the select/where/limit clauses are applied to filter
query: list/np.ndarray/str/PIL.Image.Image, default None
The targetted vector to search for.
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, default "vector"
vector_column_name: str
The name of the vector column to search.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If `query` is a string, then the query type is "vector" if the
*default "vector"*
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", or "auto"
- If "auto" then the query type is inferred from the query;
- If `query` is a list/np.ndarray then the query type is
"vector";
- If `query` is a PIL.Image.Image then either do vector search,
or raise an error if no corresponding embedding function is found.
- If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
Once executed, the query returns selected columns, the vector,
and also the "_distance" column which is the distance between the query
Once executed, the query returns
- selected columns
- the vector
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
raise NotImplementedError
@@ -286,14 +337,20 @@ class Table(ABC):
Parameters
----------
where: str
The SQL where clause to use when deleting rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
The SQL where clause to use when deleting rows.
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
The filter must not be empty, or it will error.
Examples
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
@@ -377,7 +434,8 @@ class LanceTable(Table):
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table = db.create_table("my_table",
... [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version
2
>>> table.to_pandas()
@@ -424,7 +482,8 @@ class LanceTable(Table):
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table = db.create_table("my_table", [
... {"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version
2
>>> table.to_pandas()
@@ -669,14 +728,39 @@ class LanceTable(Table):
query_type: str = "auto",
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
of the given query vector. We currently support [vector search][search]
and [full-text search][search].
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
... ]
>>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query, vector_column_name="vector")
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"])
... .limit(2)
... .to_pandas())
caption original_width vector _distance
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
1 test 3000 [0.3, 6.2, 2.6] 23.089996
Parameters
----------
query: str, list, np.ndarray, a PIL Image or None
The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
query: list/np.ndarray/str/PIL.Image.Image, default None
The targetted vector to search for.
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/[where][sql]/limit clauses are applied
to filter the table
vector_column_name: str, default "vector"
The name of the vector column to search.
query_type: str, default "auto"
@@ -685,7 +769,7 @@ class LanceTable(Table):
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If the query is a string, then the query type is "vector" if the
If the `query` is a string, then the query type is "vector" if the
table has embedding functions, else the query type is "fts"
Returns
@@ -719,8 +803,11 @@ class LanceTable(Table):
Examples
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
@@ -739,7 +826,8 @@ class LanceTable(Table):
The data to insert into the table.
At least one of `data` or `schema` must be provided.
schema: pa.Schema or LanceModel, optional
The schema of the table. If not provided, the schema is inferred from the data.
The schema of the table. If not provided,
the schema is inferred from the data.
At least one of `data` or `schema` must be provided.
mode: str, default "create"
The mode to use when writing the data. Valid values are
@@ -810,7 +898,8 @@ class LanceTable(Table):
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
return tbl
@@ -836,8 +925,11 @@ class LanceTable(Table):
Examples
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
@@ -870,12 +962,6 @@ class LanceTable(Table):
def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance()
if query.prefilter:
for idx in ds.list_indices():
if query.vector_column in idx["fields"]:
raise NotImplementedError(
"Prefiltering for indexed vector column is coming soon."
)
return ds.to_table(
columns=query.columns,
filter=query.filter,
@@ -1017,7 +1103,8 @@ def _sanitize_vector_column(
# ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_list(data[vector_column_name].type):
# if it's a variable size list array we make sure the dimensions are all the same
# if it's a variable size list array,
# we make sure the dimensions are all the same
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
if has_jagged_ndims:
data = _sanitize_jagged(

View File

@@ -63,7 +63,8 @@ def set_sentry():
"""
if "exc_info" in hint:
exc_type, exc_value, tb = hint["exc_info"]
if "out of memory" in str(exc_value).lower():
ignored_errors = ["out of memory", "no space left on device", "testing"]
if any(error in str(exc_value).lower() for error in ignored_errors):
return None
if is_git_dir():
@@ -97,7 +98,7 @@ def set_sentry():
dsn="https://c63ef8c64e05d1aa1a96513361f3ca2f@o4505950840946688.ingest.sentry.io/4505950933614592",
debug=False,
include_local_variables=False,
traces_sample_rate=1.0,
traces_sample_rate=0.5,
environment="production", # 'dev' or 'production'
before_send=before_send,
ignore_errors=[KeyboardInterrupt, FileNotFoundError, bdb.BdbQuit],

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.3.1"
version = "0.3.3"
dependencies = [
"deprecation",
"pylance==0.8.3",
"pylance==0.8.10",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.1.0",
@@ -14,7 +14,8 @@ dependencies = [
"cachetools",
"pyyaml>=6.0",
"click>=8.1.7",
"requests>=2.31.0"
"requests>=2.31.0",
"overrides>=0.7"
]
description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
@@ -52,7 +53,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip", "cohere"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
[project.scripts]
lancedb = "lancedb.cli.cli:cli"
@@ -64,6 +65,9 @@ build-backend = "setuptools.build_meta"
[tool.isort]
profile = "black"
[tool.ruff]
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [

View File

@@ -129,7 +129,7 @@ def test_ingest_iterator(tmp_path):
[
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
]
],
# TODO: test pydict separately. it is unique column number and names contraint
]
@@ -150,6 +150,21 @@ def test_ingest_iterator(tmp_path):
run_tests(PydanticSchema)
def test_table_names(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
db.create_table("test2", data=data)
db.create_table("test1", data=data)
db.create_table("test3", data=data)
assert db.table_names() == ["test1", "test2", "test3"]
def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
@@ -287,3 +302,27 @@ def test_replace_index(tmp_path):
num_sub_vectors=4,
replace=True,
)
def test_prefilter_with_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
data = [
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
]
sample_key = data[100]["vector"]
table = db.create_table(
"test",
data,
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
)
table = (
table.search(sample_key)
.where("price == 500", prefilter=True)
.limit(5)
.to_arrow()
)
assert table.num_rows == 1

View File

@@ -15,13 +15,16 @@ import sys
import lance
import numpy as np
import pyarrow as pa
import pytest
from lancedb.conftest import MockTextEmbeddingFunction
import lancedb
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
with_embeddings,
)
from lancedb.pydantic import LanceModel, Vector
def mock_embed_func(input_data):
@@ -83,3 +86,29 @@ def test_embedding_function(tmp_path):
expected = func.compute_query_embeddings("hello world")
assert np.allclose(actual, expected)
def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model):
class Schema(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
return Schema
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
model = registry.get("test-rate-limited").create(max_retries=0)
schema = _get_schema_from_model(model)
table = db.create_table("test", schema=schema, mode="overwrite")
table.add([{"text": "hello world"}])
with pytest.raises(Exception):
table.add([{"text": "hello world"}])
assert len(table) == 1
model = registry.get("test-rate-limited").create()
schema = _get_schema_from_model(model)
table = db.create_table("test", schema=schema, mode="overwrite")
table.add([{"text": "hello world"}])
table.add([{"text": "hello world"}])
assert len(table) == 2

View File

@@ -19,7 +19,7 @@ import pytest
import requests
import lancedb
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
# These are integration tests for embedding functions.
@@ -31,12 +31,15 @@ from lancedb.pydantic import LanceModel, Vector
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_sentence_transformer(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get(alias).create()
registry = get_registry()
func = registry.get(alias).create(max_retries=0)
func2 = registry.get(alias).create(max_retries=0)
class Words(LanceModel):
text: str = func.SourceField()
text2: str = func2.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vector2: Vector(func2.ndims()) = func2.VectorField()
table = db.create_table("words", schema=Words)
table.add(
@@ -50,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path):
"foo",
"bar",
"baz",
]
],
"text2": [
"to be or not to be",
"that is the question",
"for whether tis nobler",
"in the mind to suffer",
"the slings and arrows",
"of outrageous fortune",
"or to take arms",
],
}
)
)
@@ -62,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path):
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
assert actual.text == expected.text
assert actual.text == "hello world"
assert not np.allclose(actual.vector, actual.vector2)
actual = (
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
)
assert actual.text != "hello world"
assert not np.allclose(actual.vector, actual.vector2)
@pytest.mark.slow
@@ -69,7 +88,7 @@ def test_openclip(tmp_path):
from PIL import Image
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
registry = get_registry()
func = registry.get("open-clip").create()
class Images(LanceModel):
@@ -132,9 +151,9 @@ def test_openclip(tmp_path):
) # also skip if cohere not installed
def test_cohere_embedding_function():
cohere = (
EmbeddingFunctionRegistry.get_instance()
get_registry()
.get("cohere")
.create(name="embed-multilingual-v2.0")
.create(name="embed-multilingual-v2.0", max_retries=0)
)
class TextModel(LanceModel):
@@ -147,3 +166,19 @@ def test_cohere_embedding_function():
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
@pytest.mark.slow
def test_instructor_embedding(tmp_path):
model = get_registry().get("instructor").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()

View File

@@ -14,6 +14,7 @@
import json
import sys
from datetime import date, datetime
from typing import List, Optional
import pyarrow as pa
@@ -40,10 +41,18 @@ def test_pydantic_to_arrow():
li: List[int]
opt: Optional[str] = None
st: StructModel
dt: date
dtt: datetime
# d: dict
m = TestModel(
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
id=1,
s="hello",
vec=[1.0, 2.0, 3.0],
li=[2, 3, 4],
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
)
schema = pydantic_to_schema(TestModel)
@@ -62,6 +71,8 @@ def test_pydantic_to_arrow():
),
False,
),
pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False),
]
)
assert schema == expect_schema
@@ -79,10 +90,18 @@ def test_pydantic_to_arrow_py38():
li: List[int]
opt: Optional[str] = None
st: StructModel
dt: date
dtt: datetime
# d: dict
m = TestModel(
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
id=1,
s="hello",
vec=[1.0, 2.0, 3.0],
li=[2, 3, 4],
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
)
schema = pydantic_to_schema(TestModel)
@@ -101,6 +120,8 @@ def test_pydantic_to_arrow_py38():
),
False,
),
pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False),
]
)
assert schema == expect_schema

View File

@@ -458,7 +458,8 @@ def test_compact_cleanup(db):
stats = table.compact_files()
assert len(table) == 3
assert table.version == 4
# Compact_files bump 2 versions.
assert table.version == 5
assert stats.fragments_removed > 0
assert stats.fragments_added == 1
@@ -467,7 +468,7 @@ def test_compact_cleanup(db):
stats = table.cleanup_old_versions(older_than=timedelta(0), delete_unverified=True)
assert stats.bytes_removed > 0
assert table.version == 4
assert table.version == 5
with pytest.raises(Exception, match="Version 3 no longer exists"):
table.checkout(3)

View File

@@ -1,6 +1,6 @@
[package]
name = "vectordb-node"
version = "0.3.0"
version = "0.3.7"
description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
edition = "2018"

View File

@@ -70,7 +70,6 @@ fn get_index_params_builder(
.map(|mt| {
let metric_type = mt.unwrap();
index_builder.metric_type(metric_type);
pq_params.metric_type = metric_type;
});
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;

View File

@@ -74,7 +74,7 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
static LOG: OnceCell<()> = OnceCell::new();
LOG.get_or_init(|| env_logger::init());
LOG.get_or_init(env_logger::init);
RUNTIME.get_or_try_init(|| Runtime::new().or_throw(cx))
}
@@ -148,7 +148,7 @@ fn get_aws_creds(
match (secret_key_id, secret_key, temp_token) {
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
StaticCredentialProvider::new(AwsCredential {
key_id: key_id,
key_id,
secret_key: key,
token: optional_token,
}),
@@ -239,6 +239,8 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableDelete", JsTable::js_delete)?;
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
cx.export_function("tableIndexStats", JsTable::js_index_stats)?;
cx.export_function(
"tableCreateVectorIndex",
index::vector::table_create_vector_index,

View File

@@ -70,7 +70,7 @@ impl JsTable {
store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_creds, aws_region,
)),
mode: mode,
mode,
..WriteParams::default()
};
@@ -121,7 +121,7 @@ impl JsTable {
let add_result = table.add(batch_reader, Some(params)).await;
deferred.settle_with(&channel, move |mut cx| {
let _added = add_result.or_throw(&mut cx)?;
add_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
});
});
@@ -247,7 +247,7 @@ impl JsTable {
}
rt.spawn(async move {
let stats = table.compact_files(options).await;
let stats = table.compact_files(options, None).await;
deferred.settle_with(&channel, move |mut cx| {
let stats = stats.or_throw(&mut cx)?;
@@ -276,4 +276,91 @@ impl JsTable {
});
Ok(promise)
}
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
let channel = cx.channel();
let table = js_table.table.clone();
rt.spawn(async move {
let indices = table.load_indices().await;
deferred.settle_with(&channel, move |mut cx| {
let indices = indices.or_throw(&mut cx)?;
let output = JsArray::new(&mut cx, indices.len() as u32);
for (i, index) in indices.iter().enumerate() {
let js_index = JsObject::new(&mut cx);
let index_name = cx.string(index.index_name.clone());
js_index.set(&mut cx, "name", index_name)?;
let index_uuid = cx.string(index.index_uuid.clone());
js_index.set(&mut cx, "uuid", index_uuid)?;
let js_index_columns = JsArray::new(&mut cx, index.columns.len() as u32);
for (j, column) in index.columns.iter().enumerate() {
let js_column = cx.string(column.clone());
js_index_columns.set(&mut cx, j as u32, js_column)?;
}
js_index.set(&mut cx, "columns", js_index_columns)?;
output.set(&mut cx, i as u32, js_index)?;
}
Ok(output)
})
});
Ok(promise)
}
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
let channel = cx.channel();
let table = js_table.table.clone();
rt.spawn(async move {
let load_stats = futures::try_join!(
table.count_indexed_rows(&index_uuid),
table.count_unindexed_rows(&index_uuid)
);
deferred.settle_with(&channel, move |mut cx| {
let (indexed_rows, unindexed_rows) = load_stats.or_throw(&mut cx)?;
let output = JsObject::new(&mut cx);
match indexed_rows {
Some(x) => {
let i = cx.number(x as f64);
output.set(&mut cx, "numIndexedRows", i)?;
}
None => {
let null = cx.null();
output.set(&mut cx, "numIndexedRows", null)?;
}
};
match unindexed_rows {
Some(x) => {
let i = cx.number(x as f64);
output.set(&mut cx, "numUnindexedRows", i)?;
}
None => {
let null = cx.null();
output.set(&mut cx, "numUnindexedRows", null)?;
}
};
Ok(output)
})
});
Ok(promise)
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "vectordb"
version = "0.3.0"
version = "0.3.7"
edition = "2021"
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license = "Apache-2.0"

View File

@@ -18,9 +18,9 @@ use arrow::compute::kernels::{aggregate::bool_and, length::length};
use arrow_array::{
cast::AsArray,
types::{ArrowPrimitiveType, Int32Type, Int64Type},
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
Array, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatchReader,
};
use arrow_ord::comparison::eq_dyn_scalar;
use arrow_ord::cmp::eq;
use arrow_schema::DataType;
use num_traits::{ToPrimitive, Zero};
@@ -38,7 +38,8 @@ where
}
let dim = len_arr.as_primitive::<T>().value(0);
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
let datum = PrimitiveArray::<T>::new_scalar(dim);
if bool_and(&eq(len_arr.as_primitive::<T>(), &datum)?) != Some(true) {
Ok(None)
} else {
Ok(Some(dim))

View File

@@ -135,7 +135,7 @@ impl Database {
async fn open_path(path: &str) -> Result<Database> {
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
}
Ok(Self {
uri: path.to_string(),
@@ -161,7 +161,7 @@ impl Database {
///
/// * A [Vec<String>] with all table names.
pub async fn table_names(&self) -> Result<Vec<String>> {
let f = self
let mut f = self
.object_store
.read_dir(self.base_path.clone())
.await?
@@ -175,7 +175,8 @@ impl Database {
is_lance.unwrap_or(false)
})
.filter_map(|p| p.file_stem().and_then(|s| s.to_str().map(String::from)))
.collect();
.collect::<Vec<String>>();
f.sort();
Ok(f)
}
@@ -312,8 +313,8 @@ mod tests {
let db = Database::connect(uri).await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 2);
assert!(tables.contains(&String::from("table1")));
assert!(tables.contains(&String::from("table2")));
assert!(tables[0].eq(&String::from("table1")));
assert!(tables[1].eq(&String::from("table2")));
}
#[tokio::test]

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use lance::format::{Index, Manifest};
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::VectorIndexParams;
@@ -95,10 +96,14 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
}
fn build(&self) -> VectorIndexParams {
let ivf_params = self.ivf_params.clone().unwrap_or(IvfBuildParams::default());
let pq_params = self.pq_params.clone().unwrap_or(PQBuildParams::default());
let ivf_params = self.ivf_params.clone().unwrap_or_default();
let pq_params = self.pq_params.clone().unwrap_or_default();
VectorIndexParams::with_ivf_pq_params(pq_params.metric_type, ivf_params, pq_params)
VectorIndexParams::with_ivf_pq_params(
self.metric_type.unwrap_or(MetricType::L2),
ivf_params,
pq_params,
)
}
fn get_replace(&self) -> bool {
@@ -106,6 +111,27 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
}
}
pub struct VectorIndex {
pub columns: Vec<String>,
pub index_name: String,
pub index_uuid: String,
}
impl VectorIndex {
pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
let fields = index
.fields
.iter()
.map(|i| manifest.schema.fields[*i as usize].name.clone())
.collect();
VectorIndex {
columns: fields,
index_name: index.name.clone(),
index_uuid: index.uuid.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -158,7 +184,6 @@ mod tests {
pq_params.max_iters = 1;
pq_params.num_bits = 8;
pq_params.num_sub_vectors = 50;
pq_params.metric_type = MetricType::Cosine;
pq_params.max_opq_iters = 2;
index_builder.ivf_params(ivf_params);
index_builder.pq_params(pq_params);
@@ -176,7 +201,6 @@ mod tests {
assert_eq!(pq_params.max_iters, 1);
assert_eq!(pq_params.num_bits, 8);
assert_eq!(pq_params.num_sub_vectors, 50);
assert_eq!(pq_params.metric_type, MetricType::Cosine);
assert_eq!(pq_params.max_opq_iters, 2);
} else {
assert!(false, "Expected second stage to be pq")

View File

@@ -25,7 +25,8 @@ use bytes::Bytes;
use futures::{stream::BoxStream, FutureExt, StreamExt};
use lance::io::object_store::WrappingObjectStore;
use object_store::{
path::Path, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result,
path::Path, Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore,
Result,
};
use async_trait::async_trait;
@@ -57,7 +58,7 @@ trait PrimaryOnly {
impl PrimaryOnly for Path {
fn primary_only(&self) -> bool {
self.to_string().contains("manifest")
self.filename().unwrap_or("") == "_latest.manifest"
}
}
@@ -118,8 +119,13 @@ impl ObjectStore for MirroringObjectStore {
self.primary.head(location).await
}
// garbage collection on secondary will happen async from other means
async fn delete(&self, location: &Path) -> Result<()> {
if !location.primary_only() {
match self.secondary.delete(location).await {
Err(Error::NotFound { .. }) | Ok(_) => {}
Err(e) => return Err(e),
}
}
self.primary.delete(location).await
}
@@ -132,7 +138,7 @@ impl ObjectStore for MirroringObjectStore {
}
async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
if from.primary_only() {
if to.primary_only() {
self.primary.copy(from, to).await
} else {
self.secondary.copy(from, to).await?;
@@ -142,6 +148,9 @@ impl ObjectStore for MirroringObjectStore {
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
if !to.primary_only() {
self.secondary.copy(from, to).await?;
}
self.primary.copy_if_not_exists(from, to).await
}
}
@@ -379,7 +388,7 @@ mod test {
let primary_f = primary_elem.unwrap().unwrap();
// hit manifest, skip, _versions contains all the manifest and should not exist on secondary
let primary_raw_path = primary_f.file_name().to_str().unwrap();
if primary_raw_path.contains("manifest") || primary_raw_path.contains("_versions") {
if primary_raw_path.contains("_latest.manifest") {
primary_elem = primary_iter.next();
continue;
}

View File

@@ -18,14 +18,16 @@ use std::sync::Arc;
use arrow_array::{Float32Array, RecordBatchReader};
use arrow_schema::SchemaRef;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
};
use lance::dataset::{Dataset, WriteParams};
use lance::index::IndexType;
use lance::index::{DatasetIndexExt, IndexType};
use lance::io::object_store::WrappingObjectStore;
use std::path::Path;
use crate::error::{Error, Result};
use crate::index::vector::VectorIndexBuilder;
use crate::index::vector::{VectorIndex, VectorIndexBuilder};
use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode;
@@ -153,6 +155,22 @@ impl Table {
})
}
pub async fn checkout_latest(&self) -> Result<Self> {
let latest_version_id = self.dataset.latest_version_id().await?;
let dataset = if latest_version_id == self.dataset.version().version {
self.dataset.clone()
} else {
Arc::new(self.dataset.checkout_version(latest_version_id).await?)
};
Ok(Table {
name: self.name.clone(),
uri: self.uri.clone(),
dataset,
store_wrapper: self.store_wrapper.clone(),
})
}
fn get_table_name(uri: &str) -> Result<String> {
let path = Path::new(uri);
let name = path
@@ -222,8 +240,6 @@ impl Table {
/// Create index on the table.
pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> {
use lance::index::DatasetIndexExt;
let mut dataset = self.dataset.as_ref().clone();
dataset
.create_index(
@@ -241,6 +257,14 @@ impl Table {
Ok(())
}
pub async fn optimize_indices(&mut self) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
dataset.optimize_indices().await?;
Ok(())
}
/// Insert records into this Table
///
/// # Arguments
@@ -337,12 +361,45 @@ impl Table {
/// for faster reads.
///
/// This calls into [lance::dataset::optimize::compact_files].
pub async fn compact_files(&mut self, options: CompactionOptions) -> Result<CompactionMetrics> {
pub async fn compact_files(
&mut self,
options: CompactionOptions,
remap_options: Option<Arc<dyn IndexRemapperOptions>>,
) -> Result<CompactionMetrics> {
let mut dataset = self.dataset.as_ref().clone();
let metrics = compact_files(&mut dataset, options).await?;
let metrics = compact_files(&mut dataset, options, remap_options).await?;
self.dataset = Arc::new(dataset);
Ok(metrics)
}
pub fn count_fragments(&self) -> usize {
self.dataset.count_fragments()
}
pub async fn count_deleted_rows(&self) -> Result<usize> {
Ok(self.dataset.count_deleted_rows().await?)
}
pub async fn num_small_files(&self, max_rows_per_group: usize) -> usize {
self.dataset.num_small_files(max_rows_per_group).await
}
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
Ok(self.dataset.count_indexed_rows(index_uuid).await?)
}
pub async fn count_unindexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
Ok(self.dataset.count_unindexed_rows(index_uuid).await?)
}
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
let (indices, mf) =
futures::try_join!(self.dataset.load_indices(), self.dataset.latest_manifest())?;
Ok(indices
.iter()
.map(|i| VectorIndex::new_from_format(&mf, i))
.collect())
}
}
#[cfg(test)]