mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
Compare commits
9 Commits
ayush/jina
...
lance-14.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffcf632abb | ||
|
|
3c6c21c137 | ||
|
|
fd5ca20f34 | ||
|
|
ef30f87fd1 | ||
|
|
08d25c5a80 | ||
|
|
a5ff623443 | ||
|
|
b8ccea9f71 | ||
|
|
46c6ff889d | ||
|
|
12b3c87964 |
17
Cargo.toml
17
Cargo.toml
@@ -20,11 +20,18 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.13.0", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.13.0" }
|
||||
lance-linalg = { "version" = "=0.13.0" }
|
||||
lance-testing = { "version" = "=0.13.0" }
|
||||
lance-datafusion = { "version" = "=0.13.0" }
|
||||
# lance = { "version" = "=0.14.0", "features" = ["dynamodb"] }
|
||||
# lance-index = { "version" = "=0.14.0" }
|
||||
# lance-linalg = { "version" = "=0.14.0" }
|
||||
# lance-testing = { "version" = "=0.14.0" }
|
||||
# lance-datafusion = { "version" = "=0.14.0" }
|
||||
|
||||
lance = { path = "../lance/rust/lance", "features" = ["dynamodb"] }
|
||||
lance-index = { path = "../lance/rust/lance-index" }
|
||||
lance-linalg = { path = "../lance/rust/lance-linalg" }
|
||||
lance-testing = { path = "../lance/rust/lance-testing" }
|
||||
lance-datafusion = { path = "../lance/rust/lance-datafusion" }
|
||||
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "51.0", optional = false }
|
||||
arrow-array = "51.0"
|
||||
|
||||
@@ -125,10 +125,11 @@ nav:
|
||||
- DuckDB: python/duckdb.md
|
||||
- LangChain:
|
||||
- LangChain 🔗: integrations/langchain.md
|
||||
- LangChain demo: notebooks/langchain_demo.ipynb
|
||||
- LangChain JS/TS 🔗: https://js.langchain.com/docs/integrations/vectorstores/lancedb
|
||||
- LlamaIndex 🦙:
|
||||
- LlamaIndex docs: integrations/llamaIndex.md
|
||||
- LlamaIndex demo: https://docs.llamaindex.ai/en/stable/examples/vector_stores/LanceDBIndexDemo/
|
||||
- LlamaIndex demo: notebooks/llamaIndex_demo.ipynb
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
@@ -204,9 +205,9 @@ nav:
|
||||
- Pandas and PyArrow: python/pandas_and_pyarrow.md
|
||||
- Polars: python/polars_arrow.md
|
||||
- DuckDB: python/duckdb.md
|
||||
- LangChain 🦜️🔗↗: https://python.langchain.com/docs/integrations/vectorstores/lancedb
|
||||
- LangChain 🦜️🔗↗: integrations/langchain.md
|
||||
- LangChain.js 🦜️🔗↗: https://js.langchain.com/docs/integrations/vectorstores/lancedb
|
||||
- LlamaIndex 🦙↗: https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html
|
||||
- LlamaIndex 🦙↗: integrations/llamaIndex.md
|
||||
- Pydantic: python/pydantic.md
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
|
||||
@@ -68,6 +68,39 @@ table.add(
|
||||
]
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
|
||||
### Jina Embeddings
|
||||
LanceDB registers the JinaAI embeddings function in the registry as `jina`. You can pass any supported model name to the `create`. By default it uses `"jina-clip-v1"`.
|
||||
`jina-clip-v1` can handle both text and images and other models only support `text`.
|
||||
|
||||
You need to pass `JINA_API_KEY` in the environment variable or pass it as `api_key` to `create` method.
|
||||
|
||||
```python
|
||||
import os
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
os.environ['JINA_API_KEY'] = "jina_*"
|
||||
|
||||
db = lancedb.connect("/tmp/db")
|
||||
func = get_registry().get("jina").create(name="jina-clip-v1")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words, mode="overwrite")
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||

|
||||
|
||||
## Quick Start
|
||||
You can load your document data using langchain's loaders, for this example we are using `TextLoader` and `OpenAIEmbeddings` as the embedding model.
|
||||
You can load your document data using langchain's loaders, for this example we are using `TextLoader` and `OpenAIEmbeddings` as the embedding model. Checkout Complete example here - [LangChain demo](../notebooks/langchain_example.ipynb)
|
||||
```python
|
||||
import os
|
||||
from langchain.document_loaders import TextLoader
|
||||
@@ -38,6 +38,8 @@ The exhaustive list of parameters for `LanceDB` vector store are :
|
||||
- `api_key`: (Optional) API key to use for LanceDB cloud database. Defaults to `None`.
|
||||
- `region`: (Optional) Region to use for LanceDB cloud database. Only for LanceDB Cloud, defaults to `None`.
|
||||
- `mode`: (Optional) Mode to use for adding data to the table. Defaults to `'overwrite'`.
|
||||
- `reranker`: (Optional) The reranker to use for LanceDB.
|
||||
- `relevance_score_fn`: (Optional[Callable[[float], float]]) Langchain relevance score function to be used. Defaults to `None`.
|
||||
|
||||
```python
|
||||
db_url = "db://lang_test" # url of db you created
|
||||
@@ -54,12 +56,14 @@ vector_store = LanceDB(
|
||||
```
|
||||
|
||||
### Methods
|
||||
To add texts and store respective embeddings automatically:
|
||||
|
||||
##### add_texts()
|
||||
- `texts`: `Iterable` of strings to add to the vectorstore.
|
||||
- `metadatas`: Optional `list[dict()]` of metadatas associated with the texts.
|
||||
- `ids`: Optional `list` of ids to associate with the texts.
|
||||
- `kwargs`: `Any`
|
||||
|
||||
This method adds texts and stores respective embeddings automatically.
|
||||
|
||||
```python
|
||||
vector_store.add_texts(texts = ['test_123'], metadatas =[{'source' :'wiki'}])
|
||||
@@ -74,7 +78,6 @@ pd_df.to_csv("docsearch.csv", index=False)
|
||||
# you can also create a new vector store object using an older connection object:
|
||||
vector_store = LanceDB(connection=tbl, embedding=embeddings)
|
||||
```
|
||||
For index creation make sure your table has enough data in it. An ANN index is ususally not needed for datasets ~100K vectors. For large-scale (>1M) or higher dimension vectors, it is beneficial to create an ANN index.
|
||||
##### create_index()
|
||||
- `col_name`: `Optional[str] = None`
|
||||
- `vector_col`: `Optional[str] = None`
|
||||
@@ -82,6 +85,8 @@ For index creation make sure your table has enough data in it. An ANN index is u
|
||||
- `num_sub_vectors`: `Optional[int] = 96`
|
||||
- `index_cache_size`: `Optional[int] = None`
|
||||
|
||||
This method creates an index for the vector store. For index creation make sure your table has enough data in it. An ANN index is ususally not needed for datasets ~100K vectors. For large-scale (>1M) or higher dimension vectors, it is beneficial to create an ANN index.
|
||||
|
||||
```python
|
||||
# for creating vector index
|
||||
vector_store.create_index(vector_col='vector', metric = 'cosine')
|
||||
@@ -89,4 +94,108 @@ vector_store.create_index(vector_col='vector', metric = 'cosine')
|
||||
# for creating scalar index(for non-vector columns)
|
||||
vector_store.create_index(col_name='text')
|
||||
|
||||
```
|
||||
```
|
||||
|
||||
##### similarity_search()
|
||||
- `query`: `str`
|
||||
- `k`: `Optional[int] = None`
|
||||
- `filter`: `Optional[Dict[str, str]] = None`
|
||||
- `fts`: `Optional[bool] = False`
|
||||
- `name`: `Optional[str] = None`
|
||||
- `kwargs`: `Any`
|
||||
|
||||
Return documents most similar to the query without relevance scores
|
||||
|
||||
```python
|
||||
docs = docsearch.similarity_search(query)
|
||||
print(docs[0].page_content)
|
||||
```
|
||||
|
||||
##### similarity_search_by_vector()
|
||||
- `embedding`: `List[float]`
|
||||
- `k`: `Optional[int] = None`
|
||||
- `filter`: `Optional[Dict[str, str]] = None`
|
||||
- `name`: `Optional[str] = None`
|
||||
- `kwargs`: `Any`
|
||||
|
||||
Returns documents most similar to the query vector.
|
||||
|
||||
```python
|
||||
docs = docsearch.similarity_search_by_vector(query)
|
||||
print(docs[0].page_content)
|
||||
```
|
||||
|
||||
##### similarity_search_with_score()
|
||||
- `query`: `str`
|
||||
- `k`: `Optional[int] = None`
|
||||
- `filter`: `Optional[Dict[str, str]] = None`
|
||||
- `kwargs`: `Any`
|
||||
|
||||
Returns documents most similar to the query string with relevance scores, gets called by base class's `similarity_search_with_relevance_scores` which selects relevance score based on our `_select_relevance_score_fn`.
|
||||
|
||||
```python
|
||||
docs = docsearch.similarity_search_with_relevance_scores(query)
|
||||
print("relevance score - ", docs[0][1])
|
||||
print("text- ", docs[0][0].page_content[:1000])
|
||||
```
|
||||
|
||||
##### similarity_search_by_vector_with_relevance_scores()
|
||||
- `embedding`: `List[float]`
|
||||
- `k`: `Optional[int] = None`
|
||||
- `filter`: `Optional[Dict[str, str]] = None`
|
||||
- `name`: `Optional[str] = None`
|
||||
- `kwargs`: `Any`
|
||||
|
||||
Return documents most similar to the query vector with relevance scores.
|
||||
Relevance score
|
||||
|
||||
```python
|
||||
docs = docsearch.similarity_search_by_vector_with_relevance_scores(query_embedding)
|
||||
print("relevance score - ", docs[0][1])
|
||||
print("text- ", docs[0][0].page_content[:1000])
|
||||
```
|
||||
|
||||
##### max_marginal_relevance_search()
|
||||
- `query`: `str`
|
||||
- `k`: `Optional[int] = None`
|
||||
- `fetch_k` : Number of Documents to fetch to pass to MMR algorithm, `Optional[int] = None`
|
||||
- `lambda_mult`: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5. `float = 0.5`
|
||||
- `filter`: `Optional[Dict[str, str]] = None`
|
||||
- `kwargs`: `Any`
|
||||
|
||||
Returns docs selected using the maximal marginal relevance(MMR).
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents.
|
||||
|
||||
Similarly, `max_marginal_relevance_search_by_vector()` function returns docs most similar to the embedding passed to the function using MMR. instead of a string query you need to pass the embedding to be searched for.
|
||||
|
||||
```python
|
||||
result = docsearch.max_marginal_relevance_search(
|
||||
query="text"
|
||||
)
|
||||
result_texts = [doc.page_content for doc in result]
|
||||
print(result_texts)
|
||||
|
||||
## search by vector :
|
||||
result = docsearch.max_marginal_relevance_search_by_vector(
|
||||
embeddings.embed_query("text")
|
||||
)
|
||||
result_texts = [doc.page_content for doc in result]
|
||||
print(result_texts)
|
||||
```
|
||||
|
||||
##### add_images()
|
||||
- `uris` : File path to the image. `List[str]`.
|
||||
- `metadatas` : Optional list of metadatas. `(Optional[List[dict]], optional)`
|
||||
- `ids` : Optional list of IDs. `(Optional[List[str]], optional)`
|
||||
|
||||
Adds images by automatically creating their embeddings and adds them to the vectorstore.
|
||||
|
||||
```python
|
||||
vec_store.add_images(uris=image_uris)
|
||||
# here image_uris are local fs paths to the images.
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||

