mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
42 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe89a373a2 | ||
|
|
3d3915edef | ||
|
|
e2e8b6aee4 | ||
|
|
12dbca5248 | ||
|
|
a6babfa651 | ||
|
|
75ede86fab | ||
|
|
becd649130 | ||
|
|
9d2fb7d602 | ||
|
|
fdb5d6fdf1 | ||
|
|
2f13fa225f | ||
|
|
e933de003d | ||
|
|
05fd387425 | ||
|
|
82a1da554c | ||
|
|
a7c0d80b9e | ||
|
|
71323a064a | ||
|
|
df48454b70 | ||
|
|
6603414885 | ||
|
|
c256f6c502 | ||
|
|
cc03f90379 | ||
|
|
975da09b02 | ||
|
|
c32e17b497 | ||
|
|
0528abdf97 | ||
|
|
1090c311e8 | ||
|
|
e767cbb374 | ||
|
|
3d7c48feca | ||
|
|
08d62550bb | ||
|
|
b272408b05 | ||
|
|
46ffa87cd4 | ||
|
|
cd9fc37b95 | ||
|
|
431f94e564 | ||
|
|
c1a7d65473 | ||
|
|
1e5ccb1614 | ||
|
|
2e7ab373dc | ||
|
|
c7fbc4aaee | ||
|
|
7e023c1ef2 | ||
|
|
1d0dd9a8b8 | ||
|
|
deb947ddbd | ||
|
|
b039765d50 | ||
|
|
d155e82723 | ||
|
|
5d8c91256c | ||
|
|
44c03ebef3 | ||
|
|
8ea06fe7f3 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.17
|
||||
current_version = 0.4.19
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
3
.github/workflows/cargo-publish.yml
vendored
3
.github/workflows/cargo-publish.yml
vendored
@@ -8,6 +8,9 @@ env:
|
||||
# This env var is used by Swatinem/rust-cache@v2 for the cache
|
||||
# key, so we set it to make sure it is always consistent.
|
||||
CARGO_TERM_COLOR: always
|
||||
# Up-to-date compilers needed for fp16kernels.
|
||||
CC: gcc-12
|
||||
CXX: g++-12
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
6
.github/workflows/pypi-publish.yml
vendored
6
.github/workflows/pypi-publish.yml
vendored
@@ -6,6 +6,8 @@ on:
|
||||
|
||||
jobs:
|
||||
linux:
|
||||
# Only runs on tags that matches the python-make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
name: Python ${{ matrix.config.platform }} manylinux${{ matrix.config.manylinux }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
@@ -44,6 +46,8 @@ jobs:
|
||||
token: ${{ secrets.LANCEDB_PYPI_API_TOKEN }}
|
||||
repo: "pypi"
|
||||
mac:
|
||||
# Only runs on tags that matches the python-make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
timeout-minutes: 60
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
strategy:
|
||||
@@ -76,6 +80,8 @@ jobs:
|
||||
token: ${{ secrets.LANCEDB_PYPI_API_TOKEN }}
|
||||
repo: "pypi"
|
||||
windows:
|
||||
# Only runs on tags that matches the python-make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||
timeout-minutes: 60
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,7 +6,7 @@
|
||||
venv
|
||||
|
||||
.vscode
|
||||
|
||||
.zed
|
||||
rust/target
|
||||
rust/Cargo.lock
|
||||
|
||||
|
||||
24
Cargo.toml
24
Cargo.toml
@@ -14,19 +14,19 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.10.10", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.10.10" }
|
||||
lance-linalg = { "version" = "=0.10.10" }
|
||||
lance-testing = { "version" = "=0.10.10" }
|
||||
lance = { "version" = "=0.10.16", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.10.16" }
|
||||
lance-linalg = { "version" = "=0.10.16" }
|
||||
lance-testing = { "version" = "=0.10.16" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
arrow-data = "50.0"
|
||||
arrow-ipc = "50.0"
|
||||
arrow-ord = "50.0"
|
||||
arrow-schema = "50.0"
|
||||
arrow-arith = "50.0"
|
||||
arrow-cast = "50.0"
|
||||
arrow = { version = "51.0", optional = false }
|
||||
arrow-array = "51.0"
|
||||
arrow-data = "51.0"
|
||||
arrow-ipc = "51.0"
|
||||
arrow-ord = "51.0"
|
||||
arrow-schema = "51.0"
|
||||
arrow-arith = "51.0"
|
||||
arrow-cast = "51.0"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
<hr />
|
||||
|
||||
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings.
|
||||
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrieval, filtering and management of embeddings.
|
||||
|
||||
The key features of LanceDB include:
|
||||
|
||||
@@ -36,7 +36,7 @@ The key features of LanceDB include:
|
||||
|
||||
* GPU support in building vector index(*).
|
||||
|
||||
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
||||
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/docs/integrations/vectorstores/lancedb/), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
||||
|
||||
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.
|
||||
|
||||
|
||||
@@ -57,16 +57,6 @@ plugins:
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- ultralytics:
|
||||
verbose: True
|
||||
enabled: True
|
||||
default_image: "assets/lancedb_and_lance.png" # Default image for all pages
|
||||
add_image: True # Automatically add meta image
|
||||
add_keywords: True # Add page keywords in the header tag
|
||||
add_share_buttons: True # Add social share buttons
|
||||
add_authors: False # Display page authors
|
||||
add_desc: False
|
||||
add_dates: False
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
@@ -104,6 +94,14 @@ nav:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Reranking:
|
||||
- Quickstart: reranking/index.md
|
||||
- Cohere Reranker: reranking/cohere.md
|
||||
- Linear Combination Reranker: reranking/linear_combination.md
|
||||
- Cross Encoder Reranker: reranking/cross_encoder.md
|
||||
- ColBERT Reranker: reranking/colbert.md
|
||||
- OpenAI Reranker: reranking/openai.md
|
||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
@@ -120,9 +118,10 @@ nav:
|
||||
- Pandas and PyArrow: python/pandas_and_pyarrow.md
|
||||
- Polars: python/polars_arrow.md
|
||||
- DuckDB: python/duckdb.md
|
||||
- LangChain 🔗: https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html
|
||||
- LangChain JS/TS 🔗: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb
|
||||
- LlamaIndex 🦙: https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html
|
||||
- LangChain:
|
||||
- LangChain 🔗: https://python.langchain.com/docs/integrations/vectorstores/lancedb/
|
||||
- LangChain JS/TS 🔗: https://js.langchain.com/docs/integrations/vectorstores/lancedb
|
||||
- LlamaIndex 🦙: https://docs.llamaindex.ai/en/stable/examples/vector_stores/LanceDBIndexDemo/
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
@@ -170,6 +169,14 @@ nav:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Reranking:
|
||||
- Quickstart: reranking/index.md
|
||||
- Cohere Reranker: reranking/cohere.md
|
||||
- Linear Combination Reranker: reranking/linear_combination.md
|
||||
- Cross Encoder Reranker: reranking/cross_encoder.md
|
||||
- ColBERT Reranker: reranking/colbert.md
|
||||
- OpenAI Reranker: reranking/openai.md
|
||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
@@ -186,8 +193,8 @@ nav:
|
||||
- Pandas and PyArrow: python/pandas_and_pyarrow.md
|
||||
- Polars: python/polars_arrow.md
|
||||
- DuckDB: python/duckdb.md
|
||||
- LangChain 🦜️🔗↗: https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html
|
||||
- LangChain.js 🦜️🔗↗: https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb
|
||||
- LangChain 🦜️🔗↗: https://python.langchain.com/docs/integrations/vectorstores/lancedb
|
||||
- LangChain.js 🦜️🔗↗: https://js.langchain.com/docs/integrations/vectorstores/lancedb
|
||||
- LlamaIndex 🦙↗: https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
|
||||
@@ -2,5 +2,4 @@ mkdocs==1.5.3
|
||||
mkdocs-jupyter==0.24.1
|
||||
mkdocs-material==9.5.3
|
||||
mkdocstrings[python]==0.20.0
|
||||
pydantic
|
||||
mkdocs-ultralytics-plugin==0.0.44
|
||||
pydantic
|
||||
@@ -154,9 +154,12 @@ Allows you to set parameters when registering a `sentence-transformers` object.
|
||||
!!! note "BAAI Embeddings example"
|
||||
Here is an example that uses BAAI embedding model from the HuggingFace Hub [supported models](https://huggingface.co/models?library=sentence-transformers)
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect("/tmp/db")
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
|
||||
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
@@ -165,7 +168,7 @@ Allows you to set parameters when registering a `sentence-transformers` object.
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"}
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
@@ -213,18 +216,21 @@ LanceDB registers the OpenAI embeddings function in the registry by default, as
|
||||
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect("/tmp/db")
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("openai").create()
|
||||
func = get_registry().get("openai").create(name="text-embedding-ada-002")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table = db.create_table("words", schema=Words, mode="overwrite")
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"}
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
@@ -353,6 +359,10 @@ Supported parameters (to be passed in `create` method) are:
|
||||
Usage Example:
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
model = get_registry().get("bedrock-text").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
@@ -387,10 +397,12 @@ This embedding function supports ingesting images as both bytes and urls. You ca
|
||||
LanceDB supports ingesting images directly from accessible links.
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("open-clip").create()
|
||||
func = get_registry.get("open-clip").create()
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
@@ -465,9 +477,12 @@ This function is registered as `imagebind` and supports Audio, Video and Text mo
|
||||
Below is an example demonstrating how the API works:
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("imagebind").create()
|
||||
func = get_registry.get("imagebind").create()
|
||||
|
||||
class ImageBindModel(LanceModel):
|
||||
text: str
|
||||
|
||||
@@ -46,7 +46,7 @@ For this purpose, LanceDB introduces an **embedding functions API**, that allow
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
vector: Vector(clip.ndims()) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
```
|
||||
|
||||
@@ -149,7 +149,7 @@ You can also use the integration for adding utility operations in the schema. Fo
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
vector: Vector(clip.ndims()) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
|
||||
@property
|
||||
@@ -166,4 +166,4 @@ rs[2].image
|
||||

|
||||
|
||||
Now that you have the basic idea about LanceDB embedding functions and the embedding function registry,
|
||||
let's dive deeper into defining your own [custom functions](./custom_embedding_function.md).
|
||||
let's dive deeper into defining your own [custom functions](./custom_embedding_function.md).
|
||||
|
||||
@@ -11,4 +11,64 @@ LanceDB supports 3 methods of working with embeddings.
|
||||
that extends the default embedding functions.
|
||||
|
||||
For python users, there is also a legacy [with_embeddings API](./legacy.md).
|
||||
It is retained for compatibility and will be removed in a future version.
|
||||
It is retained for compatibility and will be removed in a future version.
|
||||
|
||||
## Quickstart
|
||||
|
||||
To get started with embeddings, you can use the built-in embedding functions.
|
||||
|
||||
### OpenAI Embedding function
|
||||
LanceDB registers the OpenAI embeddings function in the registry as `openai`. You can pass any supported model name to the `create`. By default it uses `"text-embedding-ada-002"`.
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect("/tmp/db")
|
||||
func = get_registry().get("openai").create(name="text-embedding-ada-002")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words, mode="overwrite")
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
|
||||
### Sentence Transformers Embedding function
|
||||
LanceDB registers the Sentence Transformers embeddings function in the registry as `sentence-transformers`. You can pass any supported model name to the `create`. By default it uses `"sentence-transformers/paraphrase-MiniLM-L6-v2"`.
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect("/tmp/db")
|
||||
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
@@ -299,6 +299,14 @@ LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you m
|
||||
|
||||
This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables.
|
||||
|
||||
!!! tip "Local servers"
|
||||
|
||||
For local development, the server often has a `http` endpoint rather than a
|
||||
secure `https` endpoint. In this case, you must also set the `ALLOW_HTTP`
|
||||
environment variable to `true` to allow non-TLS connections, or pass the
|
||||
storage option `allow_http` as `true`. If you do not do this, you will get
|
||||
an error like `URL scheme is not allowed`.
|
||||
|
||||
#### S3 Express
|
||||
|
||||
LanceDB supports [S3 Express One Zone](https://aws.amazon.com/s3/storage-classes/express-one-zone/) endpoints, but requires additional configuration. Also, S3 Express endpoints only support connecting from an EC2 instance within the same region.
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install --quiet openai datasets \n",
|
||||
"!pip install --quiet openai datasets\n",
|
||||
"!pip install --quiet -U lancedb"
|
||||
]
|
||||
},
|
||||
@@ -213,7 +213,7 @@
|
||||
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
||||
" # OR set the key here as a variable\n",
|
||||
" os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"client = OpenAI()\n",
|
||||
"assert len(client.models.list().data) > 0"
|
||||
]
|
||||
@@ -234,9 +234,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def embed_func(c): \n",
|
||||
"def embed_func(c):\n",
|
||||
" rs = client.embeddings.create(input=c, model=\"text-embedding-ada-002\")\n",
|
||||
" return [rs.data[0].embedding]"
|
||||
" return [\n",
|
||||
" data.embedding\n",
|
||||
" for data in rs.data\n",
|
||||
" ]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -514,7 +517,7 @@
|
||||
" prompt_start +\n",
|
||||
" \"\\n\\n---\\n\\n\".join(context.text) +\n",
|
||||
" prompt_end\n",
|
||||
" ) \n",
|
||||
" )\n",
|
||||
" return prompt"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -24,7 +24,8 @@ data = [
|
||||
table = db.create_table("pd_table", data=data)
|
||||
```
|
||||
|
||||
To query the table, first call `to_lance` to convert the table to a "dataset", which is an object that can be queried by DuckDB. Then all you need to do is reference that dataset by the same name in your SQL query.
|
||||
The `to_lance` method converts the LanceDB table to a `LanceDataset`, which is accessible to DuckDB through the Arrow compatibility layer.
|
||||
To query the resulting Lance dataset in DuckDB, all you need to do is reference the dataset by the same name in your SQL query.
|
||||
|
||||
```python
|
||||
import duckdb
|
||||
|
||||
75
docs/src/reranking/cohere.md
Normal file
75
docs/src/reranking/cohere.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# Cohere Reranker
|
||||
|
||||
This re-ranker uses the [Cohere](https://cohere.ai/) API to rerank the search results. You can use this re-ranker by passing `CohereReranker()` to the `rerank()` method. Note that you'll either need to set the `COHERE_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.
|
||||
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = CohereReranker(api_key="key")
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `"rerank-english-v2.0"` | The name of the reranker model to use. Available cohere models are: rerank-english-v2.0, rerank-multilingual-v2.0 |
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
|
||||
| `api_key` | `str` | `None` | The API key for the Cohere API. If not provided, the `COHERE_API_KEY` environment variable is used. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
71
docs/src/reranking/colbert.md
Normal file
71
docs/src/reranking/colbert.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# ColBERT Reranker
|
||||
|
||||
This re-ranker uses ColBERT model to rerank the search results. You can use this re-ranker by passing `ColbertReranker()` to the `rerank()` method.
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import ColbertReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = ColbertReranker()
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `"colbert-ir/colbertv2.0"` | The name of the reranker model to use.|
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `device` | `str` | `None` | The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu". |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
70
docs/src/reranking/cross_encoder.md
Normal file
70
docs/src/reranking/cross_encoder.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# Cross Encoder Reranker
|
||||
|
||||
This re-ranker uses Cross Encoder models from sentence-transformers to rerank the search results. You can use this re-ranker by passing `CrossEncoderReranker()` to the `rerank()` method.
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CrossEncoderReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = CrossEncoderReranker()
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `""cross-encoder/ms-marco-TinyBERT-L-6"` | The name of the reranker model to use.|
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `device` | `str` | `None` | The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu". |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
88
docs/src/reranking/custom_reranker.md
Normal file
88
docs/src/reranking/custom_reranker.md
Normal file
@@ -0,0 +1,88 @@
|
||||
## Building Custom Rerankers
|
||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Optionally, you can also implement the `rerank_vector()` and `rerank_fts()` methods if you want to support reranking for vector and FTS search separately.
|
||||
Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||
|
||||
The `Reranker` base interface comes with a `merge_results()` method that can be used to combine the results of semantic and full-text search. This is a vanilla merging algorithm that simply concatenates the results and removes the duplicates without taking the scores into consideration. It only keeps the first copy of the row encountered. This works well in cases that don't require the scores of semantic and full-text search to combine the results. If you want to use the scores or want to support `return_score="all"`, you'll need to implement your own merging algorithm.
|
||||
|
||||
```python
|
||||
|
||||
from lancedb.rerankers import Reranker
|
||||
import pyarrow as pa
|
||||
|
||||
class MyReranker(Reranker):
|
||||
def __init__(self, param1, param2, ..., return_score="relevance"):
|
||||
super().__init__(return_score)
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
|
||||
# Do something with the combined results
|
||||
# ...
|
||||
|
||||
# Return the combined results
|
||||
return combined_result
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
# Do something with the vector results
|
||||
# ...
|
||||
|
||||
# Return the vector results
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
# Do something with the FTS results
|
||||
# ...
|
||||
|
||||
# Return the FTS results
|
||||
return fts_results
|
||||
|
||||
```
|
||||
|
||||
### Example of a Custom Reranker
|
||||
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
|
||||
|
||||
```python
|
||||
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
class ModifiedCohereReranker(CohereReranker):
|
||||
def __init__(self, filters: Union[str, List[str]], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
filters = filters if isinstance(filters, list) else [filters]
|
||||
self.filters = filters
|
||||
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table:
|
||||
combined_result = super().rerank_hybrid(query, vector_results, fts_results)
|
||||
df = combined_result.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table)-> pa.Table:
|
||||
vector_results = super().rerank_vector(query, vector_results)
|
||||
df = vector_results.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table)-> pa.Table:
|
||||
fts_results = super().rerank_fts(query, fts_results)
|
||||
df = fts_results.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
```
|
||||
|
||||
!!! tip
|
||||
The `vector_results` and `fts_results` are pyarrow tables. Lean more about pyarrow tables [here](https://arrow.apache.org/docs/python). It can be convered to other data types like pandas dataframe, pydict, pylist etc.
|
||||
|
||||
For example, You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.
|
||||
60
docs/src/reranking/index.md
Normal file
60
docs/src/reranking/index.md
Normal file
@@ -0,0 +1,60 @@
|
||||
Reranking is the process of reordering a list of items based on some criteria. In the context of search, reranking is used to reorder the search results returned by a search engine based on some criteria. This can be useful when the initial ranking of the search results is not satisfactory or when the user has provided additional information that can be used to improve the ranking of the search results.
|
||||
|
||||
LanceDB comes with some built-in rerankers. Some of the rerankers that are available in LanceDB are:
|
||||
|
||||
| Reranker | Description | Supported Query Types |
|
||||
| --- | --- | --- |
|
||||
| `LinearCombinationReranker` | Reranks search results based on a linear combination of FTS and vector search scores | Hybrid |
|
||||
| `CohereReranker` | Uses cohere Reranker API to rerank results | Vector, FTS, Hybrid |
|
||||
| `CrossEncoderReranker` | Uses a cross-encoder model to rerank search results | Vector, FTS, Hybrid |
|
||||
| `ColbertReranker` | Uses a colbert model to rerank search results | Vector, FTS, Hybrid |
|
||||
| `OpenaiReranker`(Experimental) | Uses OpenAI's chat model to rerank search results | Vector, FTS, Hybrid |
|
||||
|
||||
|
||||
## Using a Reranker
|
||||
Using rerankers is optional for vector and FTS. However, for hybrid search, rerankers are required. To use a reranker, you need to create an instance of the reranker and pass it to the `rerank` method of the query builder.
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", data)
|
||||
reranker = CohereReranker(api_key="your_api_key")
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.query("hello").rerank(reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.query("hello", query_type="fts").rerank(reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text")
|
||||
result = tbl.query("hello", query_type="hybrid").rerank(reranker).to_list()
|
||||
```
|
||||
|
||||
## Available Rerankers
|
||||
LanceDB comes with some built-in rerankers. Here are some of the rerankers that are available in LanceDB:
|
||||
|
||||
- [Cohere Reranker](./cohere.md)
|
||||
- [Cross Encoder Reranker](./cross_encoder.md)
|
||||
- [ColBERT Reranker](./colbert.md)
|
||||
- [OpenAI Reranker](./openai.md)
|
||||
- [Linear Combination Reranker](./linear_combination.md)
|
||||
|
||||
## Creating Custom Rerankers
|
||||
|
||||
LanceDB also you to create custom rerankers by extending the base `Reranker` class. The custom reranker should implement the `rerank` method that takes a list of search results and returns a reranked list of search results. This is covered in more detail in the [Creating Custom Rerankers](./custom_reranker.md) section.
|
||||
52
docs/src/reranking/linear_combination.md
Normal file
52
docs/src/reranking/linear_combination.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Linear Combination Reranker
|
||||
|
||||
This is the default re-ranker used by LanceDB hybrid search. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import LinearCombinationReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = LinearCombinationReranker()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `weight` | `float` | `0.7` | The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all", will return all scores from the vector and FTS search along with the relevance score. |
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_distance`) |
|
||||
73
docs/src/reranking/openai.md
Normal file
73
docs/src/reranking/openai.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# OpenAI Reranker (Experimental)
|
||||
|
||||
This re-ranker uses OpenAI chat model to rerank the search results. You can use this re-ranker by passing `OpenAI()` to the `rerank()` method.
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
!!! warning
|
||||
This re-ranker is experimental. OpenAI doesn't have a dedicated reranking model, so we are using the chat model for reranking.
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import OpenaiReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = OpenaiReranker()
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `"gpt-4-turbo-preview"` | The name of the reranker model to use.|
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
| `api_key` | str | `None` | The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
@@ -15,6 +15,7 @@ excluded_globs = [
|
||||
"../src/ann_indexes.md",
|
||||
"../src/basic.md",
|
||||
"../src/hybrid_search/hybrid_search.md",
|
||||
"../src/reranking/*.md",
|
||||
]
|
||||
|
||||
python_prefix = "py"
|
||||
|
||||
14
node/package-lock.json
generated
14
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,11 +52,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.17",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.17",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.17"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.19",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.19",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.19",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.19",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.19"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -88,10 +88,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.17",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.17",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.17"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.19",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.19",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.19",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.19",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.19"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,7 +163,7 @@ export interface CreateTableOptions<T> {
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
*
|
||||
* Accpeted formats:
|
||||
* Accepted formats:
|
||||
*
|
||||
* - `/path/to/database` - local database
|
||||
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
|
||||
|
||||
@@ -51,7 +51,7 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
|
||||
const dir = tmpdir()
|
||||
console.log(dir)
|
||||
const conn = await lancedb.connect(`s3://lancedb-integtest?mirroredStore=${dir}`)
|
||||
const conn = await lancedb.connect({ uri: `s3://lancedb-integtest?mirroredStore=${dir}`, storageOptions: { allowHttp: 'true' } })
|
||||
const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 })
|
||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 }))
|
||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 }))
|
||||
|
||||
@@ -140,6 +140,9 @@ export class RemoteConnection implements Connection {
|
||||
schema = nameOrOpts.schema
|
||||
embeddings = nameOrOpts.embeddingFunction
|
||||
tableName = nameOrOpts.name
|
||||
if (data === undefined) {
|
||||
data = nameOrOpts.data
|
||||
}
|
||||
}
|
||||
|
||||
let buffer: Buffer
|
||||
|
||||
1
nodejs/.gitignore
vendored
Normal file
1
nodejs/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
yarn.lock
|
||||
@@ -20,7 +20,7 @@ import { Table as ArrowTable, Schema } from "apache-arrow";
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
*
|
||||
* Accpeted formats:
|
||||
* Accepted formats:
|
||||
*
|
||||
* - `/path/to/database` - local database
|
||||
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
|
||||
@@ -77,6 +77,18 @@ export interface OpenTableOptions {
|
||||
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||
*/
|
||||
storageOptions?: Record<string, string>;
|
||||
/**
|
||||
* Set the size of the index cache, specified as a number of entries
|
||||
*
|
||||
* The exact meaning of an "entry" will depend on the type of index:
|
||||
* - IVF: there is one entry for each IVF partition
|
||||
* - BTREE: there is one entry for the entire index
|
||||
*
|
||||
* This cache applies to the entire opened table, across all indices.
|
||||
* Setting this value higher will increase performance on larger datasets
|
||||
* at the expense of more RAM
|
||||
*/
|
||||
indexCacheSize?: number;
|
||||
}
|
||||
|
||||
export interface TableNamesOptions {
|
||||
@@ -160,6 +172,7 @@ export class Connection {
|
||||
const innerTable = await this.inner.openTable(
|
||||
name,
|
||||
cleanseStorageOptions(options?.storageOptions),
|
||||
options?.indexCacheSize,
|
||||
);
|
||||
return new Table(innerTable);
|
||||
}
|
||||
|
||||
@@ -169,17 +169,20 @@ export class Table {
|
||||
* // If the column has a vector (fixed size list) data type then
|
||||
* // an IvfPq vector index will be created.
|
||||
* const table = await conn.openTable("my_table");
|
||||
* await table.createIndex(["vector"]);
|
||||
* await table.createIndex("vector");
|
||||
* @example
|
||||
* // For advanced control over vector index creation you can specify
|
||||
* // the index type and options.
|
||||
* const table = await conn.openTable("my_table");
|
||||
* await table.createIndex(["vector"], I)
|
||||
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
||||
* .build();
|
||||
* await table.createIndex("vector", {
|
||||
* config: lancedb.Index.ivfPq({
|
||||
* numPartitions: 128,
|
||||
* numSubVectors: 16,
|
||||
* }),
|
||||
* });
|
||||
* @example
|
||||
* // Or create a Scalar index
|
||||
* await table.createIndex("my_float_col").build();
|
||||
* await table.createIndex("my_float_col");
|
||||
*/
|
||||
async createIndex(column: string, options?: Partial<IndexOptions>) {
|
||||
// Bit of a hack to get around the fact that TS has no package-scope.
|
||||
@@ -197,8 +200,7 @@ export class Table {
|
||||
* vector similarity, sorting, and more.
|
||||
*
|
||||
* Note: By default, all columns are returned. For best performance, you should
|
||||
* only fetch the columns you need. See [`Query::select_with_projection`] for
|
||||
* more details.
|
||||
* only fetch the columns you need.
|
||||
*
|
||||
* When appropriate, various indices and statistics based pruning will be used to
|
||||
* accelerate the query.
|
||||
@@ -206,10 +208,13 @@ export class Table {
|
||||
* // SQL-style filtering
|
||||
* //
|
||||
* // This query will return up to 1000 rows whose value in the `id` column
|
||||
* // is greater than 5. LanceDb supports a broad set of filtering functions.
|
||||
* for await (const batch of table.query()
|
||||
* .filter("id > 1").select(["id"]).limit(20)) {
|
||||
* console.log(batch);
|
||||
* // is greater than 5. LanceDb supports a broad set of filtering functions.
|
||||
* for await (const batch of table
|
||||
* .query()
|
||||
* .where("id > 1")
|
||||
* .select(["id"])
|
||||
* .limit(20)) {
|
||||
* console.log(batch);
|
||||
* }
|
||||
* @example
|
||||
* // Vector Similarity Search
|
||||
@@ -218,13 +223,14 @@ export class Table {
|
||||
* // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
|
||||
* // on the "vector" column then this will perform an ANN search.
|
||||
* //
|
||||
* // The `refine_factor` and `nprobes` methods are used to control the recall /
|
||||
* // The `refineFactor` and `nprobes` methods are used to control the recall /
|
||||
* // latency tradeoff of the search.
|
||||
* for await (const batch of table.query()
|
||||
* .nearestTo([1, 2, 3])
|
||||
* .refineFactor(5).nprobe(10)
|
||||
* .limit(10)) {
|
||||
* console.log(batch);
|
||||
* for await (const batch of table
|
||||
* .query()
|
||||
* .where("id > 1")
|
||||
* .select(["id"])
|
||||
* .limit(20)) {
|
||||
* console.log(batch);
|
||||
* }
|
||||
* @example
|
||||
* // Scan the full dataset
|
||||
@@ -286,43 +292,45 @@ export class Table {
|
||||
await this.inner.dropColumns(columnNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the version of the table
|
||||
*
|
||||
* LanceDb supports versioning. Every operation that modifies the table increases
|
||||
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
|
||||
* version to view the data at that point. In addition, you can `[Self::restore]` the
|
||||
* version to replace the current table with a previous version.
|
||||
*/
|
||||
/** Retrieve the version of the table */
|
||||
async version(): Promise<number> {
|
||||
return await this.inner.version();
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks out a specific version of the Table
|
||||
* Checks out a specific version of the table _This is an in-place operation._
|
||||
*
|
||||
* Any read operation on the table will now access the data at the checked out version.
|
||||
* As a consequence, calling this method will disable any read consistency interval
|
||||
* that was previously set.
|
||||
* This allows viewing previous versions of the table. If you wish to
|
||||
* keep writing to the dataset starting from an old version, then use
|
||||
* the `restore` function.
|
||||
*
|
||||
* This is a read-only operation that turns the table into a sort of "view"
|
||||
* or "detached head". Other table instances will not be affected. To make the change
|
||||
* permanent you can use the `[Self::restore]` method.
|
||||
* Calling this method will set the table into time-travel mode. If you
|
||||
* wish to return to standard mode, call `checkoutLatest`.
|
||||
* @param {number} version The version to checkout
|
||||
* @example
|
||||
* ```typescript
|
||||
* import * as lancedb from "@lancedb/lancedb"
|
||||
* const db = await lancedb.connect("./.lancedb");
|
||||
* const table = await db.createTable("my_table", [
|
||||
* { vector: [1.1, 0.9], type: "vector" },
|
||||
* ]);
|
||||
*
|
||||
* Any operation that modifies the table will fail while the table is in a checked
|
||||
* out state.
|
||||
*
|
||||
* To return the table to a normal state use `[Self::checkout_latest]`
|
||||
* console.log(await table.version()); // 1
|
||||
* console.log(table.display());
|
||||
* await table.add([{ vector: [0.5, 0.2], type: "vector" }]);
|
||||
* await table.checkout(1);
|
||||
* console.log(await table.version()); // 2
|
||||
* ```
|
||||
*/
|
||||
async checkout(version: number): Promise<void> {
|
||||
await this.inner.checkout(version);
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the table is pointing at the latest version
|
||||
* Checkout the latest version of the table. _This is an in-place operation._
|
||||
*
|
||||
* This can be used to manually update a table when the read_consistency_interval is None
|
||||
* It can also be used to undo a `[Self::checkout]` operation
|
||||
* The table will be set back into standard mode, and will track the latest
|
||||
* version of the table.
|
||||
*/
|
||||
async checkoutLatest(): Promise<void> {
|
||||
await this.inner.checkoutLatest();
|
||||
@@ -344,9 +352,7 @@ export class Table {
|
||||
await this.inner.restore();
|
||||
}
|
||||
|
||||
/**
|
||||
* List all indices that have been created with Self::create_index
|
||||
*/
|
||||
/** List all indices that have been created with {@link Table.createIndex} */
|
||||
async listIndices(): Promise<IndexConfig[]> {
|
||||
return await this.inner.listIndices();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
|
||||
86
nodejs/package-lock.json
generated
86
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.4.16",
|
||||
"version": "0.4.18",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.4.16",
|
||||
"version": "0.4.18",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -45,13 +45,6 @@
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/lancedb-darwin-arm64": "0.4.16",
|
||||
"@lancedb/lancedb-darwin-x64": "0.4.16",
|
||||
"@lancedb/lancedb-linux-arm64-gnu": "0.4.16",
|
||||
"@lancedb/lancedb-linux-x64-gnu": "0.4.16",
|
||||
"@lancedb/lancedb-win32-x64-msvc": "0.4.16"
|
||||
}
|
||||
},
|
||||
"node_modules/@75lb/deep-merge": {
|
||||
@@ -2221,81 +2214,6 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-darwin-arm64": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-arm64/-/lancedb-darwin-arm64-0.4.16.tgz",
|
||||
"integrity": "sha512-CV65ouIDQbBSNtdHbQSr2fqXflOuqud1cfweUS+EiK7eEOEYl7nO2oiFYO49Jy76MEwZxiP99hW825aCqIQJqg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-darwin-x64": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-x64/-/lancedb-darwin-x64-0.4.16.tgz",
|
||||
"integrity": "sha512-1CwIYCNdbFmV7fvqM+qUxbYgwxx0slcCV48PC/I19Ejitgtzw/NJiWDCvONhaLqG85lWNZm1xYceRpVv7b8seQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-linux-arm64-gnu": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-arm64-gnu/-/lancedb-linux-arm64-gnu-0.4.16.tgz",
|
||||
"integrity": "sha512-CzLEbzoHKS6jV0k52YnvsiVNx0VzLp1Vz/zmbHI6HmB/XbS67qDO93Jk71MDmXq3JDw0FKFCw9ghkg+6YWq7ZA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-linux-x64-gnu": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-x64-gnu/-/lancedb-linux-x64-gnu-0.4.16.tgz",
|
||||
"integrity": "sha512-nKChybybi8uA0AFRHBFm7Fz3VXcRm8riv5Gs7xQsrsCtYxxf4DT/0BfUvQ0xKbwNJa+fawHRxi9BOQewdj49fg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-win32-x64-msvc": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-win32-x64-msvc/-/lancedb-win32-x64-msvc-0.4.16.tgz",
|
||||
"integrity": "sha512-KMeBPMpv2g+ZMVsHVibed7BydrBlxje1qS0bZTDrLw9BtZOk6XH2lh1mCDnCJI6sbAscUKNA6fDCdquhQPHL7w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/cli": {
|
||||
"version": "2.18.0",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.0.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.19",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"napi": {
|
||||
@@ -62,20 +62,14 @@
|
||||
"build-release": "npm run build:release && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts",
|
||||
"chkformat": "prettier . --check",
|
||||
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
|
||||
"lint": "eslint lancedb && eslint __test__",
|
||||
"lint": "eslint lancedb __test__",
|
||||
"lint-fix": "eslint lancedb __test__ --fix",
|
||||
"prepublishOnly": "napi prepublish -t npm",
|
||||
"test": "npm run build && jest --verbose",
|
||||
"integration": "S3_TEST=1 npm run test",
|
||||
"universal": "napi universal",
|
||||
"version": "napi version"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/lancedb-darwin-arm64": "0.4.17",
|
||||
"@lancedb/lancedb-darwin-x64": "0.4.17",
|
||||
"@lancedb/lancedb-linux-arm64-gnu": "0.4.17",
|
||||
"@lancedb/lancedb-linux-x64-gnu": "0.4.17",
|
||||
"@lancedb/lancedb-win32-x64-msvc": "0.4.17"
|
||||
},
|
||||
"dependencies": {
|
||||
"openai": "^4.29.2",
|
||||
"apache-arrow": "^15.0.0"
|
||||
|
||||
@@ -176,6 +176,7 @@ impl Connection {
|
||||
&self,
|
||||
name: String,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
index_cache_size: Option<u32>,
|
||||
) -> napi::Result<Table> {
|
||||
let mut builder = self.get_inner()?.open_table(&name);
|
||||
if let Some(storage_options) = storage_options {
|
||||
@@ -183,6 +184,9 @@ impl Connection {
|
||||
builder = builder.storage_option(key, value);
|
||||
}
|
||||
}
|
||||
if let Some(index_cache_size) = index_cache_size {
|
||||
builder = builder.index_cache_size(index_cache_size);
|
||||
}
|
||||
let tbl = builder
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.6.8
|
||||
current_version = 0.6.12
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -14,7 +14,7 @@ name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "50.0.0", features = ["pyarrow"] }
|
||||
arrow = { version = "51.0.0", features = ["pyarrow"] }
|
||||
lancedb = { path = "../rust/lancedb" }
|
||||
env_logger = "0.10"
|
||||
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.6.8"
|
||||
version = "0.6.12"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.10.10",
|
||||
"pylance==0.10.12",
|
||||
"ratelimiter~=1.0",
|
||||
"requests>=2.31.0",
|
||||
"retry>=0.9.2",
|
||||
@@ -65,7 +65,6 @@ docs = [
|
||||
"mkdocs-jupyter",
|
||||
"mkdocs-material",
|
||||
"mkdocstrings[python]",
|
||||
"mkdocs-ultralytics-plugin==0.0.44",
|
||||
]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = [
|
||||
|
||||
@@ -83,7 +83,7 @@ def connect(
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
Connect to LanceDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
|
||||
@@ -107,6 +107,9 @@ def connect(
|
||||
request_thread_pool=request_thread_pool,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
|
||||
|
||||
@@ -224,13 +224,23 @@ class DBConnection(EnforceOverrides):
|
||||
def __getitem__(self, name: str) -> LanceTable:
|
||||
return self.open_table(name)
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
index_cache_size: int, default 256
|
||||
Set the size of the index cache, specified as a number of entries
|
||||
|
||||
The exact meaning of an "entry" will depend on the type of index:
|
||||
* IVF - there is one entry for each IVF partition
|
||||
* BTREE - there is one entry for the entire index
|
||||
|
||||
This cache applies to the entire opened table, across all indices.
|
||||
Setting this value higher will increase performance on larger datasets
|
||||
at the expense of more RAM
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -248,6 +258,18 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cur_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
@@ -407,7 +429,9 @@ class LanceDBConnection(DBConnection):
|
||||
return tbl
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> LanceTable:
|
||||
def open_table(
|
||||
self, name: str, *, index_cache_size: Optional[int] = None
|
||||
) -> LanceTable:
|
||||
"""Open a table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -419,7 +443,7 @@ class LanceDBConnection(DBConnection):
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
return LanceTable.open(self, name)
|
||||
return LanceTable.open(self, name, index_cache_size=index_cache_size)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
@@ -751,7 +775,10 @@ class AsyncConnection(object):
|
||||
return AsyncTable(new_table)
|
||||
|
||||
async def open_table(
|
||||
self, name: str, storage_options: Optional[Dict[str, str]] = None
|
||||
self,
|
||||
name: str,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
@@ -764,12 +791,22 @@ class AsyncConnection(object):
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
index_cache_size: int, default 256
|
||||
Set the size of the index cache, specified as a number of entries
|
||||
|
||||
The exact meaning of an "entry" will depend on the type of index:
|
||||
* IVF - there is one entry for each IVF partition
|
||||
* BTREE - there is one entry for the entire index
|
||||
|
||||
This cache applies to the entire opened table, across all indices.
|
||||
Setting this value higher will increase performance on larger datasets
|
||||
at the expense of more RAM
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
table = await self._inner.open_table(name, storage_options)
|
||||
table = await self._inner.open_table(name, storage_options, index_cache_size)
|
||||
return AsyncTable(table)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
|
||||
@@ -255,7 +255,13 @@ def retry_with_exponential_backoff(
|
||||
)
|
||||
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
logging.info("Retrying in %s seconds...", delay)
|
||||
logging.warning(
|
||||
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
|
||||
e,
|
||||
delay,
|
||||
num_retries,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import (
|
||||
import deprecation
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.fs as pa_fs
|
||||
import pydantic
|
||||
|
||||
from . import __version__
|
||||
@@ -37,7 +38,7 @@ from .arrow import AsyncRecordBatchReader
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import safe_import_pandas
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
@@ -665,6 +666,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
# get the index path
|
||||
index_path = self._table._get_fts_index_path()
|
||||
|
||||
# Check that we are on local filesystem
|
||||
fs, _path = fs_from_uri(index_path)
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Full-text search is only supported on the local filesystem"
|
||||
)
|
||||
|
||||
# check if the index exist
|
||||
if not Path(index_path).exists():
|
||||
raise FileNotFoundError(
|
||||
|
||||
@@ -94,7 +94,7 @@ class RemoteDBConnection(DBConnection):
|
||||
yield item
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> Table:
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -110,6 +110,12 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
self._client.mount_retry_adapter_for_table(name)
|
||||
|
||||
if index_cache_size is not None:
|
||||
logging.info(
|
||||
"index_cache_size is ignored in LanceDb Cloud"
|
||||
" (there is no local cache to configure)"
|
||||
)
|
||||
|
||||
# check if table exists
|
||||
if self._table_cache.get(name) is None:
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
@@ -281,6 +287,24 @@ class RemoteDBConnection(DBConnection):
|
||||
)
|
||||
self._table_cache.pop(name)
|
||||
|
||||
@override
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cur_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
self._client.post(
|
||||
f"/v1/table/{cur_name}/rename/",
|
||||
json={"new_table_name": new_name},
|
||||
)
|
||||
self._table_cache.pop(cur_name)
|
||||
self._table_cache[new_name] = True
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._client.close()
|
||||
|
||||
@@ -72,7 +72,7 @@ class RemoteTable(Table):
|
||||
return resp
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
"""List all the indices on the table"""
|
||||
"""List all the stats of a specified index"""
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import semver
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -42,6 +43,14 @@ class CohereReranker(Reranker):
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
# ensure version is at least 0.5.0
|
||||
if (
|
||||
hasattr(cohere, "__version__")
|
||||
and semver.compare(cohere.__version__, "5.0.0") < 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"cohere version must be at least 0.5.0, found {cohere.__version__}"
|
||||
)
|
||||
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||
@@ -51,11 +60,14 @@ class CohereReranker(Reranker):
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
response = self._client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
)
|
||||
results = (
|
||||
response.results
|
||||
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
|
||||
@@ -806,6 +806,7 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
"""Reference to the latest version of a LanceDataset."""
|
||||
|
||||
uri: str
|
||||
index_cache_size: Optional[int] = None
|
||||
read_consistency_interval: Optional[timedelta] = None
|
||||
last_consistency_check: Optional[float] = None
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
@@ -813,7 +814,9 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri)
|
||||
self._dataset = lance.dataset(
|
||||
self.uri, index_cache_size=self.index_cache_size
|
||||
)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
elif self.read_consistency_interval is not None:
|
||||
now = time.monotonic()
|
||||
@@ -842,12 +845,15 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
class _LanceTimeTravelRef(_LanceDatasetRef):
|
||||
uri: str
|
||||
version: int
|
||||
index_cache_size: Optional[int] = None
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri, version=self.version)
|
||||
self._dataset = lance.dataset(
|
||||
self.uri, version=self.version, index_cache_size=self.index_cache_size
|
||||
)
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
@@ -884,6 +890,8 @@ class LanceTable(Table):
|
||||
connection: "LanceDBConnection",
|
||||
name: str,
|
||||
version: Optional[int] = None,
|
||||
*,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
@@ -892,11 +900,13 @@ class LanceTable(Table):
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
else:
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=connection.read_consistency_interval,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1199,6 +1209,11 @@ class LanceTable(Table):
|
||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||
fs.delete_dir(path)
|
||||
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Full-text search is only supported on the local filesystem"
|
||||
)
|
||||
|
||||
index = create_index(
|
||||
self._get_fts_index_path(),
|
||||
field_names,
|
||||
|
||||
@@ -368,6 +368,15 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
# await db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
def test_open_table_sync(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
db.create_table("test", data=[{"id": 0}])
|
||||
assert db.open_table("test").count_rows() == 1
|
||||
assert db.open_table("test", index_cache_size=0).count_rows() == 1
|
||||
with pytest.raises(FileNotFoundError, match="does not exist"):
|
||||
db.open_table("does_not_exist")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_table(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
@@ -397,6 +406,10 @@ async def test_open_table(tmp_path):
|
||||
}
|
||||
)
|
||||
|
||||
# No way to verify this yet, but at least make sure we
|
||||
# can pass the parameter
|
||||
await db.open_table("test", index_cache_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="was not found"):
|
||||
await db.open_table("does_not_exist")
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ def test_syntax(table):
|
||||
# https://github.com/lancedb/lancedb/issues/769
|
||||
table.create_fts_index("text")
|
||||
with pytest.raises(ValueError, match="Syntax Error"):
|
||||
table.search("they could have been dogs OR cats").limit(10).to_list()
|
||||
table.search("they could have been dogs OR").limit(10).to_list()
|
||||
|
||||
# these should work
|
||||
|
||||
|
||||
@@ -134,17 +134,21 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, storage_options = None))]
|
||||
#[pyo3(signature = (name, storage_options = None, index_cache_size = None))]
|
||||
pub fn open_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
index_cache_size: Option<u32>,
|
||||
) -> PyResult<&PyAny> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
let mut builder = inner.open_table(name);
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
if let Some(index_cache_size) = index_cache_size {
|
||||
builder = builder.index_cache_size(index_cache_size);
|
||||
}
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
|
||||
@@ -35,21 +35,16 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
match &self {
|
||||
Ok(_) => Ok(self.unwrap()),
|
||||
Err(err) => match err {
|
||||
LanceError::InvalidInput { .. } => self.value_error(),
|
||||
LanceError::InvalidTableName { .. } => self.value_error(),
|
||||
LanceError::TableNotFound { .. } => self.value_error(),
|
||||
LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::InvalidInput { .. }
|
||||
| LanceError::InvalidTableName { .. }
|
||||
| LanceError::TableNotFound { .. }
|
||||
| LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::CreateDir { .. } => self.os_error(),
|
||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
||||
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
|
||||
LanceError::Lance { .. } => self.runtime_error(),
|
||||
LanceError::Runtime { .. } => self.runtime_error(),
|
||||
LanceError::Http { .. } => self.runtime_error(),
|
||||
LanceError::Arrow { .. } => self.runtime_error(),
|
||||
LanceError::NotSupported { .. } => {
|
||||
Err(PyNotImplementedError::new_err(err.to_string()))
|
||||
}
|
||||
LanceError::Other { .. } => self.runtime_error(),
|
||||
_ => self.runtime_error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.4.17"
|
||||
version = "0.4.19"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -59,7 +59,7 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
for handle in storage_options_js {
|
||||
let obj = handle.downcast::<JsArray, _>(&mut cx).unwrap();
|
||||
let key = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
||||
let value = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
||||
let value = obj.get::<JsString, _, _>(&mut cx, 1)?.value(&mut cx);
|
||||
|
||||
storage_options.push((key, value));
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.4.17"
|
||||
version = "0.4.19"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -40,6 +40,8 @@ serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
polars-arrow = { version = ">=0.37", optional = true }
|
||||
polars = { version = ">=0.37", optional = true}
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5.0"
|
||||
@@ -52,7 +54,8 @@ aws-sdk-kms = { version = "1.0" }
|
||||
aws-config = { version = "1.0" }
|
||||
|
||||
[features]
|
||||
default = ["remote"]
|
||||
default = []
|
||||
remote = ["dep:reqwest"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
s3-test = []
|
||||
polars = ["dep:polars-arrow", "dep:polars"]
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
|
||||
pub use arrow_array;
|
||||
pub use arrow_schema;
|
||||
use futures::{Stream, StreamExt};
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
/// An iterator of batches that also has a schema
|
||||
@@ -114,8 +116,183 @@ pub trait IntoArrow {
|
||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
|
||||
}
|
||||
|
||||
pub type BoxedRecordBatchReader = Box<dyn arrow_array::RecordBatchReader + Send>;
|
||||
|
||||
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
|
||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
|
||||
Ok(Box::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> {
|
||||
pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self {
|
||||
Self { schema, stream }
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "polars")]
|
||||
/// An iterator of record batches formed from a Polars DataFrame.
|
||||
pub struct PolarsDataFrameRecordBatchReader {
|
||||
chunks: std::vec::IntoIter<ArrowChunk>,
|
||||
arrow_schema: Arc<arrow_schema::Schema>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl PolarsDataFrameRecordBatchReader {
|
||||
/// Creates a new `PolarsDataFrameRecordBatchReader` from a given Polars DataFrame.
|
||||
/// If the input dataframe does not have aligned chunks, this function undergoes
|
||||
/// the costly operation of reallocating each series as a single contigous chunk.
|
||||
pub fn new(mut df: DataFrame) -> Result<Self> {
|
||||
df.align_chunks();
|
||||
let arrow_schema =
|
||||
polars_arrow_convertors::convert_polars_df_schema_to_arrow_rb_schema(df.schema())?;
|
||||
Ok(Self {
|
||||
chunks: df
|
||||
.iter_chunks(polars_arrow_convertors::POLARS_ARROW_FLAVOR)
|
||||
.collect::<Vec<ArrowChunk>>()
|
||||
.into_iter(),
|
||||
arrow_schema,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl Iterator for PolarsDataFrameRecordBatchReader {
|
||||
type Item = std::result::Result<arrow_array::RecordBatch, arrow_schema::ArrowError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.chunks.next().map(|chunk| {
|
||||
let columns: std::result::Result<Vec<arrow_array::ArrayRef>, arrow_schema::ArrowError> =
|
||||
chunk
|
||||
.into_arrays()
|
||||
.into_iter()
|
||||
.zip(self.arrow_schema.fields.iter())
|
||||
.map(|(polars_array, arrow_field)| {
|
||||
polars_arrow_convertors::convert_polars_arrow_array_to_arrow_rs_array(
|
||||
polars_array,
|
||||
arrow_field.data_type().clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
arrow_array::RecordBatch::try_new(self.arrow_schema.clone(), columns?)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl arrow_array::RecordBatchReader for PolarsDataFrameRecordBatchReader {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
self.arrow_schema.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for converting the result of a LanceDB query into a Polars DataFrame with aligned
|
||||
/// chunks. The resulting Polars DataFrame will have aligned chunks, but the series's
|
||||
/// chunks are not guaranteed to be contiguous.
|
||||
#[cfg(feature = "polars")]
|
||||
pub trait IntoPolars {
|
||||
fn into_polars(self) -> impl std::future::Future<Output = Result<DataFrame>> + Send;
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl IntoPolars for SendableRecordBatchStream {
|
||||
async fn into_polars(mut self) -> Result<DataFrame> {
|
||||
let polars_schema =
|
||||
polars_arrow_convertors::convert_arrow_rb_schema_to_polars_df_schema(&self.schema())?;
|
||||
let mut acc_df: DataFrame = DataFrame::from(&polars_schema);
|
||||
while let Some(record_batch) = self.next().await {
|
||||
let new_df = polars_arrow_convertors::convert_arrow_rb_to_polars_df(
|
||||
&record_batch?,
|
||||
&polars_schema,
|
||||
)?;
|
||||
acc_df = acc_df.vstack(&new_df)?;
|
||||
}
|
||||
Ok(acc_df)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "polars"))]
|
||||
mod tests {
|
||||
use super::SendableRecordBatchStream;
|
||||
use crate::arrow::{
|
||||
IntoArrow, IntoPolars, PolarsDataFrameRecordBatchReader, SimpleRecordBatchStream,
|
||||
};
|
||||
use polars::prelude::{DataFrame, NamedFrom, Series};
|
||||
|
||||
fn get_record_batch_reader_from_polars() -> Box<dyn arrow_array::RecordBatchReader + Send> {
|
||||
let mut string_series = Series::new("string", &["ab"]);
|
||||
let mut int_series = Series::new("int", &[1]);
|
||||
let mut float_series = Series::new("float", &[1.0]);
|
||||
let df1 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
|
||||
|
||||
string_series = Series::new("string", &["bc"]);
|
||||
int_series = Series::new("int", &[2]);
|
||||
float_series = Series::new("float", &[2.0]);
|
||||
let df2 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
|
||||
|
||||
PolarsDataFrameRecordBatchReader::new(df1.vstack(&df2).unwrap())
|
||||
.unwrap()
|
||||
.into_arrow()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_polars_to_arrow() {
|
||||
let record_batch_reader = get_record_batch_reader_from_polars();
|
||||
let schema = record_batch_reader.schema();
|
||||
|
||||
// Test schema conversion
|
||||
assert_eq!(
|
||||
schema
|
||||
.fields
|
||||
.iter()
|
||||
.map(|field| (field.name().as_str(), field.data_type()))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("string", &arrow_schema::DataType::LargeUtf8),
|
||||
("int", &arrow_schema::DataType::Int32),
|
||||
("float", &arrow_schema::DataType::Float64)
|
||||
]
|
||||
);
|
||||
let record_batches: Vec<arrow_array::RecordBatch> =
|
||||
record_batch_reader.map(|result| result.unwrap()).collect();
|
||||
assert_eq!(record_batches.len(), 2);
|
||||
assert_eq!(schema, record_batches[0].schema());
|
||||
assert_eq!(record_batches[0].schema(), record_batches[1].schema());
|
||||
|
||||
// Test number of rows
|
||||
assert_eq!(record_batches[0].num_rows(), 1);
|
||||
assert_eq!(record_batches[1].num_rows(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_arrow_to_polars() {
|
||||
let record_batch_reader = get_record_batch_reader_from_polars();
|
||||
let schema = record_batch_reader.schema();
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema: schema.clone(),
|
||||
stream: futures::stream::iter(
|
||||
record_batch_reader
|
||||
.into_iter()
|
||||
.map(|r| r.map_err(Into::into)),
|
||||
),
|
||||
});
|
||||
let df = stream.into_polars().await.unwrap();
|
||||
|
||||
// Test number of chunks and rows
|
||||
assert_eq!(df.n_chunks(), 2);
|
||||
assert_eq!(df.height(), 2);
|
||||
|
||||
// Test schema conversion
|
||||
assert_eq!(
|
||||
df.schema()
|
||||
.into_iter()
|
||||
.map(|(name, datatype)| (name.to_string(), datatype))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("string".to_string(), polars::prelude::DataType::String),
|
||||
("int".to_owned(), polars::prelude::DataType::Int32),
|
||||
("float".to_owned(), polars::prelude::DataType::Float64)
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,12 +27,18 @@ use object_store::{aws::AwsCredential, local::LocalFileSystem};
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::embeddings::{
|
||||
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
||||
};
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{NativeTable, WriteOptions};
|
||||
use crate::table::{NativeTable, TableDefinition, WriteOptions};
|
||||
use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use log::warn;
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
|
||||
@@ -130,9 +136,10 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
pub(crate) name: String,
|
||||
pub(crate) data: Option<T>,
|
||||
pub(crate) schema: Option<SchemaRef>,
|
||||
pub(crate) mode: CreateTableMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
pub(crate) table_definition: Option<TableDefinition>,
|
||||
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we have initial data
|
||||
@@ -142,9 +149,10 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
||||
parent,
|
||||
name,
|
||||
data: Some(data),
|
||||
schema: None,
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
table_definition: None,
|
||||
embeddings: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,24 +180,43 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
||||
parent: self.parent,
|
||||
name: self.name,
|
||||
data: None,
|
||||
schema: self.schema,
|
||||
table_definition: self.table_definition,
|
||||
mode: self.mode,
|
||||
write_options: self.write_options,
|
||||
embeddings: self.embeddings,
|
||||
};
|
||||
Ok((data, builder))
|
||||
}
|
||||
|
||||
pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result<Self> {
|
||||
// Early verification of the embedding name
|
||||
let embedding_func = self
|
||||
.parent
|
||||
.embedding_registry()
|
||||
.get(&definition.embedding_name)
|
||||
.ok_or_else(|| Error::EmbeddingFunctionNotFound {
|
||||
name: definition.embedding_name.to_string(),
|
||||
reason: "No embedding function found in the connection's embedding_registry"
|
||||
.to_string(),
|
||||
})?;
|
||||
|
||||
self.embeddings.push((definition, embedding_func));
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we do not have initial data
|
||||
impl CreateTableBuilder<false, NoData> {
|
||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
||||
let table_definition = TableDefinition::new_from_schema(schema);
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
data: None,
|
||||
schema: Some(schema),
|
||||
table_definition: Some(table_definition),
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
embeddings: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,6 +374,7 @@ impl OpenTableBuilder {
|
||||
pub(crate) trait ConnectionInternal:
|
||||
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
|
||||
{
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry;
|
||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
|
||||
async fn do_create_table(
|
||||
&self,
|
||||
@@ -363,7 +391,7 @@ pub(crate) trait ConnectionInternal:
|
||||
) -> Result<Table> {
|
||||
let batches = Box::new(RecordBatchIterator::new(
|
||||
vec![],
|
||||
options.schema.as_ref().unwrap().clone(),
|
||||
options.table_definition.clone().unwrap().schema.clone(),
|
||||
));
|
||||
self.do_create_table(options, batches).await
|
||||
}
|
||||
@@ -450,6 +478,13 @@ impl Connection {
|
||||
pub async fn drop_db(&self) -> Result<()> {
|
||||
self.internal.drop_db().await
|
||||
}
|
||||
|
||||
/// Get the in-memory embedding registry.
|
||||
/// It's important to note that the embedding registry is not persisted across connections.
|
||||
/// So if a table contains embeddings, you will need to make sure that you are using a connection that has the same embedding functions registered
|
||||
pub fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
self.internal.embedding_registry()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -483,6 +518,7 @@ pub struct ConnectBuilder {
|
||||
/// consistency only applies to read operations. Write operations are
|
||||
/// always consistent.
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
}
|
||||
|
||||
impl ConnectBuilder {
|
||||
@@ -495,6 +531,7 @@ impl ConnectBuilder {
|
||||
host_override: None,
|
||||
read_consistency_interval: None,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,6 +550,12 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
self
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
#[deprecated(note = "Pass through storage_options instead")]
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
@@ -579,6 +622,7 @@ impl ConnectBuilder {
|
||||
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
warn!("The rust implementation of the remote client is not yet ready for use.");
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
&self.uri,
|
||||
&api_key,
|
||||
@@ -638,6 +682,7 @@ struct Database {
|
||||
|
||||
// Storage options to be inherited by tables created from this connection
|
||||
storage_options: HashMap<String, String>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Database {
|
||||
@@ -671,7 +716,12 @@ impl Database {
|
||||
// TODO: pass params regardless of OS
|
||||
match parse_res {
|
||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
|
||||
Self::open_path(uri, options.read_consistency_interval).await
|
||||
Self::open_path(
|
||||
uri,
|
||||
options.read_consistency_interval,
|
||||
options.embedding_registry.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
Ok(mut url) => {
|
||||
// iter thru the query params and extract the commit store param
|
||||
@@ -741,6 +791,10 @@ impl Database {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let embedding_registry = options
|
||||
.embedding_registry
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||
Ok(Self {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
@@ -749,20 +803,33 @@ impl Database {
|
||||
store_wrapper: write_store_wrapper,
|
||||
read_consistency_interval: options.read_consistency_interval,
|
||||
storage_options,
|
||||
embedding_registry,
|
||||
})
|
||||
}
|
||||
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
|
||||
Err(_) => {
|
||||
Self::open_path(
|
||||
uri,
|
||||
options.read_consistency_interval,
|
||||
options.embedding_registry.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_path(
|
||||
path: &str,
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
) -> Result<Self> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
}
|
||||
|
||||
let embedding_registry =
|
||||
embedding_registry.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
query_string: None,
|
||||
@@ -771,6 +838,7 @@ impl Database {
|
||||
store_wrapper: None,
|
||||
read_consistency_interval,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -811,6 +879,9 @@ impl Database {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ConnectionInternal for Database {
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
self.embedding_registry.as_ref()
|
||||
}
|
||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
|
||||
let mut f = self
|
||||
.object_store
|
||||
@@ -847,7 +918,7 @@ impl ConnectionInternal for Database {
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<Table> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
|
||||
let embedding_registry = self.embedding_registry.clone();
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = options
|
||||
.write_options
|
||||
@@ -862,6 +933,11 @@ impl ConnectionInternal for Database {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
let data = if options.embeddings.is_empty() {
|
||||
data
|
||||
} else {
|
||||
Box::new(WithEmbeddings::new(data, options.embeddings))
|
||||
};
|
||||
|
||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
||||
@@ -878,7 +954,10 @@ impl ConnectionInternal for Database {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(table) => Ok(Table::new(Arc::new(table))),
|
||||
Ok(table) => Ok(Table::new_with_embedding_registry(
|
||||
Arc::new(table),
|
||||
embedding_registry,
|
||||
)),
|
||||
Err(Error::TableAlreadyExists { name }) => match options.mode {
|
||||
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
@@ -909,12 +988,23 @@ impl ConnectionInternal for Database {
|
||||
}
|
||||
}
|
||||
|
||||
// Some ReadParams are exposed in the OpenTableBuilder, but we also
|
||||
// let the user provide their own ReadParams.
|
||||
//
|
||||
// If we have a user provided ReadParams use that
|
||||
// If we don't then start with the default ReadParams and customize it with
|
||||
// the options from the OpenTableBuilder
|
||||
let read_params = options.lance_read_params.unwrap_or_else(|| ReadParams {
|
||||
index_cache_size: options.index_cache_size as usize,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
let native_table = Arc::new(
|
||||
NativeTable::open_with_params(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
self.store_wrapper.clone(),
|
||||
options.lance_read_params,
|
||||
Some(read_params),
|
||||
self.read_consistency_interval,
|
||||
)
|
||||
.await?,
|
||||
@@ -1032,7 +1122,6 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "this can't pass due to https://github.com/lancedb/lancedb/issues/1019, enable it after the bug fixed"]
|
||||
async fn test_open_table() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
307
rust/lancedb/src/embeddings.rs
Normal file
307
rust/lancedb/src/embeddings.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use arrow_array::{Array, RecordBatch, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, SchemaBuilder};
|
||||
// use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
error::Result,
|
||||
table::{ColumnDefinition, ColumnKind, TableDefinition},
|
||||
Error,
|
||||
};
|
||||
|
||||
/// Trait for embedding functions
|
||||
///
|
||||
/// An embedding function is a function that is applied to a column of input data
|
||||
/// to produce an "embedding" of that input. This embedding is then stored in the
|
||||
/// database alongside (or instead of) the original input.
|
||||
///
|
||||
/// An "embedding" is often a lower-dimensional representation of the input data.
|
||||
/// For example, sentence-transformers can be used to embed sentences into a 768-dimensional
|
||||
/// vector space. This is useful for tasks like similarity search, where we want to find
|
||||
/// similar sentences to a query sentence.
|
||||
///
|
||||
/// To use an embedding function you must first register it with the `EmbeddingsRegistry`.
|
||||
/// Then you can define it on a column in the table schema. That embedding will then be used
|
||||
/// to embed the data in that column.
|
||||
pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
/// The type of the input data
|
||||
fn source_type(&self) -> Result<Cow<DataType>>;
|
||||
/// The type of the output data
|
||||
/// This should **always** match the output of the `embed` function
|
||||
fn dest_type(&self) -> Result<Cow<DataType>>;
|
||||
/// Embed the input
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
|
||||
}
|
||||
|
||||
/// Defines an embedding from input data into a lower-dimensional space
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct EmbeddingDefinition {
|
||||
/// The name of the column in the input data
|
||||
pub source_column: String,
|
||||
/// The name of the embedding column, if not specified
|
||||
/// it will be the source column with `_embedding` appended
|
||||
pub dest_column: Option<String>,
|
||||
/// The name of the embedding function to apply
|
||||
pub embedding_name: String,
|
||||
}
|
||||
|
||||
impl EmbeddingDefinition {
|
||||
pub fn new<S: Into<String>>(source_column: S, embedding_name: S, dest: Option<S>) -> Self {
|
||||
Self {
|
||||
source_column: source_column.into(),
|
||||
dest_column: dest.map(|d| d.into()),
|
||||
embedding_name: embedding_name.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A registry of embedding
|
||||
pub trait EmbeddingRegistry: Send + Sync + std::fmt::Debug {
|
||||
/// Return the names of all registered embedding functions
|
||||
fn functions(&self) -> HashSet<String>;
|
||||
/// Register a new [`EmbeddingFunction
|
||||
/// Returns an error if the function can not be registered
|
||||
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()>;
|
||||
/// Get an embedding function by name
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>>;
|
||||
}
|
||||
|
||||
/// A [`EmbeddingRegistry`] that uses in-memory [`HashMap`]s
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct MemoryRegistry {
|
||||
functions: Arc<RwLock<HashMap<String, Arc<dyn EmbeddingFunction>>>>,
|
||||
}
|
||||
|
||||
impl EmbeddingRegistry for MemoryRegistry {
|
||||
fn functions(&self) -> HashSet<String> {
|
||||
self.functions.read().unwrap().keys().cloned().collect()
|
||||
}
|
||||
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||
self.functions
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(name.to_string(), function);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||
self.functions.read().unwrap().get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryRegistry {
|
||||
/// Create a new `MemoryRegistry`
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// A record batch reader that has embeddings applied to it
|
||||
/// This is a wrapper around another record batch reader that applies an embedding function
|
||||
/// when reading from the record batch
|
||||
pub struct WithEmbeddings<R: RecordBatchReader> {
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
}
|
||||
|
||||
/// A record batch that might have embeddings applied to it.
|
||||
pub enum MaybeEmbedded<R: RecordBatchReader> {
|
||||
/// The record batch reader has embeddings applied to it
|
||||
Yes(WithEmbeddings<R>),
|
||||
/// The record batch reader does not have embeddings applied to it
|
||||
/// The inner record batch reader is returned as-is
|
||||
No(R),
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> MaybeEmbedded<R> {
|
||||
/// Create a new RecordBatchReader with embeddings applied to it if the table definition
|
||||
/// specifies an embedding column and the registry contains an embedding function with that name
|
||||
/// Otherwise, this is a no-op and the inner RecordBatchReader is returned.
|
||||
pub fn try_new(
|
||||
inner: R,
|
||||
table_definition: TableDefinition,
|
||||
registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
) -> Result<Self> {
|
||||
if let Some(registry) = registry {
|
||||
let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len());
|
||||
for cd in table_definition.column_definitions.iter() {
|
||||
if let ColumnKind::Embedding(embedding_def) = &cd.kind {
|
||||
match registry.get(&embedding_def.embedding_name) {
|
||||
Some(func) => {
|
||||
embeddings.push((embedding_def.clone(), func));
|
||||
}
|
||||
None => {
|
||||
return Err(Error::EmbeddingFunctionNotFound {
|
||||
name: embedding_def.embedding_name.to_string(),
|
||||
reason: format!(
|
||||
"Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.",
|
||||
embedding_def.embedding_name
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !embeddings.is_empty() {
|
||||
return Ok(Self::Yes(WithEmbeddings { inner, embeddings }));
|
||||
}
|
||||
};
|
||||
|
||||
// No embeddings to apply
|
||||
Ok(Self::No(inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
pub fn new(
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
) -> Self {
|
||||
Self { inner, embeddings }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
fn dest_fields(&self) -> Result<Vec<Field>> {
|
||||
let schema = self.inner.schema();
|
||||
self.embeddings
|
||||
.iter()
|
||||
.map(|(ed, func)| {
|
||||
let src_field = schema.field_with_name(&ed.source_column).unwrap();
|
||||
|
||||
let field_name = ed
|
||||
.dest_column
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("{}_embedding", &ed.source_column));
|
||||
Ok(Field::new(
|
||||
field_name,
|
||||
func.dest_type()?.into_owned(),
|
||||
src_field.is_nullable(),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn column_defs(&self) -> Vec<ColumnDefinition> {
|
||||
let base_schema = self.inner.schema();
|
||||
base_schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.chain(self.embeddings.iter().map(|(ed, _)| ColumnDefinition {
|
||||
kind: ColumnKind::Embedding(ed.clone()),
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn table_definition(&self) -> Result<TableDefinition> {
|
||||
let base_schema = self.inner.schema();
|
||||
|
||||
let output_fields = self.dest_fields()?;
|
||||
let column_definitions = self.column_defs();
|
||||
|
||||
let mut sb: SchemaBuilder = base_schema.as_ref().into();
|
||||
sb.extend(output_fields);
|
||||
|
||||
let schema = Arc::new(sb.finish());
|
||||
Ok(TableDefinition {
|
||||
schema,
|
||||
column_definitions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
||||
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
Self::Yes(inner) => inner.next(),
|
||||
Self::No(inner) => inner.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> RecordBatchReader for MaybeEmbedded<R> {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
match self {
|
||||
Self::Yes(inner) => inner.schema(),
|
||||
Self::No(inner) => inner.schema(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let batch = self.inner.next()?;
|
||||
match batch {
|
||||
Ok(mut batch) => {
|
||||
// todo: parallelize this
|
||||
for (fld, func) in self.embeddings.iter() {
|
||||
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
||||
let embedding = match func.embed(src_column.clone()) {
|
||||
Ok(embedding) => embedding,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
"Error computing embedding: {}",
|
||||
e
|
||||
))))
|
||||
}
|
||||
};
|
||||
let dst_field_name = fld
|
||||
.dest_column
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("{}_embedding", &fld.source_column));
|
||||
|
||||
let dst_field = Field::new(
|
||||
dst_field_name,
|
||||
embedding.data_type().clone(),
|
||||
embedding.nulls().is_some(),
|
||||
);
|
||||
|
||||
match batch.try_with_column(dst_field.clone(), embedding) {
|
||||
Ok(b) => batch = b,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
}
|
||||
Some(Ok(batch))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> RecordBatchReader for WithEmbeddings<R> {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
self.table_definition()
|
||||
.expect("table definition should be infallible at this point")
|
||||
.into_rich_schema()
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,9 @@ pub enum Error {
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
||||
EmbeddingFunctionNotFound { name: String, reason: String },
|
||||
|
||||
#[snafu(display("Table '{name}' already exists"))]
|
||||
TableAlreadyExists { name: String },
|
||||
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]
|
||||
@@ -112,3 +115,13 @@ impl From<url::ParseError> for Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl From<polars::prelude::PolarsError> for Error {
|
||||
fn from(source: polars::prelude::PolarsError) -> Self {
|
||||
Self::Other {
|
||||
message: "Error in Polars DataFrame integration.".to_string(),
|
||||
source: Some(Box::new(source)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,10 +46,18 @@ impl VectorIndex {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct VectorIndexMetadata {
|
||||
pub metric_type: String,
|
||||
pub index_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct VectorIndexStatistics {
|
||||
pub num_indexed_rows: usize,
|
||||
pub num_unindexed_rows: usize,
|
||||
pub index_type: String,
|
||||
pub indices: Vec<VectorIndexMetadata>,
|
||||
}
|
||||
|
||||
/// Builder for an IVF PQ index.
|
||||
|
||||
@@ -350,8 +350,16 @@ mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_e2e() {
|
||||
let dir1 = tempfile::tempdir().unwrap().into_path();
|
||||
let dir2 = tempfile::tempdir().unwrap().into_path();
|
||||
let dir1 = tempfile::tempdir()
|
||||
.unwrap()
|
||||
.into_path()
|
||||
.canonicalize()
|
||||
.unwrap();
|
||||
let dir2 = tempfile::tempdir()
|
||||
.unwrap()
|
||||
.into_path()
|
||||
.canonicalize()
|
||||
.unwrap();
|
||||
|
||||
let secondary_store = LocalFileSystem::new_with_prefix(dir2.to_str().unwrap()).unwrap();
|
||||
let object_store_wrapper = Arc::new(MirroringObjectStoreWrapper {
|
||||
|
||||
@@ -34,6 +34,16 @@
|
||||
//! cargo install lancedb
|
||||
//! ```
|
||||
//!
|
||||
//! ## Crate Features
|
||||
//!
|
||||
//! ### Experimental Features
|
||||
//!
|
||||
//! These features are not enabled by default. They are experimental or in-development features that
|
||||
//! are not yet ready to be released.
|
||||
//!
|
||||
//! - `remote` - Enable remote client to connect to LanceDB cloud. This is not yet fully implemented
|
||||
//! and should not be enabled.
|
||||
//!
|
||||
//! ### Quick Start
|
||||
//!
|
||||
//! #### Connect to a database.
|
||||
@@ -184,10 +194,13 @@
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod data;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
#[cfg(feature = "polars")]
|
||||
mod polars_arrow_convertors;
|
||||
pub mod query;
|
||||
#[cfg(feature = "remote")]
|
||||
pub(crate) mod remote;
|
||||
|
||||
123
rust/lancedb/src/polars_arrow_convertors.rs
Normal file
123
rust/lancedb/src/polars_arrow_convertors.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
/// Polars and LanceDB both use Arrow for their in memory-representation, but use
|
||||
/// different Rust Arrow implementations. LanceDB uses the arrow-rs crate and
|
||||
/// Polars uses the polars-arrow crate.
|
||||
///
|
||||
/// This crate defines zero-copy conversions (of the underlying buffers)
|
||||
/// between polars-arrow and arrow-rs using the C FFI.
|
||||
///
|
||||
/// The polars-arrow does implement conversions to and from arrow-rs, but
|
||||
/// requires a feature flagged dependency on arrow-rs. The version of arrow-rs
|
||||
/// depended on by polars-arrow and LanceDB may not be compatible,
|
||||
/// which necessitates using the C FFI.
|
||||
use crate::error::Result;
|
||||
use polars::prelude::{DataFrame, Series};
|
||||
use std::{mem, sync::Arc};
|
||||
|
||||
/// When interpreting Polars dataframes as polars-arrow record batches,
|
||||
/// one must decide whether to use Arrow string/binary view types
|
||||
/// instead of the standard Arrow string/binary types.
|
||||
/// For now, we will not use string view types because conversions
|
||||
/// for string view types from polars-arrow to arrow-rs are not yet implemented.
|
||||
/// See: https://lists.apache.org/thread/w88tpz76ox8h3rxkjl4so6rg3f1rv7wt for the
|
||||
/// differences in the types.
|
||||
pub const POLARS_ARROW_FLAVOR: bool = false;
|
||||
const IS_ARRAY_NULLABLE: bool = true;
|
||||
|
||||
/// Converts a Polars DataFrame schema to an Arrow RecordBatch schema.
|
||||
pub fn convert_polars_df_schema_to_arrow_rb_schema(
|
||||
polars_df_schema: polars::prelude::Schema,
|
||||
) -> Result<Arc<arrow_schema::Schema>> {
|
||||
let arrow_fields: Result<Vec<arrow_schema::Field>> = polars_df_schema
|
||||
.into_iter()
|
||||
.map(|(name, df_dtype)| {
|
||||
let polars_arrow_dtype = df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
|
||||
let polars_field =
|
||||
polars_arrow::datatypes::Field::new(name, polars_arrow_dtype, IS_ARRAY_NULLABLE);
|
||||
convert_polars_arrow_field_to_arrow_rs_field(polars_field)
|
||||
})
|
||||
.collect();
|
||||
Ok(Arc::new(arrow_schema::Schema::new(arrow_fields?)))
|
||||
}
|
||||
|
||||
/// Converts an Arrow RecordBatch schema to a Polars DataFrame schema.
|
||||
pub fn convert_arrow_rb_schema_to_polars_df_schema(
|
||||
arrow_schema: &arrow_schema::Schema,
|
||||
) -> Result<polars::prelude::Schema> {
|
||||
let polars_df_fields: Result<Vec<polars::prelude::Field>> = arrow_schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|arrow_rs_field| {
|
||||
let polars_arrow_field = convert_arrow_rs_field_to_polars_arrow_field(arrow_rs_field)?;
|
||||
Ok(polars::prelude::Field::new(
|
||||
arrow_rs_field.name(),
|
||||
polars::datatypes::DataType::from(polars_arrow_field.data_type()),
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
Ok(polars::prelude::Schema::from_iter(polars_df_fields?))
|
||||
}
|
||||
|
||||
/// Converts an Arrow RecordBatch to a Polars DataFrame, using a provided Polars DataFrame schema.
|
||||
pub fn convert_arrow_rb_to_polars_df(
|
||||
arrow_rb: &arrow::record_batch::RecordBatch,
|
||||
polars_schema: &polars::prelude::Schema,
|
||||
) -> Result<DataFrame> {
|
||||
let mut columns: Vec<Series> = Vec::with_capacity(arrow_rb.num_columns());
|
||||
|
||||
for (i, column) in arrow_rb.columns().iter().enumerate() {
|
||||
let polars_df_dtype = polars_schema.try_get_at_index(i)?.1;
|
||||
let polars_arrow_dtype = polars_df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
|
||||
let polars_array =
|
||||
convert_arrow_rs_array_to_polars_arrow_array(column, polars_arrow_dtype)?;
|
||||
columns.push(Series::from_arrow(
|
||||
polars_schema.try_get_at_index(i)?.0,
|
||||
polars_array,
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(DataFrame::from_iter(columns))
|
||||
}
|
||||
|
||||
/// Converts a polars-arrow Arrow array to an arrow-rs Arrow array.
|
||||
pub fn convert_polars_arrow_array_to_arrow_rs_array(
|
||||
polars_array: Box<dyn polars_arrow::array::Array>,
|
||||
arrow_datatype: arrow_schema::DataType,
|
||||
) -> std::result::Result<arrow_array::ArrayRef, arrow_schema::ArrowError> {
|
||||
let polars_c_array = polars_arrow::ffi::export_array_to_c(polars_array);
|
||||
let arrow_c_array = unsafe { mem::transmute(polars_c_array) };
|
||||
Ok(arrow_array::make_array(unsafe {
|
||||
arrow::ffi::from_ffi_and_data_type(arrow_c_array, arrow_datatype)
|
||||
}?))
|
||||
}
|
||||
|
||||
/// Converts an arrow-rs Arrow array to a polars-arrow Arrow array.
|
||||
fn convert_arrow_rs_array_to_polars_arrow_array(
|
||||
arrow_rs_array: &Arc<dyn arrow_array::Array>,
|
||||
polars_arrow_dtype: polars::datatypes::ArrowDataType,
|
||||
) -> Result<Box<dyn polars_arrow::array::Array>> {
|
||||
let arrow_c_array = arrow::ffi::FFI_ArrowArray::new(&arrow_rs_array.to_data());
|
||||
let polars_c_array = unsafe { mem::transmute(arrow_c_array) };
|
||||
Ok(unsafe { polars_arrow::ffi::import_array_from_c(polars_c_array, polars_arrow_dtype) }?)
|
||||
}
|
||||
|
||||
fn convert_polars_arrow_field_to_arrow_rs_field(
|
||||
polars_arrow_field: polars_arrow::datatypes::Field,
|
||||
) -> Result<arrow_schema::Field> {
|
||||
let polars_c_schema = polars_arrow::ffi::export_field_to_c(&polars_arrow_field);
|
||||
let arrow_c_schema: arrow::ffi::FFI_ArrowSchema = unsafe { mem::transmute(polars_c_schema) };
|
||||
let arrow_rs_dtype = arrow_schema::DataType::try_from(&arrow_c_schema)?;
|
||||
Ok(arrow_schema::Field::new(
|
||||
polars_arrow_field.name,
|
||||
arrow_rs_dtype,
|
||||
IS_ARRAY_NULLABLE,
|
||||
))
|
||||
}
|
||||
|
||||
fn convert_arrow_rs_field_to_polars_arrow_field(
|
||||
arrow_rs_field: &arrow_schema::Field,
|
||||
) -> Result<polars_arrow::datatypes::Field> {
|
||||
let arrow_rs_dtype = arrow_rs_field.data_type();
|
||||
let arrow_c_schema = arrow::ffi::FFI_ArrowSchema::try_from(arrow_rs_dtype)?;
|
||||
let polars_c_schema: polars_arrow::ffi::ArrowSchema = unsafe { mem::transmute(arrow_c_schema) };
|
||||
Ok(unsafe { polars_arrow::ffi::import_field_from_c(&polars_c_schema) }?)
|
||||
}
|
||||
@@ -23,6 +23,7 @@ use tokio::task::spawn_blocking;
|
||||
use crate::connection::{
|
||||
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
|
||||
};
|
||||
use crate::embeddings::EmbeddingRegistry;
|
||||
use crate::error::Result;
|
||||
use crate::Table;
|
||||
|
||||
@@ -87,14 +88,16 @@ impl ConnectionInternal for RemoteDatabase {
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
self.client
|
||||
.post(&format!("/v1/table/{}/create", options.name))
|
||||
let rsp = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/create/", options.name))
|
||||
.body(data_buffer)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
// This is currently expected by LanceDb cloud but will be removed soon.
|
||||
.header("x-request-id", "na")
|
||||
.send()
|
||||
.await?;
|
||||
self.client.check_response(rsp).await?;
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
@@ -113,4 +116,8 @@ impl ConnectionInternal for RemoteDatabase {
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
query::{Query, QueryExecutionOptions, VectorQuery},
|
||||
table::{
|
||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||
TableInternal, UpdateBuilder,
|
||||
TableDefinition, TableInternal, UpdateBuilder,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -120,4 +120,7 @@ impl TableInternal for RemoteTable {
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
todo!()
|
||||
}
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,10 +41,12 @@ use lance::io::WrappingObjectStore;
|
||||
use lance_index::IndexType;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::whatever;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
||||
use crate::index::IndexConfig;
|
||||
@@ -63,6 +65,79 @@ use self::merge::MergeInsertBuilder;
|
||||
pub(crate) mod dataset;
|
||||
pub mod merge;
|
||||
|
||||
/// Defines the type of column
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ColumnKind {
|
||||
/// Columns populated by data from the user (this is the most common case)
|
||||
Physical,
|
||||
/// Columns populated by applying an embedding function to the input
|
||||
Embedding(EmbeddingDefinition),
|
||||
}
|
||||
|
||||
/// Defines a column in a table
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ColumnDefinition {
|
||||
/// The source of the column data
|
||||
pub kind: ColumnKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TableDefinition {
|
||||
pub column_definitions: Vec<ColumnDefinition>,
|
||||
pub schema: SchemaRef,
|
||||
}
|
||||
|
||||
impl TableDefinition {
|
||||
pub fn new(schema: SchemaRef, column_definitions: Vec<ColumnDefinition>) -> Self {
|
||||
Self {
|
||||
column_definitions,
|
||||
schema,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_schema(schema: SchemaRef) -> Self {
|
||||
let column_definitions = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.collect();
|
||||
Self::new(schema, column_definitions)
|
||||
}
|
||||
|
||||
pub fn try_from_rich_schema(schema: SchemaRef) -> Result<Self> {
|
||||
let column_definitions = schema.metadata.get("lancedb::column_definitions");
|
||||
if let Some(column_definitions) = column_definitions {
|
||||
let column_definitions: Vec<ColumnDefinition> =
|
||||
serde_json::from_str(column_definitions).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to deserialize column definitions: {}", e),
|
||||
})?;
|
||||
Ok(Self::new(schema, column_definitions))
|
||||
} else {
|
||||
let column_definitions = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.collect();
|
||||
Ok(Self::new(schema, column_definitions))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_rich_schema(self) -> SchemaRef {
|
||||
// We have full control over the structure of column definitions. This should
|
||||
// not fail, except for a bug
|
||||
let lancedb_metadata = serde_json::to_string(&self.column_definitions).unwrap();
|
||||
let mut schema_with_metadata = (*self.schema).clone();
|
||||
schema_with_metadata
|
||||
.metadata
|
||||
.insert("lancedb::column_definitions".to_string(), lancedb_metadata);
|
||||
Arc::new(schema_with_metadata)
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimize the dataset.
|
||||
///
|
||||
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
|
||||
@@ -132,6 +207,7 @@ pub struct AddDataBuilder<T: IntoArrow> {
|
||||
pub(crate) data: T,
|
||||
pub(crate) mode: AddDataMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
}
|
||||
|
||||
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
|
||||
@@ -163,6 +239,7 @@ impl<T: IntoArrow> AddDataBuilder<T> {
|
||||
mode: self.mode,
|
||||
parent: self.parent,
|
||||
write_options: self.write_options,
|
||||
embedding_registry: self.embedding_registry,
|
||||
};
|
||||
parent.add(without_data, data).await
|
||||
}
|
||||
@@ -280,6 +357,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn checkout(&self, version: u64) -> Result<()>;
|
||||
async fn checkout_latest(&self) -> Result<()>;
|
||||
async fn restore(&self) -> Result<()>;
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
@@ -288,6 +366,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
#[derive(Clone)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn TableInternal>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Table {
|
||||
@@ -298,7 +377,20 @@ impl std::fmt::Display for Table {
|
||||
|
||||
impl Table {
|
||||
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
|
||||
Self { inner }
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_embedding_registry(
|
||||
inner: Arc<dyn TableInternal>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
@@ -340,6 +432,7 @@ impl Table {
|
||||
data: batches,
|
||||
mode: AddDataMode::Append,
|
||||
write_options: WriteOptions::default(),
|
||||
embedding_registry: Some(self.embedding_registry.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -743,11 +836,10 @@ impl Table {
|
||||
|
||||
impl From<NativeTable> for Table {
|
||||
fn from(table: NativeTable) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(table),
|
||||
}
|
||||
Self::new(Arc::new(table))
|
||||
}
|
||||
}
|
||||
|
||||
/// A table in a LanceDB database.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NativeTable {
|
||||
@@ -918,7 +1010,6 @@ impl NativeTable {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
};
|
||||
|
||||
let storage_options = params
|
||||
.store_params
|
||||
.clone()
|
||||
@@ -1061,6 +1152,26 @@ impl NativeTable {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> {
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(stats.index_type)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_distance_type(&self, index_uuid: &str) -> Result<Option<String>> {
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(
|
||||
stats
|
||||
.indices
|
||||
.iter()
|
||||
.map(|i| i.metric_type.clone())
|
||||
.collect(),
|
||||
)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
|
||||
@@ -1322,6 +1433,11 @@ impl TableInternal for NativeTable {
|
||||
Ok(Arc::new(Schema::from(&lance_schema)))
|
||||
}
|
||||
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
let schema = self.schema().await?;
|
||||
TableDefinition::try_from_rich_schema(schema)
|
||||
}
|
||||
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
Ok(self.dataset.get().await?.count_rows(filter).await?)
|
||||
}
|
||||
@@ -1331,6 +1447,9 @@ impl TableInternal for NativeTable {
|
||||
add: AddDataBuilder<NoData>,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
let data =
|
||||
MaybeEmbedded::try_new(data, self.table_definition().await?, add.embedding_registry)?;
|
||||
|
||||
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
||||
mode: match add.mode {
|
||||
AddDataMode::Append => WriteMode::Append,
|
||||
@@ -1358,8 +1477,8 @@ impl TableInternal for NativeTable {
|
||||
};
|
||||
|
||||
self.dataset.ensure_mutable().await?;
|
||||
|
||||
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
|
||||
|
||||
self.dataset.set_latest(dataset).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
iter::repeat,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use arrow::buffer::NullBuffer;
|
||||
use arrow_array::{
|
||||
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
arrow::IntoArrow,
|
||||
connect,
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
|
||||
query::ExecutableQuery,
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_custom_func() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
let db = connect(tempdir).execute().await?;
|
||||
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
|
||||
db.embedding_registry()
|
||||
.register("embed_fun", Arc::new(embed_fun.clone()))?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&embed_fun.name,
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
// now make sure the embeddings are applied when
|
||||
// we add new records too
|
||||
tbl.add(create_some_records()?).execute().await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_custom_registry() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir)
|
||||
.embedding_registry(Arc::new(MyRegistry::default()))
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"func_1",
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(
|
||||
embeddings.data_type(),
|
||||
MockEmbed::new("func_1".to_string(), 1)
|
||||
.dest_type()?
|
||||
.as_ref()
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_embeddings() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
let func_1 = MockEmbed::new("func_1".to_string(), 1);
|
||||
let func_2 = MockEmbed::new("func_2".to_string(), 10);
|
||||
db.embedding_registry()
|
||||
.register(&func_1.name, Arc::new(func_1.clone()))?;
|
||||
db.embedding_registry()
|
||||
.register(&func_2.name, Arc::new(func_2.clone()))?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&func_1.name,
|
||||
Some("first_embeddings"),
|
||||
))?
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&func_2.name,
|
||||
Some("second_embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("first_embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||
assert!(second_embeddings.is_some());
|
||||
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||
|
||||
let second_embeddings = second_embeddings.unwrap();
|
||||
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||
}
|
||||
|
||||
// now make sure the embeddings are applied when
|
||||
// we add new records too
|
||||
tbl.add(create_some_records()?).execute().await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("first_embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||
assert!(second_embeddings.is_some());
|
||||
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||
|
||||
let second_embeddings = second_embeddings.unwrap();
|
||||
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_func_in_registry() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
|
||||
let res = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"some_func",
|
||||
Some("first_embeddings"),
|
||||
));
|
||||
assert!(res.is_err());
|
||||
assert!(matches!(
|
||||
res.err().unwrap(),
|
||||
Error::EmbeddingFunctionNotFound { .. }
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_func_in_registry_on_add() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
db.embedding_registry().register(
|
||||
"some_func",
|
||||
Arc::new(MockEmbed::new("some_func".to_string(), 1)),
|
||||
)?;
|
||||
|
||||
db.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"some_func",
|
||||
Some("first_embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
|
||||
let tbl = db.open_table("test").execute().await?;
|
||||
// This should fail because 'tbl' is expecting "some_func" to be in the registry
|
||||
let res = tbl.add(create_some_records()?).execute().await;
|
||||
assert!(res.is_err());
|
||||
assert!(matches!(
|
||||
res.unwrap_err(),
|
||||
crate::Error::EmbeddingFunctionNotFound { .. }
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_some_records() -> Result<impl IntoArrow> {
|
||||
const TOTAL: usize = 2;
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, true),
|
||||
]));
|
||||
|
||||
// Create a RecordBatch stream.
|
||||
let batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(StringArray::from_iter(
|
||||
repeat(Some("hello world".to_string())).take(TOTAL),
|
||||
)),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
Ok(Box::new(batches))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MyRegistry {
|
||||
functions: HashMap<String, Arc<dyn EmbeddingFunction>>,
|
||||
}
|
||||
impl Default for MyRegistry {
|
||||
fn default() -> Self {
|
||||
let funcs: Vec<Arc<dyn EmbeddingFunction>> = vec![
|
||||
Arc::new(MockEmbed::new("func_1".to_string(), 1)),
|
||||
Arc::new(MockEmbed::new("func_2".to_string(), 10)),
|
||||
];
|
||||
Self {
|
||||
functions: funcs
|
||||
.into_iter()
|
||||
.map(|f| (f.name().to_string(), f))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// a mock registry that only has one function called `embed_fun`
|
||||
impl EmbeddingRegistry for MyRegistry {
|
||||
fn functions(&self) -> HashSet<String> {
|
||||
self.functions.keys().cloned().collect()
|
||||
}
|
||||
|
||||
fn register(&self, _name: &str, _function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||
Err(Error::Other {
|
||||
message: "MyRegistry is read-only".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||
self.functions.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MockEmbed {
|
||||
source_type: DataType,
|
||||
dest_type: DataType,
|
||||
name: String,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl MockEmbed {
|
||||
pub fn new(name: String, dim: usize) -> Self {
|
||||
Self {
|
||||
source_type: DataType::Utf8,
|
||||
dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true),
|
||||
name,
|
||||
dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for MockEmbed {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn source_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Borrowed(&self.source_type))
|
||||
}
|
||||
fn dest_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Borrowed(&self.dest_type))
|
||||
}
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
|
||||
// and we want to explicitly work with non-nullable arrays.
|
||||
let len = source.len();
|
||||
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
|
||||
let field = Field::new("item", inner.data_type().clone(), false);
|
||||
let arr = FixedSizeListArray::new(
|
||||
Arc::new(field),
|
||||
self.dim as _,
|
||||
inner,
|
||||
Some(NullBuffer::new_valid(len)),
|
||||
);
|
||||
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user