|
||||
|
||||
## Quick start
|
||||
You would need to install the integration via `pip install llama-index-vector-stores-lancedb` in order to use it. You can run the below script to try it out :
|
||||
You would need to install the integration via `pip install llama-index-vector-stores-lancedb` in order to use it.
|
||||
You can run the below script to try it out :
|
||||
```python
|
||||
import logging
|
||||
import sys
|
||||
@@ -43,6 +44,8 @@ retriever = index.as_retriever(vector_store_kwargs={"where": lance_filter})
|
||||
response = retriever.retrieve("What did the author do growing up?")
|
||||
```
|
||||
|
||||
Checkout Complete example here - [LlamaIndex demo](../notebooks/LlamaIndex_example.ipynb)
|
||||
|
||||
### Filtering
|
||||
For metadata filtering, you can use a Lance SQL-like string filter as demonstrated in the example above. Additionally, you can also filter using the `MetadataFilters` class from LlamaIndex:
|
||||
```python
|
||||
|
||||
538
docs/src/notebooks/LlamaIndex_example.ipynb
Normal file
538
docs/src/notebooks/LlamaIndex_example.ipynb
Normal file
@@ -0,0 +1,538 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "2db56c9b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/docs/examples/vector_stores/LanceDBIndexDemo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "db0855d0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LanceDB Vector Store\n",
|
||||
"In this notebook we are going to show how to use [LanceDB](https://www.lancedb.com) to perform vector searches in LlamaIndex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "f44170b2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6c84199c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install llama-index llama-index-vector-stores-lancedb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a90ce34",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install lancedb==0.6.13 #Only required if the above cell installs an older version of lancedb (pypi package may not be released yet)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "39c62671",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Refresh vector store URI if restarting or re-using the same notebook\n",
|
||||
"! rm -rf ./lancedb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "59b54276",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"# Uncomment to see debug logs\n",
|
||||
"# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)\n",
|
||||
"# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from llama_index.core import SimpleDirectoryReader, Document, StorageContext\n",
|
||||
"from llama_index.core import VectorStoreIndex\n",
|
||||
"from llama_index.vector_stores.lancedb import LanceDBVectorStore\n",
|
||||
"import textwrap"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "26c71b6d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup OpenAI\n",
|
||||
"The first step is to configure the openai key. It will be used to created embeddings for the documents loaded into the index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "67b86621",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"\n",
|
||||
"openai.api_key = \"sk-\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "073f0a68",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Download Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eef1b911",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2024-06-11 16:42:37-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt\n",
|
||||
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...\n",
|
||||
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
|
||||
"HTTP request sent, awaiting response... 200 OK\n",
|
||||
"Length: 75042 (73K) [text/plain]\n",
|
||||
"Saving to: ‘data/paul_graham/paul_graham_essay.txt’\n",
|
||||
"\n",
|
||||
"data/paul_graham/pa 100%[===================>] 73.28K --.-KB/s in 0.02s \n",
|
||||
"\n",
|
||||
"2024-06-11 16:42:37 (3.97 MB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!mkdir -p 'data/paul_graham/'\n",
|
||||
"!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "f7010b1d-d1bb-4f08-9309-a328bb4ea396",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Loading documents\n",
|
||||
"Load the documents stored in the `data/paul_graham/` using the SimpleDirectoryReader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c154dd4b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Document ID: cac1ba78-5007-4cf8-89ba-280264790115 Document Hash: fe2d4d3ef3a860780f6c2599808caa587c8be6516fe0ba4ca53cf117044ba953\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"documents = SimpleDirectoryReader(\"./data/paul_graham/\").load_data()\n",
|
||||
"print(\"Document ID:\", documents[0].doc_id, \"Document Hash:\", documents[0].hash)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "c0232fd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create the index\n",
|
||||
"Here we create an index backed by LanceDB using the documents loaded previously. LanceDBVectorStore takes a few arguments.\n",
|
||||
"- uri (str, required): Location where LanceDB will store its files.\n",
|
||||
"- table_name (str, optional): The table name where the embeddings will be stored. Defaults to \"vectors\".\n",
|
||||
"- nprobes (int, optional): The number of probes used. A higher number makes search more accurate but also slower. Defaults to 20.\n",
|
||||
"- refine_factor: (int, optional): Refine the results by reading extra elements and re-ranking them in memory. Defaults to None\n",
|
||||
"\n",
|
||||
"- More details can be found at [LanceDB docs](https://lancedb.github.io/lancedb/ann_indexes)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f2e20ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### For LanceDB cloud :\n",
|
||||
"```python\n",
|
||||
"vector_store = LanceDBVectorStore( \n",
|
||||
" uri=\"db://db_name\", # your remote DB URI\n",
|
||||
" api_key=\"sk_..\", # lancedb cloud api key\n",
|
||||
" region=\"your-region\" # the region you configured\n",
|
||||
" ...\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8731da62",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vector_store = LanceDBVectorStore(\n",
|
||||
" uri=\"./lancedb\", mode=\"overwrite\", query_type=\"hybrid\"\n",
|
||||
")\n",
|
||||
"storage_context = StorageContext.from_defaults(vector_store=vector_store)\n",
|
||||
"\n",
|
||||
"index = VectorStoreIndex.from_documents(\n",
|
||||
" documents, storage_context=storage_context\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "8ee4473a-094f-4d0a-a825-e1213db07240",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Query the index\n",
|
||||
"We can now ask questions using our index. We can use filtering via `MetadataFilters` or use native lance `where` clause."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5eb6419b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from llama_index.core.vector_stores import (\n",
|
||||
" MetadataFilters,\n",
|
||||
" FilterOperator,\n",
|
||||
" FilterCondition,\n",
|
||||
" MetadataFilter,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"query_filters = MetadataFilters(\n",
|
||||
" filters=[\n",
|
||||
" MetadataFilter(\n",
|
||||
" key=\"creation_date\",\n",
|
||||
" operator=FilterOperator.EQ,\n",
|
||||
" value=datetime.now().strftime(\"%Y-%m-%d\"),\n",
|
||||
" ),\n",
|
||||
" MetadataFilter(\n",
|
||||
" key=\"file_size\", value=75040, operator=FilterOperator.GT\n",
|
||||
" ),\n",
|
||||
" ],\n",
|
||||
" condition=FilterCondition.AND,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ee201930",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Hybrid Search\n",
|
||||
"\n",
|
||||
"LanceDB offers hybrid search with reranking capabilities. For complete documentation, refer [here](https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/).\n",
|
||||
"\n",
|
||||
"This example uses the `colbert` reranker. The following cell installs the necessary dependencies for `colbert`. If you choose a different reranker, make sure to adjust the dependencies accordingly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e12d1454",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install -U torch transformers tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c742cb07",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"if you want to add a reranker at vector store initialization, you can pass it in the arguments like below :\n",
|
||||
"```\n",
|
||||
"from lancedb.rerankers import ColbertReranker\n",
|
||||
"reranker = ColbertReranker()\n",
|
||||
"vector_store = LanceDBVectorStore(uri=\"./lancedb\", reranker=reranker, mode=\"overwrite\")\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "27ea047b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import lancedb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8414517f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from lancedb.rerankers import ColbertReranker\n",
|
||||
"\n",
|
||||
"reranker = ColbertReranker()\n",
|
||||
"vector_store._add_reranker(reranker)\n",
|
||||
"\n",
|
||||
"query_engine = index.as_query_engine(\n",
|
||||
" filters=query_filters,\n",
|
||||
" # vector_store_kwargs={\n",
|
||||
" # \"query_type\": \"fts\",\n",
|
||||
" # },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = query_engine.query(\"How much did Viaweb charge per month?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dc6ccb7a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Viaweb charged $100 a month for a small store and $300 a month for a big one.\n",
|
||||
"metadata - {'65ed5f07-5b8a-4143-a939-e8764884828e': {'file_path': '/Users/raghavdixit/Desktop/open_source/llama_index_lance/docs/docs/examples/vector_stores/data/paul_graham/paul_graham_essay.txt', 'file_name': 'paul_graham_essay.txt', 'file_type': 'text/plain', 'file_size': 75042, 'creation_date': '2024-06-11', 'last_modified_date': '2024-06-11'}, 'be231827-20b8-4988-ac75-94fa79b3c22e': {'file_path': '/Users/raghavdixit/Desktop/open_source/llama_index_lance/docs/docs/examples/vector_stores/data/paul_graham/paul_graham_essay.txt', 'file_name': 'paul_graham_essay.txt', 'file_type': 'text/plain', 'file_size': 75042, 'creation_date': '2024-06-11', 'last_modified_date': '2024-06-11'}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(response)\n",
|
||||
"print(\"metadata -\", response.metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0c1c6c73",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### lance filters(SQL like) directly via the `where` clause :"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0a2bcc07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"lance_filter = \"metadata.file_name = 'paul_graham_essay.txt' \"\n",
|
||||
"retriever = index.as_retriever(vector_store_kwargs={\"where\": lance_filter})\n",
|
||||
"response = retriever.retrieve(\"What did the author do growing up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ac47cf9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"What I Worked On\n",
|
||||
"\n",
|
||||
"February 2021\n",
|
||||
"\n",
|
||||
"Before college the two main things I worked on, outside of school, were writing and programming. I didn't write essays. I wrote what beginning writers were supposed to write then, and probably still are: short stories. My stories were awful. They had hardly any plot, just characters with strong feelings, which I imagined made them deep.\n",
|
||||
"\n",
|
||||
"The first programs I tried writing were on the IBM 1401 that our school district used for what was then called \"data processing.\" This was in 9th grade, so I was 13 or 14. The school district's 1401 happened to be in the basement of our junior high school, and my friend Rich Draves and I got permission to use it. It was like a mini Bond villain's lair down there, with all these alien-looking machines — CPU, disk drives, printer, card reader — sitting up on a raised floor under bright fluorescent lights.\n",
|
||||
"\n",
|
||||
"The language we used was an early version of Fortran. You had to type programs on punch cards, then stack them in the card reader and press a button to load the program into memory and run it. The result would ordinarily be to print something on the spectacularly loud printer.\n",
|
||||
"\n",
|
||||
"I was puzzled by the 1401. I couldn't figure out what to do with it. And in retrospect there's not much I could have done with it. The only form of input to programs was data stored on punched cards, and I didn't have any data stored on punched cards. The only other option was to do things that didn't rely on any input, like calculate approximations of pi, but I didn't know enough math to do anything interesting of that type. So I'm not surprised I can't remember any programs I wrote, because they can't have done much. My clearest memory is of the moment I learned it was possible for programs not to terminate, when one of mine didn't. On a machine without time-sharing, this was a social as well as a technical error, as the data center manager's expression made clear.\n",
|
||||
"\n",
|
||||
"With microcomputers, everything changed. Now you could have a computer sitting right in front of you, on a desk, that could respond to your keystrokes as it was running instead of just churning through a stack of punch cards and then stopping. [1]\n",
|
||||
"\n",
|
||||
"The first of my friends to get a microcomputer built it himself. It was sold as a kit by Heathkit. I remember vividly how impressed and envious I felt watching him sitting in front of it, typing programs right into the computer.\n",
|
||||
"\n",
|
||||
"Computers were expensive in those days and it took me years of nagging before I convinced my father to buy one, a TRS-80, in about 1980. The gold standard then was the Apple II, but a TRS-80 was good enough. This was when I really started programming. I wrote simple games, a program to predict how high my model rockets would fly, and a word processor that my father used to write at least one book. There was only room in memory for about 2 pages of text, so he'd write 2 pages at a time and then print them out, but it was a lot better than a typewriter.\n",
|
||||
"\n",
|
||||
"Though I liked programming, I didn't plan to study it in college. In college I was going to study philosophy, which sounded much more powerful. It seemed, to my naive high school self, to be the study of the ultimate truths, compared to which the things studied in other fields would be mere domain knowledge. What I discovered when I got to college was that the other fields took up so much of the space of ideas that there wasn't much left for these supposed ultimate truths. All that seemed left for philosophy were edge cases that people in other fields felt could safely be ignored.\n",
|
||||
"\n",
|
||||
"I couldn't have put this into words when I was 18. All I knew at the time was that I kept taking philosophy courses and they kept being boring. So I decided to switch to AI.\n",
|
||||
"\n",
|
||||
"AI was in the air in the mid 1980s, but there were two things especially that made me want to work on it: a novel by Heinlein called The Moon is a Harsh Mistress, which featured an intelligent computer called Mike, and a PBS documentary that showed Terry Winograd using SHRDLU. I haven't tried rereading The Moon is a Harsh Mistress, so I don't know how well it has aged, but when I read it I was drawn entirely into its world.\n",
|
||||
"metadata - {'file_path': '/Users/raghavdixit/Desktop/open_source/llama_index_lance/docs/docs/examples/vector_stores/data/paul_graham/paul_graham_essay.txt', 'file_name': 'paul_graham_essay.txt', 'file_type': 'text/plain', 'file_size': 75042, 'creation_date': '2024-06-11', 'last_modified_date': '2024-06-11'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(response[0].get_content())\n",
|
||||
"print(\"metadata -\", response[0].metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "6afc84ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Appending data\n",
|
||||
"You can also add data to an existing index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "759a532e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nodes = [node.node for node in response]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "069fc099",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"del index\n",
|
||||
"\n",
|
||||
"index = VectorStoreIndex.from_documents(\n",
|
||||
" [Document(text=\"The sky is purple in Portland, Maine\")],\n",
|
||||
" uri=\"/tmp/new_dataset\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a64ed441",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"index.insert_nodes(nodes)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5cffcfe",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Portland, Maine\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query_engine = index.as_query_engine()\n",
|
||||
"response = query_engine.query(\"Where is the sky purple?\")\n",
|
||||
"print(textwrap.fill(str(response), 100))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ec548a02",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also create an index from an existing table"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dc99404d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"del index\n",
|
||||
"\n",
|
||||
"vec_store = LanceDBVectorStore.from_table(vector_store._table)\n",
|
||||
"index = VectorStoreIndex.from_vector_store(vec_store)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7b2e8cca",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The author started Viaweb and Aspra.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query_engine = index.as_query_engine()\n",
|
||||
"response = query_engine.query(\"What companies did the author start?\")\n",
|
||||
"print(textwrap.fill(str(response), 100))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
566
docs/src/notebooks/langchain_example.ipynb
Normal file
566
docs/src/notebooks/langchain_example.ipynb
Normal file
@@ -0,0 +1,566 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "683953b3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LanceDB\n",
|
||||
"\n",
|
||||
">[LanceDB](https://lancedb.com/) is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings. Fully open source.\n",
|
||||
"\n",
|
||||
"This notebook shows how to use functionality related to the `LanceDB` vector database based on the Lance data format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1051ba9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install tantivy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "88ac92c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install -U langchain-openai langchain-community"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5a1c84d6-a10f-428c-95cd-46d3a1702e07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install lancedb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "99134dd1-b91e-486f-8d90-534248e43b9d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We want to use OpenAIEmbeddings so we have to get the OpenAI API Key. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "a0361f5c-e6f4-45f4-b829-11680cf03cec",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d114ed78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! rm -rf /tmp/lancedb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.document_loaders import TextLoader\n",
|
||||
"from langchain_community.vectorstores import LanceDB\n",
|
||||
"from langchain_openai import OpenAIEmbeddings\n",
|
||||
"from langchain_text_splitters import CharacterTextSplitter\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"../../how_to/state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"\n",
|
||||
"documents = CharacterTextSplitter().split_documents(documents)\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e9517bb0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### For LanceDB cloud, you can invoke the vector store as follows :\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"db_url = \"db://lang_test\" # url of db you created\n",
|
||||
"api_key = \"xxxxx\" # your API key\n",
|
||||
"region=\"us-east-1-dev\" # your selected region\n",
|
||||
"\n",
|
||||
"vector_store = LanceDB(\n",
|
||||
" uri=db_url,\n",
|
||||
" api_key=api_key,\n",
|
||||
" region=region,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" table_name='langchain_test'\n",
|
||||
" )\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"You can also add `region`, `api_key`, `uri` to `from_documents()` classmethod\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "6e104aee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from lancedb.rerankers import LinearCombinationReranker\n",
|
||||
"\n",
|
||||
"reranker = LinearCombinationReranker(weight=0.3)\n",
|
||||
"\n",
|
||||
"docsearch = LanceDB.from_documents(documents, embeddings, reranker=reranker)\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "259c7988",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"relevance score - 0.7066475030191711\n",
|
||||
"text- They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n",
|
||||
"\n",
|
||||
"Officer Mora was 27 years old. \n",
|
||||
"\n",
|
||||
"Officer Rivera was 22. \n",
|
||||
"\n",
|
||||
"Both Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n",
|
||||
"\n",
|
||||
"I spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. \n",
|
||||
"\n",
|
||||
"I’ve worked on these issues a long time. \n",
|
||||
"\n",
|
||||
"I know what works: Investing in crime prevention and community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety. \n",
|
||||
"\n",
|
||||
"So let’s not abandon our streets. Or choose between safety and equal justice. \n",
|
||||
"\n",
|
||||
"Let’s come together to protect our communities, restore trust, and hold law enforcement accountable. \n",
|
||||
"\n",
|
||||
"That’s why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers. \n",
|
||||
"\n",
|
||||
"That’s why the American Rescue \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"docs = docsearch.similarity_search_with_relevance_scores(query)\n",
|
||||
"print(\"relevance score - \", docs[0][1])\n",
|
||||
"print(\"text- \", docs[0][0].page_content[:1000])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "9fa29dae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"distance - 0.30000001192092896\n",
|
||||
"text- My administration is providing assistance with job training and housing, and now helping lower-income veterans get VA care debt-free. \n",
|
||||
"\n",
|
||||
"Our troops in Iraq and Afghanistan faced many dangers. \n",
|
||||
"\n",
|
||||
"One was stationed at bases and breathing in toxic smoke from “burn pits” that incinerated wastes of war—medical and hazard material, jet fuel, and more. \n",
|
||||
"\n",
|
||||
"When they came home, many of the world’s fittest and best trained warriors were never the same. \n",
|
||||
"\n",
|
||||
"Headaches. Numbness. Dizziness. \n",
|
||||
"\n",
|
||||
"A cancer that would put them in a flag-draped coffin. \n",
|
||||
"\n",
|
||||
"I know. \n",
|
||||
"\n",
|
||||
"One of those soldiers was my son Major Beau Biden. \n",
|
||||
"\n",
|
||||
"We don’t know for sure if a burn pit was the cause of his brain cancer, or the diseases of so many of our troops. \n",
|
||||
"\n",
|
||||
"But I’m committed to finding out everything we can. \n",
|
||||
"\n",
|
||||
"Committed to military families like Danielle Robinson from Ohio. \n",
|
||||
"\n",
|
||||
"The widow of Sergeant First Class Heath Robinson. \n",
|
||||
"\n",
|
||||
"He was born a soldier. Army National Guard. Combat medic in Kosovo and Iraq. \n",
|
||||
"\n",
|
||||
"Stationed near Baghdad, just ya\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"docs = docsearch.similarity_search_with_score(query=\"Headaches\", query_type=\"hybrid\")\n",
|
||||
"print(\"distance - \", docs[0][1])\n",
|
||||
"print(\"text- \", docs[0][0].page_content[:1000])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "e70ad201",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"reranker : <lancedb.rerankers.linear_combination.LinearCombinationReranker object at 0x107ef1130>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"reranker : \", docsearch._reranker)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f5e1cdfd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Additionaly, to explore the table you can load it into a df or save it in a csv file: \n",
|
||||
"```python\n",
|
||||
"tbl = docsearch.get_table()\n",
|
||||
"print(\"tbl:\", tbl)\n",
|
||||
"pd_df = tbl.to_pandas()\n",
|
||||
"# pd_df.to_csv(\"docsearch.csv\", index=False)\n",
|
||||
"\n",
|
||||
"# you can also create a new vector store object using an older connection object:\n",
|
||||
"vector_store = LanceDB(connection=tbl, embedding=embeddings)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "9c608226",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"metadata : {'source': '../../how_to/state_of_the_union.txt'}\n",
|
||||
"\n",
|
||||
"SQL filtering :\n",
|
||||
"\n",
|
||||
"They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n",
|
||||
"\n",
|
||||
"Officer Mora was 27 years old. \n",
|
||||
"\n",
|
||||
"Officer Rivera was 22. \n",
|
||||
"\n",
|
||||
"Both Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n",
|
||||
"\n",
|
||||
"I spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. \n",
|
||||
"\n",
|
||||
"I’ve worked on these issues a long time. \n",
|
||||
"\n",
|
||||
"I know what works: Investing in crime prevention and community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety. \n",
|
||||
"\n",
|
||||
"So let’s not abandon our streets. Or choose between safety and equal justice. \n",
|
||||
"\n",
|
||||
"Let’s come together to protect our communities, restore trust, and hold law enforcement accountable. \n",
|
||||
"\n",
|
||||
"That’s why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers. \n",
|
||||
"\n",
|
||||
"That’s why the American Rescue Plan provided $350 Billion that cities, states, and counties can use to hire more police and invest in proven strategies like community violence interruption—trusted messengers breaking the cycle of violence and trauma and giving young people hope. \n",
|
||||
"\n",
|
||||
"We should all agree: The answer is not to Defund the police. The answer is to FUND the police with the resources and training they need to protect our communities. \n",
|
||||
"\n",
|
||||
"I ask Democrats and Republicans alike: Pass my budget and keep our neighborhoods safe. \n",
|
||||
"\n",
|
||||
"And I will keep doing everything in my power to crack down on gun trafficking and ghost guns you can buy online and make at home—they have no serial numbers and can’t be traced. \n",
|
||||
"\n",
|
||||
"And I ask Congress to pass proven measures to reduce gun violence. Pass universal background checks. Why should anyone on a terrorist list be able to purchase a weapon? \n",
|
||||
"\n",
|
||||
"Ban assault weapons and high-capacity magazines. \n",
|
||||
"\n",
|
||||
"Repeal the liability shield that makes gun manufacturers the only industry in America that can’t be sued. \n",
|
||||
"\n",
|
||||
"These laws don’t infringe on the Second Amendment. They save lives. \n",
|
||||
"\n",
|
||||
"The most fundamental right in America is the right to vote – and to have it counted. And it’s under assault. \n",
|
||||
"\n",
|
||||
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
|
||||
"\n",
|
||||
"We cannot let this happen. \n",
|
||||
"\n",
|
||||
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||
"\n",
|
||||
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||
"\n",
|
||||
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||
"\n",
|
||||
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. \n",
|
||||
"\n",
|
||||
"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n",
|
||||
"\n",
|
||||
"And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n",
|
||||
"\n",
|
||||
"We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling. \n",
|
||||
"\n",
|
||||
"We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers. \n",
|
||||
"\n",
|
||||
"We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"docs = docsearch.similarity_search(\n",
|
||||
" query=query, filter={\"metadata.source\": \"../../how_to/state_of_the_union.txt\"}\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"metadata :\", docs[0].metadata)\n",
|
||||
"\n",
|
||||
"# or you can directly supply SQL string filters :\n",
|
||||
"\n",
|
||||
"print(\"\\nSQL filtering :\\n\")\n",
|
||||
"docs = docsearch.similarity_search(query=query, filter=\"text LIKE '%Officer Rivera%'\")\n",
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9a173c94",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Adding images "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "05f669d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install -U langchain-experimental"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ed69810",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install open_clip_torch torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "2cacb5ee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! rm -rf '/tmp/multimmodal_lance'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "b3456e2c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_experimental.open_clip import OpenCLIPEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "3848eba2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"# List of image URLs to download\n",
|
||||
"image_urls = [\n",
|
||||
" \"https://github.com/raghavdixit99/assets/assets/34462078/abf47cc4-d979-4aaa-83be-53a2115bf318\",\n",
|
||||
" \"https://github.com/raghavdixit99/assets/assets/34462078/93be928e-522b-4e37-889d-d4efd54b2112\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"texts = [\"bird\", \"dragon\"]\n",
|
||||
"\n",
|
||||
"# Directory to save images\n",
|
||||
"dir_name = \"./photos/\"\n",
|
||||
"\n",
|
||||
"# Create directory if it doesn't exist\n",
|
||||
"os.makedirs(dir_name, exist_ok=True)\n",
|
||||
"\n",
|
||||
"image_uris = []\n",
|
||||
"# Download and save each image\n",
|
||||
"for i, url in enumerate(image_urls, start=1):\n",
|
||||
" response = requests.get(url)\n",
|
||||
" path = os.path.join(dir_name, f\"image{i}.jpg\")\n",
|
||||
" image_uris.append(path)\n",
|
||||
" with open(path, \"wb\") as f:\n",
|
||||
" f.write(response.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "3d62c2a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.vectorstores import LanceDB\n",
|
||||
"\n",
|
||||
"vec_store = LanceDB(\n",
|
||||
" table_name=\"multimodal_test\",\n",
|
||||
" embedding=OpenCLIPEmbeddings(),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "ebbb4881",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['b673620b-01f0-42ca-a92e-d033bb92c0a6',\n",
|
||||
" '99c3a5b0-b577-417a-8177-92f4a655dbfb']"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vec_store.add_images(uris=image_uris)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "3c29dea3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['f7adde5d-a4a3-402b-9e73-088b230722c3',\n",
|
||||
" 'cbed59da-0aec-4bff-8820-9e59d81a2140']"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vec_store.add_texts(texts)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "8b2f25ce",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img_embed = vec_store._embedding.embed_query(\"bird\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "87a24079",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Document(page_content='bird', metadata={'id': 'f7adde5d-a4a3-402b-9e73-088b230722c3'})"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vec_store.similarity_search_by_vector(img_embed)[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "78557867",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LanceTable(connection=LanceDBConnection(/tmp/lancedb), name=\"multimodal_test\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vec_store._table"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
78
docs/src/reranking/jina.md
Normal file
78
docs/src/reranking/jina.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Jina Reranker
|
||||
|
||||
This re-ranker uses the [Jina](https://jina.ai/reranker/) API to rerank the search results. You can use this re-ranker by passing `JinaReranker()` to the `rerank()` method. Note that you'll either need to set the `JINA_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.
|
||||
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import os
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import JinaReranker
|
||||
|
||||
os.environ['JINA_API_KEY'] = "jina_*"
|
||||
|
||||
|
||||
embedder = get_registry().get("jina").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = JinaReranker(api_key="key")
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `"jina-reranker-v2-base-multilingual"` | The name of the reranker model to use. You can find the list of available models in https://jina.ai/reranker/|
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
|
||||
| `api_key` | `str` | `None` | The API key for the Jina API. If not provided, the `JINA_API_KEY` environment variable is used. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
@@ -706,10 +706,10 @@ describe("table.search", () => {
|
||||
const data = [{ text: "hello world" }, { text: "goodbye world" }];
|
||||
const table = await db.createTable("test", data, { schema });
|
||||
|
||||
const results = await table.search("greetings").then((r) => r.toArray());
|
||||
const results = await table.search("greetings").toArray();
|
||||
expect(results[0].text).toBe(data[0].text);
|
||||
|
||||
const results2 = await table.search("farewell").then((r) => r.toArray());
|
||||
const results2 = await table.search("farewell").toArray();
|
||||
expect(results2[0].text).toBe(data[1].text);
|
||||
});
|
||||
|
||||
@@ -721,7 +721,7 @@ describe("table.search", () => {
|
||||
];
|
||||
const table = await db.createTable("test", data);
|
||||
|
||||
expect(table.search("hello")).rejects.toThrow(
|
||||
expect(table.search("hello").toArray()).rejects.toThrow(
|
||||
"No embedding functions are defined in the table",
|
||||
);
|
||||
});
|
||||
@@ -745,3 +745,27 @@ describe("table.search", () => {
|
||||
expect(results[0].text).toBe(data[1].text);
|
||||
});
|
||||
});
|
||||
|
||||
describe("when calling explainPlan", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let table: Table;
|
||||
let queryVec: number[];
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
const con = await connect(tmpDir.name);
|
||||
table = await con.createTable("vectors", [{ id: 1, vector: [0.1, 0.2] }]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
tmpDir.removeCallback();
|
||||
});
|
||||
|
||||
it("retrieves query plan", async () => {
|
||||
queryVec = Array(2)
|
||||
.fill(1)
|
||||
.map(() => Math.random());
|
||||
const plan = await table.query().nearestTo(queryVec).explainPlan(true);
|
||||
|
||||
expect(plan).toMatch("KNN");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -97,7 +97,11 @@ export type TableLike =
|
||||
| ArrowTable
|
||||
| { schema: SchemaLike; batches: RecordBatchLike[] };
|
||||
|
||||
export type IntoVector = Float32Array | Float64Array | number[];
|
||||
export type IntoVector =
|
||||
| Float32Array
|
||||
| Float64Array
|
||||
| number[]
|
||||
| Promise<Float32Array | Float64Array | number[]>;
|
||||
|
||||
export function isArrowTable(value: object): value is TableLike {
|
||||
if (value instanceof ArrowTable) return true;
|
||||
|
||||
@@ -181,7 +181,7 @@ export abstract class EmbeddingFunction<
|
||||
/**
|
||||
Compute the embeddings for a single query
|
||||
*/
|
||||
async computeQueryEmbeddings(data: T): Promise<IntoVector> {
|
||||
async computeQueryEmbeddings(data: T): Promise<Awaited<IntoVector>> {
|
||||
return this.computeSourceEmbeddings([data]).then(
|
||||
(embeddings) => embeddings[0],
|
||||
);
|
||||
|
||||
@@ -89,15 +89,26 @@ export interface QueryExecutionOptions {
|
||||
}
|
||||
|
||||
/** Common methods supported by all query types */
|
||||
export class QueryBase<
|
||||
NativeQueryType extends NativeQuery | NativeVectorQuery,
|
||||
QueryType,
|
||||
> implements AsyncIterable<RecordBatch>
|
||||
export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
|
||||
implements AsyncIterable<RecordBatch>
|
||||
{
|
||||
protected constructor(protected inner: NativeQueryType) {
|
||||
protected constructor(
|
||||
protected inner: NativeQueryType | Promise<NativeQueryType>,
|
||||
) {
|
||||
// intentionally empty
|
||||
}
|
||||
|
||||
// call a function on the inner (either a promise or the actual object)
|
||||
protected doCall(fn: (inner: NativeQueryType) => void) {
|
||||
if (this.inner instanceof Promise) {
|
||||
this.inner = this.inner.then((inner) => {
|
||||
fn(inner);
|
||||
return inner;
|
||||
});
|
||||
} else {
|
||||
fn(this.inner);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* A filter statement to be applied to this query.
|
||||
*
|
||||
@@ -110,16 +121,16 @@ export class QueryBase<
|
||||
* Filtering performance can often be improved by creating a scalar index
|
||||
* on the filter column(s).
|
||||
*/
|
||||
where(predicate: string): QueryType {
|
||||
this.inner.onlyIf(predicate);
|
||||
return this as unknown as QueryType;
|
||||
where(predicate: string): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.onlyIf(predicate));
|
||||
return this;
|
||||
}
|
||||
/**
|
||||
* A filter statement to be applied to this query.
|
||||
* @alias where
|
||||
* @deprecated Use `where` instead
|
||||
*/
|
||||
filter(predicate: string): QueryType {
|
||||
filter(predicate: string): this {
|
||||
return this.where(predicate);
|
||||
}
|
||||
|
||||
@@ -155,7 +166,7 @@ export class QueryBase<
|
||||
*/
|
||||
select(
|
||||
columns: string[] | Map<string, string> | Record<string, string> | string,
|
||||
): QueryType {
|
||||
): this {
|
||||
let columnTuples: [string, string][];
|
||||
if (typeof columns === "string") {
|
||||
columns = [columns];
|
||||
@@ -167,8 +178,10 @@ export class QueryBase<
|
||||
} else {
|
||||
columnTuples = Object.entries(columns);
|
||||
}
|
||||
this.inner.select(columnTuples);
|
||||
return this as unknown as QueryType;
|
||||
this.doCall((inner: NativeQueryType) => {
|
||||
inner.select(columnTuples);
|
||||
});
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -177,15 +190,19 @@ export class QueryBase<
|
||||
* By default, a plain search has no limit. If this method is not
|
||||
* called then every valid row from the table will be returned.
|
||||
*/
|
||||
limit(limit: number): QueryType {
|
||||
this.inner.limit(limit);
|
||||
return this as unknown as QueryType;
|
||||
limit(limit: number): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.limit(limit));
|
||||
return this;
|
||||
}
|
||||
|
||||
protected nativeExecute(
|
||||
options?: Partial<QueryExecutionOptions>,
|
||||
): Promise<NativeBatchIterator> {
|
||||
return this.inner.execute(options?.maxBatchLength);
|
||||
if (this.inner instanceof Promise) {
|
||||
return this.inner.then((inner) => inner.execute(options?.maxBatchLength));
|
||||
} else {
|
||||
return this.inner.execute(options?.maxBatchLength);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -214,7 +231,13 @@ export class QueryBase<
|
||||
/** Collect the results as an Arrow @see {@link ArrowTable}. */
|
||||
async toArrow(options?: Partial<QueryExecutionOptions>): Promise<ArrowTable> {
|
||||
const batches = [];
|
||||
for await (const batch of new RecordBatchIterable(this.inner, options)) {
|
||||
let inner;
|
||||
if (this.inner instanceof Promise) {
|
||||
inner = await this.inner;
|
||||
} else {
|
||||
inner = this.inner;
|
||||
}
|
||||
for await (const batch of new RecordBatchIterable(inner, options)) {
|
||||
batches.push(batch);
|
||||
}
|
||||
return new ArrowTable(batches);
|
||||
@@ -226,6 +249,24 @@ export class QueryBase<
|
||||
const tbl = await this.toArrow(options);
|
||||
return tbl.toArray();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates an explanation of the query execution plan.
|
||||
*
|
||||
* @example
|
||||
* import * as lancedb from "@lancedb/lancedb"
|
||||
* const db = await lancedb.connect("./.lancedb");
|
||||
* const table = await db.createTable("my_table", [
|
||||
* { vector: [1.1, 0.9], id: "1" },
|
||||
* ]);
|
||||
* const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan();
|
||||
*
|
||||
* @param verbose - If true, provides a more detailed explanation. Defaults to false.
|
||||
* @returns A Promise that resolves to a string containing the query execution plan explanation.
|
||||
*/
|
||||
async explainPlan(verbose = false): Promise<string> {
|
||||
return await this.inner.explainPlan(verbose);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -240,8 +281,8 @@ export interface ExecutableQuery {}
|
||||
*
|
||||
* This builder can be reused to execute the query many times.
|
||||
*/
|
||||
export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
constructor(inner: NativeVectorQuery) {
|
||||
export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
constructor(inner: NativeVectorQuery | Promise<NativeVectorQuery>) {
|
||||
super(inner);
|
||||
}
|
||||
|
||||
@@ -268,7 +309,8 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* you the desired recall.
|
||||
*/
|
||||
nprobes(nprobes: number): VectorQuery {
|
||||
this.inner.nprobes(nprobes);
|
||||
super.doCall((inner) => inner.nprobes(nprobes));
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -282,7 +324,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* whose data type is a fixed-size-list of floats.
|
||||
*/
|
||||
column(column: string): VectorQuery {
|
||||
this.inner.column(column);
|
||||
super.doCall((inner) => inner.column(column));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -303,7 +345,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
distanceType(
|
||||
distanceType: Required<IvfPqOptions>["distanceType"],
|
||||
): VectorQuery {
|
||||
this.inner.distanceType(distanceType);
|
||||
super.doCall((inner) => inner.distanceType(distanceType));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -337,7 +379,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* distance between the query vector and the actual uncompressed vector.
|
||||
*/
|
||||
refineFactor(refineFactor: number): VectorQuery {
|
||||
this.inner.refineFactor(refineFactor);
|
||||
super.doCall((inner) => inner.refineFactor(refineFactor));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -362,7 +404,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* factor can often help restore some of the results lost by post filtering.
|
||||
*/
|
||||
postfilter(): VectorQuery {
|
||||
this.inner.postfilter();
|
||||
super.doCall((inner) => inner.postfilter());
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -376,13 +418,13 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
* calculate your recall to select an appropriate value for nprobes.
|
||||
*/
|
||||
bypassVectorIndex(): VectorQuery {
|
||||
this.inner.bypassVectorIndex();
|
||||
super.doCall((inner) => inner.bypassVectorIndex());
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
/** A builder for LanceDB queries. */
|
||||
export class Query extends QueryBase<NativeQuery, Query> {
|
||||
export class Query extends QueryBase<NativeQuery> {
|
||||
constructor(tbl: NativeTable) {
|
||||
super(tbl.query());
|
||||
}
|
||||
@@ -425,7 +467,37 @@ export class Query extends QueryBase<NativeQuery, Query> {
|
||||
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
||||
*/
|
||||
nearestTo(vector: IntoVector): VectorQuery {
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||
return new VectorQuery(vectorQuery);
|
||||
if (this.inner instanceof Promise) {
|
||||
const nativeQuery = this.inner.then(async (inner) => {
|
||||
if (vector instanceof Promise) {
|
||||
const arr = await vector.then((v) => Float32Array.from(v));
|
||||
return inner.nearestTo(arr);
|
||||
} else {
|
||||
return inner.nearestTo(Float32Array.from(vector));
|
||||
}
|
||||
});
|
||||
return new VectorQuery(nativeQuery);
|
||||
}
|
||||
if (vector instanceof Promise) {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.nearestTo(arr);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
return inner;
|
||||
} catch (e) {
|
||||
return Promise.reject(e);
|
||||
}
|
||||
})();
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||
return new VectorQuery(vectorQuery);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,9 +122,8 @@ export class RemoteTable extends Table {
|
||||
query(): import("..").Query {
|
||||
throw new Error("query() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
search(query: IntoVector): VectorQuery;
|
||||
search(query: string): Promise<VectorQuery>;
|
||||
search(_query: string | IntoVector): VectorQuery | Promise<VectorQuery> {
|
||||
|
||||
search(_query: string | IntoVector): VectorQuery {
|
||||
throw new Error("search() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
vectorSearch(_vector: unknown): import("..").VectorQuery {
|
||||
|
||||
@@ -244,9 +244,9 @@ export abstract class Table {
|
||||
* Create a search query to find the nearest neighbors
|
||||
* of the given query vector
|
||||
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
|
||||
* @rejects {Error} If no embedding functions are defined in the table
|
||||
* @note If no embedding functions are defined in the table, this will error when collecting the results.
|
||||
*/
|
||||
abstract search(query: string): Promise<VectorQuery>;
|
||||
abstract search(query: string): VectorQuery;
|
||||
/**
|
||||
* Create a search query to find the nearest neighbors
|
||||
* of the given query vector
|
||||
@@ -502,28 +502,26 @@ export class LocalTable extends Table {
|
||||
query(): Query {
|
||||
return new Query(this.inner);
|
||||
}
|
||||
|
||||
search(query: string): Promise<VectorQuery>;
|
||||
|
||||
search(query: IntoVector): VectorQuery;
|
||||
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
|
||||
search(query: string | IntoVector): VectorQuery {
|
||||
if (typeof query !== "string") {
|
||||
return this.vectorSearch(query);
|
||||
} else {
|
||||
return this.getEmbeddingFunctions().then(async (functions) => {
|
||||
// TODO: Support multiple embedding functions
|
||||
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
|
||||
.values()
|
||||
.next().value;
|
||||
if (!embeddingFunc) {
|
||||
return Promise.reject(
|
||||
new Error("No embedding functions are defined in the table"),
|
||||
);
|
||||
}
|
||||
const embeddings =
|
||||
await embeddingFunc.function.computeQueryEmbeddings(query);
|
||||
return this.query().nearestTo(embeddings);
|
||||
});
|
||||
const queryPromise = this.getEmbeddingFunctions().then(
|
||||
async (functions) => {
|
||||
// TODO: Support multiple embedding functions
|
||||
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
|
||||
.values()
|
||||
.next().value;
|
||||
if (!embeddingFunc) {
|
||||
return Promise.reject(
|
||||
new Error("No embedding functions are defined in the table"),
|
||||
);
|
||||
}
|
||||
return await embeddingFunc.function.computeQueryEmbeddings(query);
|
||||
},
|
||||
);
|
||||
|
||||
return this.query().nearestTo(queryPromise);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -80,6 +80,13 @@ impl Query {
|
||||
})?;
|
||||
Ok(RecordBatchIterator::new(inner_stream))
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn explain_plan(&self, verbose: bool) -> napi::Result<String> {
|
||||
self.inner.explain_plan(verbose).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to retrieve the query plan: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@@ -154,4 +161,11 @@ impl VectorQuery {
|
||||
})?;
|
||||
Ok(RecordBatchIterator::new(inner_stream))
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn explain_plan(&self, verbose: bool) -> napi::Result<String> {
|
||||
self.inner.explain_plan(verbose).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to retrieve the query plan: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ name = "lancedb"
|
||||
# version in Cargo.toml
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.13.0",
|
||||
"pylance==0.14.0",
|
||||
"ratelimiter~=1.0",
|
||||
"requests>=2.31.0",
|
||||
"retry>=0.9.2",
|
||||
|
||||
@@ -25,3 +25,4 @@ from .gte import GteEmbeddings
|
||||
from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings
|
||||
from .imagebind import ImageBindEmbeddings
|
||||
from .utils import with_embeddings
|
||||
from .jinaai import JinaEmbeddings
|
||||
|
||||
172
python/python/lancedb/embeddings/jinaai.py
Normal file
172
python/python/lancedb/embeddings/jinaai.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import io
|
||||
import requests
|
||||
import base64
|
||||
import urllib.parse as urlparse
|
||||
from typing import ClassVar, List, Union, Optional, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help, TEXT, IMAGES, url_retrieve
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
|
||||
API_URL = "https://api.jina.ai/v1/embeddings"
|
||||
|
||||
|
||||
@register("jina")
|
||||
class JinaEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the Jina API
|
||||
|
||||
https://jina.ai/embeddings/
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "jina-clip-v1". Note that some models support both image
|
||||
and text embeddings and some just text embedding
|
||||
|
||||
api_key: str, default None
|
||||
The api key to access Jina API. If you pass None, you can set JINA_API_KEY
|
||||
environment variable
|
||||
|
||||
"""
|
||||
|
||||
name: str = "jina-clip-v1"
|
||||
api_key: Optional[str] = None
|
||||
_session: ClassVar = None
|
||||
|
||||
def ndims(self):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
|
||||
def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(inputs, (str, bytes)):
|
||||
inputs = [inputs]
|
||||
elif isinstance(inputs, pa.Array):
|
||||
inputs = inputs.to_pylist()
|
||||
elif isinstance(inputs, pa.ChunkedArray):
|
||||
inputs = inputs.combine_chunks().to_pylist()
|
||||
return inputs
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return self.generate_text_embeddings([query])
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError(
|
||||
"JinaEmbeddingFunction supports str or PIL Image as query"
|
||||
)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_text_embeddings(texts)
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
image = {"image": base64.b64encode(image).decode("utf-8")}
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
image_bytes = buffered.getvalue()
|
||||
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
pil_image = PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
buffered = io.BytesIO()
|
||||
pil_image.save(buffered, format="PNG")
|
||||
image_bytes = buffered.getvalue()
|
||||
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||
return self._generate_embeddings(input=[image])[0]
|
||||
|
||||
def generate_text_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
return self._generate_embeddings(input=texts)
|
||||
|
||||
def _generate_embeddings(self, input: List, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
self._init_client()
|
||||
resp = JinaEmbeddings._session.post( # type: ignore
|
||||
API_URL, json={"input": input, "model": self.name}
|
||||
).json()
|
||||
if "data" not in resp:
|
||||
raise RuntimeError(resp["detail"])
|
||||
|
||||
embeddings = resp["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
def _init_client(self):
|
||||
if JinaEmbeddings._session is None:
|
||||
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
|
||||
api_key_not_found_help("jina")
|
||||
api_key = self.api_key or os.environ.get("JINA_API_KEY")
|
||||
JinaEmbeddings._session = requests.Session()
|
||||
JinaEmbeddings._session.headers.update(
|
||||
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
|
||||
)
|
||||
@@ -417,6 +417,40 @@ class LanceQueryBuilder(ABC):
|
||||
self._with_row_id = with_row_id
|
||||
return self
|
||||
|
||||
def explain_plan(self, verbose: Optional[bool] = False) -> str:
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
|
||||
>>> query = [100, 100]
|
||||
>>> plan = table.search(query).explain_plan(True)
|
||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
Projection: fields=[vector, _distance]
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST]
|
||||
KNNVectorDistance: metric=l2
|
||||
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
|
||||
|
||||
Parameters
|
||||
----------
|
||||
verbose : bool, default False
|
||||
Use a verbose output format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
plan : str
|
||||
""" # noqa: E501
|
||||
ds = self._table.to_lance()
|
||||
return ds.scanner(
|
||||
nearest={
|
||||
"column": self._vector_column,
|
||||
"q": self._query,
|
||||
},
|
||||
).explain_plan(verbose)
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
@@ -1166,6 +1200,37 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
return (await self.to_arrow()).to_pandas()
|
||||
|
||||
async def explain_plan(self, verbose: Optional[bool] = False):
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import asyncio
|
||||
>>> from lancedb import connect_async
|
||||
>>> async def doctest_example():
|
||||
... conn = await connect_async("./.lancedb")
|
||||
... table = await conn.create_table("my_table", [{"vector": [99, 99]}])
|
||||
... query = [100, 100]
|
||||
... plan = await table.query().nearest_to([1, 2]).explain_plan(True)
|
||||
... print(plan)
|
||||
>>> asyncio.run(doctest_example()) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
Projection: fields=[vector, _distance]
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST]
|
||||
KNNVectorDistance: metric=l2
|
||||
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
|
||||
|
||||
Parameters
|
||||
----------
|
||||
verbose : bool, default False
|
||||
Use a verbose output format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
plan : str
|
||||
""" # noqa: E501
|
||||
return await self._inner.explain_plan(verbose)
|
||||
|
||||
|
||||
class AsyncQuery(AsyncQueryBase):
|
||||
def __init__(self, inner: LanceQuery):
|
||||
|
||||
@@ -111,6 +111,7 @@ class RemoteTable(Table):
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
replace: Optional[bool] = None,
|
||||
accelerator: Optional[str] = None,
|
||||
index_type="vector",
|
||||
):
|
||||
"""Create an index on the table.
|
||||
Currently, the only parameters that matter are
|
||||
@@ -166,7 +167,6 @@ class RemoteTable(Table):
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
index_type = "vector"
|
||||
|
||||
data = {
|
||||
"column": vector_column_name,
|
||||
|
||||
@@ -4,7 +4,7 @@ from .colbert import ColbertReranker
|
||||
from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
from .jina import JinaReranker
|
||||
from .jinaai import JinaReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class JinaReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using Jina reranker model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "jinaai/jina-reranker-v1-turbo-en"
|
||||
The name of the reranker to use. For all models, see
|
||||
https://huggingface.co/jinaai/jina-reranker-v1-turbo-en
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
device : str, default None
|
||||
The device to use for the cross encoder model. If None, will use "cuda"
|
||||
if available, otherwise "cpu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "jinaai/jina-reranker-v1-turbo-en",
|
||||
column: str = "text",
|
||||
device: Union[str, None] = None,
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = attempt_import_or_raise("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self.model_name, num_labels=1, trust_remote_code=True
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
passages = result_set[self.column].to_pylist()
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.compute_score(cross_inp)
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
# sort the results by _score
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for CrossEncoderReranker"
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
return fts_results
|
||||
122
python/python/lancedb/rerankers/jinaai.py
Normal file
122
python/python/lancedb/rerankers/jinaai.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import os
|
||||
import requests
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
|
||||
API_URL = "https://api.jina.ai/v1/rerank"
|
||||
|
||||
|
||||
class JinaReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the Jina Rerank API.
|
||||
https://jina.ai/rerank
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "jina-reranker-v2-base-multilingual"
|
||||
The name of the cross reanker model to use
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
top_n : str, default None
|
||||
The number of results to return. If None, will return all results.
|
||||
api_key : str, default None
|
||||
The api key to access Jina API. If you pass None, you can set JINA_API_KEY
|
||||
environment variable
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "jina-reranker-v2-base-multilingual",
|
||||
column: str = "text",
|
||||
top_n: Union[int, None] = None,
|
||||
return_score="relevance",
|
||||
api_key: Union[str, None] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"JINA_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the JinaReranker."
|
||||
)
|
||||
self.api_key = self.api_key or os.environ.get("JINA_API_KEY")
|
||||
self._session = requests.Session()
|
||||
self._session.headers.update(
|
||||
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
|
||||
)
|
||||
return self._session
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.post( # type: ignore
|
||||
API_URL,
|
||||
json={
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"model": self.model_name,
|
||||
"top_n": self.top_n,
|
||||
},
|
||||
).json()
|
||||
if "results" not in response:
|
||||
raise RuntimeError(response["detail"])
|
||||
|
||||
results = response["results"]
|
||||
|
||||
indices, scores = list(
|
||||
zip(*[(result["index"], result["relevance_score"]) for result in results])
|
||||
) # tuples
|
||||
result_set = result_set.take(list(indices))
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for JinaReranker"
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["score"])
|
||||
|
||||
return result_set
|
||||
@@ -1173,11 +1173,12 @@ class LanceTable(Table):
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
index_type="IVF_PQ",
|
||||
):
|
||||
"""Create an index on the table."""
|
||||
self._dataset_mut.create_index(
|
||||
column=vector_column_name,
|
||||
index_type="IVF_PQ",
|
||||
index_type=index_type,
|
||||
metric=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
|
||||
@@ -333,3 +333,15 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
|
||||
df = await table_async.query().where("id < 0").to_pandas()
|
||||
assert df.shape == (0, 4)
|
||||
|
||||
|
||||
def test_explain_plan(table):
|
||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
plan = q.explain_plan(verbose=True)
|
||||
assert "KNN" in plan
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_plan_async(table_async: AsyncTable):
|
||||
plan = await table_async.query().nearest_to(pa.array([1, 2])).explain_plan(True)
|
||||
assert "KNN" in plan
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -9,7 +11,6 @@ from lancedb.rerankers import (
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
JinaReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -118,18 +119,136 @@ def test_linear_combination(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"reranker",
|
||||
[
|
||||
ColbertReranker(),
|
||||
OpenaiReranker(),
|
||||
CohereReranker(),
|
||||
CrossEncoderReranker(),
|
||||
JinaReranker(),
|
||||
],
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
def test_colbert_reranker(tmp_path, reranker):
|
||||
def test_cohere_reranker(tmp_path):
|
||||
pytest.importorskip("cohere")
|
||||
reranker = CohereReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# Hybrid search setting
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
query = "Our father who art in heaven"
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
reranker = ColbertReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
@@ -186,3 +305,67 @@ def test_colbert_reranker(tmp_path, reranker):
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
reranker = OpenaiReranker()
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
@@ -19,6 +19,7 @@ use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::pyclass;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::PyAny;
|
||||
@@ -73,6 +74,16 @@ impl Query {
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
}
|
||||
|
||||
fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.explain_plan(verbose)
|
||||
.await
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -131,4 +142,14 @@ impl VectorQuery {
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
}
|
||||
|
||||
fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.explain_plan(verbose)
|
||||
.await
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1191,6 +1191,7 @@ mod tests {
|
||||
.query()
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 50000,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
@@ -1211,6 +1212,7 @@ mod tests {
|
||||
.query()
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 50000,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
|
||||
@@ -374,6 +374,16 @@ pub trait QueryBase {
|
||||
/// Columns will always be returned in the order given, even if that order is different than
|
||||
/// the order used when adding the data.
|
||||
fn select(self, selection: Select) -> Self;
|
||||
|
||||
/// Only execute the query over indexed data.
|
||||
///
|
||||
/// This allows weak-consistent fast path for queries that only need to access the indexed data.
|
||||
///
|
||||
/// Users can use [`crate::Table::optimize`] to merge new data into the index, and make the
|
||||
/// new data available for fast search.
|
||||
///
|
||||
/// By default, it is false.
|
||||
fn fast_search(self) -> Self;
|
||||
}
|
||||
|
||||
pub trait HasQuery {
|
||||
@@ -395,6 +405,11 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self.mut_query().select = select;
|
||||
self
|
||||
}
|
||||
|
||||
fn fast_search(mut self) -> Self {
|
||||
self.mut_query().fast_search = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for controlling the execution of a query
|
||||
@@ -465,6 +480,8 @@ pub trait ExecutableQuery {
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> impl Future<Output = Result<SendableRecordBatchStream>> + Send;
|
||||
|
||||
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
|
||||
}
|
||||
|
||||
/// A builder for LanceDB queries.
|
||||
@@ -489,6 +506,12 @@ pub struct Query {
|
||||
pub(crate) filter: Option<String>,
|
||||
/// Select column projection.
|
||||
pub(crate) select: Select,
|
||||
|
||||
/// If set to true, the query is executed only on the indexed data,
|
||||
/// and yields faster results.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) fast_search: bool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -498,6 +521,7 @@ impl Query {
|
||||
limit: None,
|
||||
filter: None,
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,6 +596,12 @@ impl ExecutableQuery for Query {
|
||||
self.parent.clone().plain_query(self, options).await?,
|
||||
))
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, verbose: bool) -> Result<String> {
|
||||
self.parent
|
||||
.explain_plan(&self.clone().into_vector(), verbose)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for vector searches
|
||||
@@ -752,6 +782,10 @@ impl ExecutableQuery for VectorQuery {
|
||||
)?),
|
||||
))
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, verbose: bool) -> Result<String> {
|
||||
self.base.parent.explain_plan(self, verbose).await
|
||||
}
|
||||
}
|
||||
|
||||
impl HasQuery for VectorQuery {
|
||||
@@ -989,6 +1023,7 @@ mod tests {
|
||||
.query()
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 10,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1053,4 +1088,20 @@ mod tests {
|
||||
.to_string()
|
||||
.contains("No vector column found to match with the query vector dimension: 3"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fast_search_plan() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let plan = table
|
||||
.query()
|
||||
.select(Select::columns(&["_distance"]))
|
||||
.nearest_to(vec![0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.fast_search()
|
||||
.explain_plan(true)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!plan.contains("Take"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::table::dataset::DatasetReadGuard;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::SchemaRef;
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
||||
|
||||
use crate::{
|
||||
connection::NoData,
|
||||
@@ -74,6 +76,14 @@ impl TableInternal for RemoteTable {
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
async fn build_plan(
|
||||
&self,
|
||||
_ds_ref: &DatasetReadGuard,
|
||||
_query: &VectorQuery,
|
||||
_options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner> {
|
||||
todo!()
|
||||
}
|
||||
async fn create_plan(
|
||||
&self,
|
||||
_query: &VectorQuery,
|
||||
@@ -81,6 +91,9 @@ impl TableInternal for RemoteTable {
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
unimplemented!()
|
||||
}
|
||||
async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result<String> {
|
||||
todo!()
|
||||
}
|
||||
async fn plain_query(
|
||||
&self,
|
||||
_query: &Query,
|
||||
|
||||
@@ -35,6 +35,7 @@ use lance::dataset::{
|
||||
Dataset, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams,
|
||||
};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::index::scalar::ScalarIndexType;
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::vector::hnsw::builder::HnswBuildParams;
|
||||
@@ -65,7 +66,7 @@ use crate::query::{
|
||||
};
|
||||
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
|
||||
|
||||
use self::dataset::DatasetConsistencyWrapper;
|
||||
use self::dataset::{DatasetConsistencyWrapper, DatasetReadGuard};
|
||||
use self::merge::MergeInsertBuilder;
|
||||
|
||||
pub(crate) mod dataset;
|
||||
@@ -369,6 +370,12 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
/// Count the number of rows in this table.
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
async fn build_plan(
|
||||
&self,
|
||||
ds_ref: &DatasetReadGuard,
|
||||
query: &VectorQuery,
|
||||
options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner>;
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
@@ -379,6 +386,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
query: &Query,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream>;
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String>;
|
||||
async fn add(
|
||||
&self,
|
||||
add: AddDataBuilder<NoData>,
|
||||
@@ -1270,22 +1278,25 @@ impl NativeTable {
|
||||
|
||||
/// Get statistics about an index.
|
||||
/// Returns an error if the index does not exist.
|
||||
pub async fn index_stats<S: AsRef<str>>(
|
||||
pub async fn index_stats(
|
||||
&self,
|
||||
index_name: S,
|
||||
index_name: impl AsRef<str>,
|
||||
) -> Result<Option<IndexStatistics>> {
|
||||
self.dataset
|
||||
let stats = match self
|
||||
.dataset
|
||||
.get()
|
||||
.await?
|
||||
.index_statistics(index_name.as_ref())
|
||||
.await
|
||||
.ok()
|
||||
.map(|stats| {
|
||||
serde_json::from_str(&stats).map_err(|e| Error::InvalidInput {
|
||||
message: format!("error deserializing index statistics: {}", e),
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
{
|
||||
Ok(stats) => stats,
|
||||
Err(lance::error::Error::IndexNotFound { .. }) => return Ok(None),
|
||||
Err(e) => return Err(Error::from(e)),
|
||||
};
|
||||
|
||||
serde_json::from_str(&stats).map_err(|e| Error::InvalidInput {
|
||||
message: format!("error deserializing index statistics: {}", e),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
@@ -1493,7 +1504,9 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance::index::scalar::ScalarIndexParams {};
|
||||
let lance_idx_params = lance::index::scalar::ScalarIndexParams {
|
||||
force_index_type: Some(ScalarIndexType::BTree),
|
||||
};
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
@@ -1667,12 +1680,12 @@ impl TableInternal for NativeTable {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
async fn build_plan(
|
||||
&self,
|
||||
ds_ref: &DatasetReadGuard,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner> {
|
||||
let mut scanner: Scanner = ds_ref.scan();
|
||||
|
||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||
@@ -1684,9 +1697,11 @@ impl TableInternal for NativeTable {
|
||||
let arrow_schema = Schema::from(ds_ref.schema());
|
||||
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
|
||||
};
|
||||
|
||||
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
|
||||
message: format!("Column {} not found in dataset schema", column),
|
||||
})?;
|
||||
|
||||
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
||||
if !f.data_type().is_floating() {
|
||||
return Err(Error::InvalidInput {
|
||||
@@ -1698,16 +1713,17 @@ impl TableInternal for NativeTable {
|
||||
}
|
||||
if dim != query_vector.len() as i32 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
|
||||
query dim={}, expected vector dim={}",
|
||||
column,
|
||||
query_vector.len(),
|
||||
dim,
|
||||
),
|
||||
});
|
||||
message: format!(
|
||||
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
|
||||
query dim={}, expected vector dim={}",
|
||||
column,
|
||||
query_vector.len(),
|
||||
dim,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
&column,
|
||||
@@ -1718,10 +1734,38 @@ impl TableInternal for NativeTable {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
|
||||
}
|
||||
|
||||
scanner.nprobs(query.nprobes);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.prefilter);
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
match query.base.select {
|
||||
Select::Columns(ref columns) => {
|
||||
scanner.project(columns.as_slice())?;
|
||||
}
|
||||
Select::Dynamic(ref select_with_transform) => {
|
||||
scanner.project_with_transform(select_with_transform.as_slice())?;
|
||||
}
|
||||
Select::All => {}
|
||||
}
|
||||
|
||||
if let Some(opts) = options {
|
||||
scanner.batch_size(opts.max_batch_length as usize);
|
||||
}
|
||||
if query.base.fast_search {
|
||||
scanner.fast_search();
|
||||
}
|
||||
|
||||
Ok(scanner)
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let mut scanner = self.build_plan(&ds_ref, query, Some(options)).await?;
|
||||
|
||||
match &query.base.select {
|
||||
Select::Columns(select) => {
|
||||
@@ -1744,6 +1788,7 @@ impl TableInternal for NativeTable {
|
||||
if let Some(distance_type) = query.distance_type {
|
||||
scanner.distance_metric(distance_type.into());
|
||||
}
|
||||
|
||||
Ok(scanner.create_plan().await?)
|
||||
}
|
||||
|
||||
@@ -1756,6 +1801,16 @@ impl TableInternal for NativeTable {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let scanner = self.build_plan(&ds_ref, query, None).await?;
|
||||
|
||||
let plan = scanner.explain_plan(verbose).await?;
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
|
||||
Reference in New Issue
Block a user