mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 22:29:58 +00:00
Compare commits
13 Commits
python-v0.
...
v0.11.0-be
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b72ac073ab | ||
|
|
3152ccd13c | ||
|
|
d5021356b4 | ||
|
|
e82f63b40a | ||
|
|
f81ce68e41 | ||
|
|
f5c25b6fff | ||
|
|
86978e7588 | ||
|
|
7c314d61cc | ||
|
|
7a8d2f37c4 | ||
|
|
11072b9edc | ||
|
|
915d828cee | ||
|
|
d9a72adc58 | ||
|
|
d6cf2dafc6 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.10.0"
|
||||
current_version = "0.11.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
3
.github/workflows/rust.yml
vendored
3
.github/workflows/rust.yml
vendored
@@ -30,7 +30,6 @@ jobs:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: rust
|
||||
env:
|
||||
# Need up-to-date compilers for kernels
|
||||
CC: gcc-12
|
||||
@@ -50,7 +49,7 @@ jobs:
|
||||
- name: Run format
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Run clippy
|
||||
run: cargo clippy --all --all-features -- -D warnings
|
||||
run: cargo clippy --workspace --tests --all-features -- -D warnings
|
||||
linux:
|
||||
timeout-minutes: 30
|
||||
# To build all features, we need more disk space than is available
|
||||
|
||||
@@ -34,6 +34,7 @@ theme:
|
||||
- navigation.footer
|
||||
- navigation.tracking
|
||||
- navigation.instant
|
||||
- content.footnote.tooltips
|
||||
icon:
|
||||
repo: fontawesome/brands/github
|
||||
annotation: material/arrow-right-circle
|
||||
@@ -65,6 +66,11 @@ plugins:
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
- footnotes
|
||||
- pymdownx.critic
|
||||
- pymdownx.caret
|
||||
- pymdownx.keys
|
||||
- pymdownx.mark
|
||||
- pymdownx.tilde
|
||||
- pymdownx.details
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
@@ -177,6 +183,7 @@ nav:
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- dlt: integrations/dlt.md
|
||||
- phidata: integrations/phidata.md
|
||||
- 🎯 Examples:
|
||||
- Overview: examples/index.md
|
||||
- 🐍 Python:
|
||||
@@ -299,6 +306,7 @@ nav:
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- dlt: integrations/dlt.md
|
||||
- phidata: integrations/phidata.md
|
||||
- Examples:
|
||||
- examples/index.md
|
||||
- 🐍 Python:
|
||||
|
||||
383
docs/src/integrations/phidata.md
Normal file
383
docs/src/integrations/phidata.md
Normal file
@@ -0,0 +1,383 @@
|
||||
**phidata** is a framework for building **AI Assistants** with long-term memory, contextual knowledge, and the ability to take actions using function calling. It helps turn general-purpose LLMs into specialized assistants tailored to your use case by extending its capabilities using **memory**, **knowledge**, and **tools**.
|
||||
|
||||
- **Memory**: Stores chat history in a **database** and enables LLMs to have long-term conversations.
|
||||
- **Knowledge**: Stores information in a **vector database** and provides LLMs with business context. (Here we will use LanceDB)
|
||||
- **Tools**: Enable LLMs to take actions like pulling data from an **API**, **sending emails** or **querying a database**, etc.
|
||||
|
||||

|
||||
|
||||
Memory & knowledge make LLMs smarter while tools make them autonomous.
|
||||
|
||||
LanceDB is a vector database and its integration into phidata makes it easy for us to provide a **knowledge base** to LLMs. It enables us to store information as [embeddings](../embeddings/understanding_embeddings.md) and search for the **results** similar to ours using **query**.
|
||||
|
||||
??? Question "What is Knowledge Base?"
|
||||
Knowledge Base is a database of information that the Assistant can search to improve its responses. This information is stored in a vector database and provides LLMs with business context, which makes them respond in a context-aware manner.
|
||||
|
||||
While any type of storage can act as a knowledge base, vector databases offer the best solution for retrieving relevant results from dense information quickly.
|
||||
|
||||
Let's see how using LanceDB inside phidata helps in making LLM more useful:
|
||||
|
||||
## Prerequisites: install and import necessary dependencies
|
||||
|
||||
**Create a virtual environment**
|
||||
|
||||
1. install virtualenv package
|
||||
```python
|
||||
pip install virtualenv
|
||||
```
|
||||
2. Create a directory for your project and go to the directory and create a virtual environment inside it.
|
||||
```python
|
||||
mkdir phi
|
||||
```
|
||||
```python
|
||||
cd phi
|
||||
```
|
||||
```python
|
||||
python -m venv phidata_
|
||||
```
|
||||
|
||||
**Activating virtual environment**
|
||||
|
||||
1. from inside the project directory, run the following command to activate the virtual environment.
|
||||
```python
|
||||
phidata_/Scripts/activate
|
||||
```
|
||||
|
||||
**Install the following packages in the virtual environment**
|
||||
```python
|
||||
pip install lancedb phidata youtube_transcript_api openai ollama pandas numpy
|
||||
```
|
||||
|
||||
**Create python files and import necessary libraries**
|
||||
|
||||
You need to create two files - `transcript.py` and `ollama_assistant.py` or `openai_assistant.py`
|
||||
|
||||
=== "openai_assistant.py"
|
||||
|
||||
```python
|
||||
import os, openai
|
||||
from rich.prompt import Prompt
|
||||
from phi.assistant import Assistant
|
||||
from phi.knowledge.text import TextKnowledgeBase
|
||||
from phi.vectordb.lancedb import LanceDb
|
||||
from phi.llm.openai import OpenAIChat
|
||||
from phi.embedder.openai import OpenAIEmbedder
|
||||
from transcript import extract_transcript
|
||||
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
|
||||
# The code below creates a file "transcript.txt" in the directory, the txt file will be used below
|
||||
youtube_url = "https://www.youtube.com/watch?v=Xs33-Gzl8Mo"
|
||||
segment_duration = 20
|
||||
transcript_text,dict_transcript = extract_transcript(youtube_url,segment_duration)
|
||||
```
|
||||
|
||||
=== "ollama_assistant.py"
|
||||
|
||||
```python
|
||||
from rich.prompt import Prompt
|
||||
from phi.assistant import Assistant
|
||||
from phi.knowledge.text import TextKnowledgeBase
|
||||
from phi.vectordb.lancedb import LanceDb
|
||||
from phi.llm.ollama import Ollama
|
||||
from phi.embedder.ollama import OllamaEmbedder
|
||||
from transcript import extract_transcript
|
||||
|
||||
# The code below creates a file "transcript.txt" in the directory, the txt file will be used below
|
||||
youtube_url = "https://www.youtube.com/watch?v=Xs33-Gzl8Mo"
|
||||
segment_duration = 20
|
||||
transcript_text,dict_transcript = extract_transcript(youtube_url,segment_duration)
|
||||
```
|
||||
|
||||
=== "transcript.py"
|
||||
|
||||
``` python
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
import re
|
||||
|
||||
def smodify(seconds):
|
||||
hours, remainder = divmod(seconds, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"
|
||||
|
||||
def extract_transcript(youtube_url,segment_duration):
|
||||
# Extract video ID from the URL
|
||||
video_id = re.search(r'(?<=v=)[\w-]+', youtube_url)
|
||||
if not video_id:
|
||||
video_id = re.search(r'(?<=be/)[\w-]+', youtube_url)
|
||||
if not video_id:
|
||||
return None
|
||||
|
||||
video_id = video_id.group(0)
|
||||
|
||||
# Attempt to fetch the transcript
|
||||
try:
|
||||
# Try to get the official transcript
|
||||
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
|
||||
except Exception:
|
||||
# If no official transcript is found, try to get auto-generated transcript
|
||||
try:
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
|
||||
for transcript in transcript_list:
|
||||
transcript = transcript.translate('en').fetch()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Format the transcript into 120s chunks
|
||||
transcript_text,dict_transcript = format_transcript(transcript,segment_duration)
|
||||
# Open the file in write mode, which creates it if it doesn't exist
|
||||
with open("transcript.txt", "w",encoding="utf-8") as file:
|
||||
file.write(transcript_text)
|
||||
return transcript_text,dict_transcript
|
||||
|
||||
def format_transcript(transcript,segment_duration):
|
||||
chunked_transcript = []
|
||||
chunk_dict = []
|
||||
current_chunk = []
|
||||
current_time = 0
|
||||
# 2 minutes in seconds
|
||||
start_time_chunk = 0 # To track the start time of the current chunk
|
||||
|
||||
for segment in transcript:
|
||||
start_time = segment['start']
|
||||
end_time_x = start_time + segment['duration']
|
||||
text = segment['text']
|
||||
|
||||
# Add text to the current chunk
|
||||
current_chunk.append(text)
|
||||
|
||||
# Update the current time with the duration of the current segment
|
||||
# The duration of the current segment is given by segment['start'] - start_time_chunk
|
||||
if current_chunk:
|
||||
current_time = start_time - start_time_chunk
|
||||
|
||||
# If current chunk duration reaches or exceeds 2 minutes, save the chunk
|
||||
if current_time >= segment_duration:
|
||||
# Use the start time of the first segment in the current chunk as the timestamp
|
||||
chunked_transcript.append(f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}] " + " ".join(current_chunk))
|
||||
current_chunk = re.sub(r'[\xa0\n]', lambda x: '' if x.group() == '\xa0' else ' ', "\n".join(current_chunk))
|
||||
chunk_dict.append({"timestamp":f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}]", "text": "".join(current_chunk)})
|
||||
current_chunk = [] # Reset the chunk
|
||||
start_time_chunk = start_time + segment['duration'] # Update the start time for the next chunk
|
||||
current_time = 0 # Reset current time
|
||||
|
||||
# Add any remaining text in the last chunk
|
||||
if current_chunk:
|
||||
chunked_transcript.append(f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}] " + " ".join(current_chunk))
|
||||
current_chunk = re.sub(r'[\xa0\n]', lambda x: '' if x.group() == '\xa0' else ' ', "\n".join(current_chunk))
|
||||
chunk_dict.append({"timestamp":f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}]", "text": "".join(current_chunk)})
|
||||
|
||||
return "\n\n".join(chunked_transcript), chunk_dict
|
||||
```
|
||||
|
||||
!!! warning
|
||||
If creating Ollama assistant, download and install Ollama [from here](https://ollama.com/) and then run the Ollama instance in the background. Also, download the required models using `ollama pull <model-name>`. Check out the models [here](https://ollama.com/library)
|
||||
|
||||
|
||||
**Run the following command to deactivate the virtual environment if needed**
|
||||
```python
|
||||
deactivate
|
||||
```
|
||||
|
||||
## **Step 1** - Create a Knowledge Base for AI Assistant using LanceDB
|
||||
|
||||
=== "openai_assistant.py"
|
||||
|
||||
```python
|
||||
# Create knowledge Base with OpenAIEmbedder in LanceDB
|
||||
knowledge_base = TextKnowledgeBase(
|
||||
path="transcript.txt",
|
||||
vector_db=LanceDb(
|
||||
embedder=OpenAIEmbedder(api_key = openai.api_key),
|
||||
table_name="transcript_documents",
|
||||
uri="./t3mp/.lancedb",
|
||||
),
|
||||
num_documents = 10
|
||||
)
|
||||
```
|
||||
|
||||
=== "ollama_assistant.py"
|
||||
|
||||
```python
|
||||
# Create knowledge Base with OllamaEmbedder in LanceDB
|
||||
knowledge_base = TextKnowledgeBase(
|
||||
path="transcript.txt",
|
||||
vector_db=LanceDb(
|
||||
embedder=OllamaEmbedder(model="nomic-embed-text",dimensions=768),
|
||||
table_name="transcript_documents",
|
||||
uri="./t2mp/.lancedb",
|
||||
),
|
||||
num_documents = 10
|
||||
)
|
||||
```
|
||||
Check out the list of **embedders** supported by **phidata** and their usage [here](https://docs.phidata.com/embedder/introduction).
|
||||
|
||||
Here we have used `TextKnowledgeBase`, which loads text/docx files to the knowledge base.
|
||||
|
||||
Let's see all the parameters that `TextKnowledgeBase` takes -
|
||||
|
||||
| Name| Type | Purpose | Default |
|
||||
|:----|:-----|:--------|:--------|
|
||||
|`path`|`Union[str, Path]`| Path to text file(s). It can point to a single text file or a directory of text files.| provided by user |
|
||||
|`formats`|`List[str]`| File formats accepted by this knowledge base. |`[".txt"]`|
|
||||
|`vector_db`|`VectorDb`| Vector Database for the Knowledge Base. phidata provides a wrapper around many vector DBs, you can import it like this - `from phi.vectordb.lancedb import LanceDb` | provided by user |
|
||||
|`num_documents`|`int`| Number of results (documents/vectors) that vector search should return. |`5`|
|
||||
|`reader`|`TextReader`| phidata provides many types of reader objects which read data, clean it and create chunks of data, encapsulate each chunk inside an object of the `Document` class, and return **`List[Document]`**. | `TextReader()` |
|
||||
|`optimize_on`|`int`| It is used to specify the number of documents on which to optimize the vector database. Supposed to create an index. |`1000`|
|
||||
|
||||
??? Tip "Wonder! What is `Document` class?"
|
||||
We know that, before storing the data in vectorDB, we need to split the data into smaller chunks upon which embeddings will be created and these embeddings along with the chunks will be stored in vectorDB. When the user queries over the vectorDB, some of these embeddings will be returned as the result based on the semantic similarity with the query.
|
||||
|
||||
When the user queries over vectorDB, the queries are converted into embeddings, and a nearest neighbor search is performed over these query embeddings which returns the embeddings that correspond to most semantically similar chunks(parts of our data) present in vectorDB.
|
||||
|
||||
Here, a “Document” is a class in phidata. Since there is an option to let phidata create and manage embeddings, it splits our data into smaller chunks(as expected). It does not directly create embeddings on it. Instead, it takes each chunk and encapsulates it inside the object of the `Document` class along with various other metadata related to the chunk. Then embeddings are created on these `Document` objects and stored in vectorDB.
|
||||
|
||||
```python
|
||||
class Document(BaseModel):
|
||||
"""Model for managing a document"""
|
||||
|
||||
content: str # <--- here data of chunk is stored
|
||||
id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
meta_data: Dict[str, Any] = {}
|
||||
embedder: Optional[Embedder] = None
|
||||
embedding: Optional[List[float]] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
```
|
||||
|
||||
However, using phidata you can load many other types of data in the knowledge base(other than text). Check out [phidata Knowledge Base](https://docs.phidata.com/knowledge/introduction) for more information.
|
||||
|
||||
Let's dig deeper into the `vector_db` parameter and see what parameters `LanceDb` takes -
|
||||
|
||||
| Name| Type | Purpose | Default |
|
||||
|:----|:-----|:--------|:--------|
|
||||
|`embedder`|`Embedder`| phidata provides many Embedders that abstract the interaction with embedding APIs and utilize it to generate embeddings. Check out other embedders [here](https://docs.phidata.com/embedder/introduction) | `OpenAIEmbedder` |
|
||||
|`distance`|`List[str]`| The choice of distance metric used to calculate the similarity between vectors, which directly impacts search results and performance in vector databases. |`Distance.cosine`|
|
||||
|`connection`|`lancedb.db.LanceTable`| LanceTable can be accessed through `.connection`. You can connect to an existing table of LanceDB, created outside of phidata, and utilize it. If not provided, it creates a new table using `table_name` parameter and adds it to `connection`. |`None`|
|
||||
|`uri`|`str`| It specifies the directory location of **LanceDB database** and establishes a connection that can be used to interact with the database. | `"/tmp/lancedb"` |
|
||||
|`table_name`|`str`| If `connection` is not provided, it initializes and connects to a new **LanceDB table** with a specified(or default) name in the database present at `uri`. |`"phi"`|
|
||||
|`nprobes`|`int`| It refers to the number of partitions that the search algorithm examines to find the nearest neighbors of a given query vector. Higher values will yield better recall (more likely to find vectors if they exist) at the expense of latency. |`20`|
|
||||
|
||||
|
||||
!!! note
|
||||
Since we just initialized the KnowledgeBase. The VectorDB table that corresponds to this Knowledge Base is not yet populated with our data. It will be populated in **Step 3**, once we perform the `load` operation.
|
||||
|
||||
You can check the state of the LanceDB table using - `knowledge_base.vector_db.connection.to_pandas()`
|
||||
|
||||
Now that the Knowledge Base is initialized, , we can go to **step 2**.
|
||||
|
||||
## **Step 2** - Create an assistant with our choice of LLM and reference to the knowledge base.
|
||||
|
||||
|
||||
=== "openai_assistant.py"
|
||||
|
||||
```python
|
||||
# define an assistant with gpt-4o-mini llm and reference to the knowledge base created above
|
||||
assistant = Assistant(
|
||||
llm=OpenAIChat(model="gpt-4o-mini", max_tokens=1000, temperature=0.3,api_key = openai.api_key),
|
||||
description="""You are an Expert in explaining youtube video transcripts. You are a bot that takes transcript of a video and answer the question based on it.
|
||||
|
||||
This is transcript for the above timestamp: {relevant_document}
|
||||
The user input is: {user_input}
|
||||
generate highlights only when asked.
|
||||
When asked to generate highlights from the video, understand the context for each timestamp and create key highlight points, answer in following way -
|
||||
[timestamp] - highlight 1
|
||||
[timestamp] - highlight 2
|
||||
... so on
|
||||
|
||||
Your task is to understand the user question, and provide an answer using the provided contexts. Your answers are correct, high-quality, and written by an domain expert. If the provided context does not contain the answer, simply state,'The provided context does not have the answer.'""",
|
||||
knowledge_base=knowledge_base,
|
||||
add_references_to_prompt=True,
|
||||
)
|
||||
```
|
||||
|
||||
=== "ollama_assistant.py"
|
||||
|
||||
```python
|
||||
# define an assistant with llama3.1 llm and reference to the knowledge base created above
|
||||
assistant = Assistant(
|
||||
llm=Ollama(model="llama3.1"),
|
||||
description="""You are an Expert in explaining youtube video transcripts. You are a bot that takes transcript of a video and answer the question based on it.
|
||||
|
||||
This is transcript for the above timestamp: {relevant_document}
|
||||
The user input is: {user_input}
|
||||
generate highlights only when asked.
|
||||
When asked to generate highlights from the video, understand the context for each timestamp and create key highlight points, answer in following way -
|
||||
[timestamp] - highlight 1
|
||||
[timestamp] - highlight 2
|
||||
... so on
|
||||
|
||||
Your task is to understand the user question, and provide an answer using the provided contexts. Your answers are correct, high-quality, and written by an domain expert. If the provided context does not contain the answer, simply state,'The provided context does not have the answer.'""",
|
||||
knowledge_base=knowledge_base,
|
||||
add_references_to_prompt=True,
|
||||
)
|
||||
```
|
||||
|
||||
Assistants add **memory**, **knowledge**, and **tools** to LLMs. Here we will add only **knowledge** in this example.
|
||||
|
||||
Whenever we will give a query to LLM, the assistant will retrieve relevant information from our **Knowledge Base**(table in LanceDB) and pass it to LLM along with the user query in a structured way.
|
||||
|
||||
- The `add_references_to_prompt=True` always adds information from the knowledge base to the prompt, regardless of whether it is relevant to the question.
|
||||
|
||||
To know more about an creating assistant in phidata, check out [phidata docs](https://docs.phidata.com/assistants/introduction) here.
|
||||
|
||||
## **Step 3** - Load data to Knowledge Base.
|
||||
|
||||
```python
|
||||
# load out data into the knowledge_base (populating the LanceTable)
|
||||
assistant.knowledge_base.load(recreate=False)
|
||||
```
|
||||
The above code loads the data to the Knowledge Base(LanceDB Table) and now it is ready to be used by the assistant.
|
||||
|
||||
| Name| Type | Purpose | Default |
|
||||
|:----|:-----|:--------|:--------|
|
||||
|`recreate`|`bool`| If True, it drops the existing table and recreates the table in the vectorDB. |`False`|
|
||||
|`upsert`|`bool`| If True and the vectorDB supports upsert, it will upsert documents to the vector db. | `False` |
|
||||
|`skip_existing`|`bool`| If True, skips documents that already exist in the vectorDB when inserting. |`True`|
|
||||
|
||||
??? tip "What is upsert?"
|
||||
Upsert is a database operation that combines "update" and "insert". It updates existing records if a document with the same identifier does exist, or inserts new records if no matching record exists. This is useful for maintaining the most current information without manually checking for existence.
|
||||
|
||||
During the Load operation, phidata directly interacts with the LanceDB library and performs the loading of the table with our data in the following steps -
|
||||
|
||||
1. **Creates** and **initializes** the table if it does not exist.
|
||||
|
||||
2. Then it **splits** our data into smaller **chunks**.
|
||||
|
||||
??? question "How do they create chunks?"
|
||||
**phidata** provides many types of **Knowledge Bases** based on the type of data. Most of them :material-information-outline:{ title="except LlamaIndexKnowledgeBase and LangChainKnowledgeBase"} has a property method called `document_lists` of type `Iterator[List[Document]]`. During the load operation, this property method is invoked. It traverses on the data provided by us (in this case, a text file(s)) using `reader`. Then it **reads**, **creates chunks**, and **encapsulates** each chunk inside a `Document` object and yields **lists of `Document` objects** that contain our data.
|
||||
|
||||
3. Then **embeddings** are created on these chunks are **inserted** into the LanceDB Table
|
||||
|
||||
??? question "How do they insert your data as different rows in LanceDB Table?"
|
||||
The chunks of your data are in the form - **lists of `Document` objects**. It was yielded in the step above.
|
||||
|
||||
for each `Document` in `List[Document]`, it does the following operations:
|
||||
|
||||
- Creates embedding on `Document`.
|
||||
- Cleans the **content attribute**(chunks of our data is here) of `Document`.
|
||||
- Prepares data by creating `id` and loading `payload` with the metadata related to this chunk. (1)
|
||||
{ .annotate }
|
||||
|
||||
1. Three columns will be added to the table - `"id"`, `"vector"`, and `"payload"` (payload contains various metadata including **`content`**)
|
||||
|
||||
- Then add this data to LanceTable.
|
||||
|
||||
4. Now the internal state of `knowledge_base` is changed (embeddings are created and loaded in the table ) and it **ready to be used by assistant**.
|
||||
|
||||
## **Step 4** - Start a cli chatbot with access to the Knowledge base
|
||||
|
||||
```python
|
||||
# start cli chatbot with knowledge base
|
||||
assistant.print_response("Ask me about something from the knowledge base")
|
||||
while True:
|
||||
message = Prompt.ask(f"[bold] :sunglasses: User [/bold]")
|
||||
if message in ("exit", "bye"):
|
||||
break
|
||||
assistant.print_response(message, markdown=True)
|
||||
```
|
||||
|
||||
|
||||
For more information and amazing cookbooks of phidata, read the [phidata documentation](https://docs.phidata.com/introduction) and also visit [LanceDB x phidata docmentation](https://docs.phidata.com/vectordb/lancedb).
|
||||
@@ -1,6 +1,9 @@
|
||||
# Linear Combination Reranker
|
||||
|
||||
This is the default re-ranker used by LanceDB hybrid search. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||
!!! note
|
||||
This is depricated. It is recommended to use the `RRFReranker` instead, if you want to use a score based reranker.
|
||||
|
||||
It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Reciprocal Rank Fusion Reranker
|
||||
|
||||
Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf).
|
||||
This is the default re-ranker used by LanceDB hybrid search. Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf).
|
||||
|
||||
|
||||
!!! note
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.10.0</version>
|
||||
<version>0.11.0-beta.1</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.10.0</version>
|
||||
<version>0.11.0-beta.1</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
4
node/package-lock.json
generated
4
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
|
||||
@@ -220,7 +220,8 @@ export async function connect(
|
||||
region: partOpts.region ?? defaultRegion,
|
||||
timeout: partOpts.timeout ?? defaultRequestTimeout,
|
||||
readConsistencyInterval: partOpts.readConsistencyInterval ?? undefined,
|
||||
storageOptions: partOpts.storageOptions ?? undefined
|
||||
storageOptions: partOpts.storageOptions ?? undefined,
|
||||
hostOverride: partOpts.hostOverride ?? undefined
|
||||
}
|
||||
if (opts.uri.startsWith("db://")) {
|
||||
// Remote connection
|
||||
|
||||
@@ -33,6 +33,7 @@ export class Query<T = number[]> {
|
||||
private _filter?: string
|
||||
private _metricType?: MetricType
|
||||
private _prefilter: boolean
|
||||
private _fastSearch: boolean
|
||||
protected readonly _embeddings?: EmbeddingFunction<T>
|
||||
|
||||
constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||
@@ -46,6 +47,7 @@ export class Query<T = number[]> {
|
||||
this._metricType = undefined
|
||||
this._embeddings = embeddings
|
||||
this._prefilter = false
|
||||
this._fastSearch = false
|
||||
}
|
||||
|
||||
/***
|
||||
@@ -110,6 +112,15 @@ export class Query<T = number[]> {
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Skip searching un-indexed data. This can make search faster, but will miss
|
||||
* any data that is not yet indexed.
|
||||
*/
|
||||
fastSearch (value: boolean): Query<T> {
|
||||
this._fastSearch = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the query and return the results as an Array of Objects
|
||||
*/
|
||||
|
||||
@@ -151,7 +151,8 @@ export class HttpLancedbClient {
|
||||
prefilter: boolean,
|
||||
refineFactor?: number,
|
||||
columns?: string[],
|
||||
filter?: string
|
||||
filter?: string,
|
||||
fastSearch?: boolean
|
||||
): Promise<ArrowTable<any>> {
|
||||
const result = await this.post(
|
||||
`/v1/table/${tableName}/query/`,
|
||||
@@ -162,7 +163,8 @@ export class HttpLancedbClient {
|
||||
refineFactor,
|
||||
columns,
|
||||
filter,
|
||||
prefilter
|
||||
prefilter,
|
||||
fast_search: fastSearch
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
|
||||
@@ -238,7 +238,8 @@ export class RemoteQuery<T = number[]> extends Query<T> {
|
||||
(this as any)._prefilter,
|
||||
(this as any)._refineFactor,
|
||||
(this as any)._select,
|
||||
(this as any)._filter
|
||||
(this as any)._filter,
|
||||
(this as any)._fastSearch
|
||||
)
|
||||
|
||||
return data.toArray().map((entry: Record<string, unknown>) => {
|
||||
|
||||
208
nodejs/native.d.ts
vendored
208
nodejs/native.d.ts
vendored
@@ -1,208 +0,0 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/* auto-generated by NAPI-RS */
|
||||
|
||||
/** A description of an index currently configured on a column */
|
||||
export interface IndexConfig {
|
||||
/** The name of the index */
|
||||
name: string
|
||||
/** The type of the index */
|
||||
indexType: string
|
||||
/**
|
||||
* The columns in the index
|
||||
*
|
||||
* Currently this is always an array of size 1. In the future there may
|
||||
* be more columns to represent composite indices.
|
||||
*/
|
||||
columns: Array<string>
|
||||
}
|
||||
/** Statistics about a compaction operation. */
|
||||
export interface CompactionStats {
|
||||
/** The number of fragments removed */
|
||||
fragmentsRemoved: number
|
||||
/** The number of new, compacted fragments added */
|
||||
fragmentsAdded: number
|
||||
/** The number of data files removed */
|
||||
filesRemoved: number
|
||||
/** The number of new, compacted data files added */
|
||||
filesAdded: number
|
||||
}
|
||||
/** Statistics about a cleanup operation */
|
||||
export interface RemovalStats {
|
||||
/** The number of bytes removed */
|
||||
bytesRemoved: number
|
||||
/** The number of old versions removed */
|
||||
oldVersionsRemoved: number
|
||||
}
|
||||
/** Statistics about an optimize operation */
|
||||
export interface OptimizeStats {
|
||||
/** Statistics about the compaction operation */
|
||||
compaction: CompactionStats
|
||||
/** Statistics about the removal operation */
|
||||
prune: RemovalStats
|
||||
}
|
||||
/**
|
||||
* A definition of a column alteration. The alteration changes the column at
|
||||
* `path` to have the new name `name`, to be nullable if `nullable` is true,
|
||||
* and to have the data type `data_type`. At least one of `rename` or `nullable`
|
||||
* must be provided.
|
||||
*/
|
||||
export interface ColumnAlteration {
|
||||
/**
|
||||
* The path to the column to alter. This is a dot-separated path to the column.
|
||||
* If it is a top-level column then it is just the name of the column. If it is
|
||||
* a nested column then it is the path to the column, e.g. "a.b.c" for a column
|
||||
* `c` nested inside a column `b` nested inside a column `a`.
|
||||
*/
|
||||
path: string
|
||||
/**
|
||||
* The new name of the column. If not provided then the name will not be changed.
|
||||
* This must be distinct from the names of all other columns in the table.
|
||||
*/
|
||||
rename?: string
|
||||
/** Set the new nullability. Note that a nullable column cannot be made non-nullable. */
|
||||
nullable?: boolean
|
||||
}
|
||||
/** A definition of a new column to add to a table. */
|
||||
export interface AddColumnsSql {
|
||||
/** The name of the new column. */
|
||||
name: string
|
||||
/**
|
||||
* The values to populate the new column with, as a SQL expression.
|
||||
* The expression can reference other columns in the table.
|
||||
*/
|
||||
valueSql: string
|
||||
}
|
||||
export interface IndexStatistics {
|
||||
/** The number of rows indexed by the index */
|
||||
numIndexedRows: number
|
||||
/** The number of rows not indexed */
|
||||
numUnindexedRows: number
|
||||
/** The type of the index */
|
||||
indexType?: string
|
||||
/** The metadata for each index */
|
||||
indices: Array<IndexMetadata>
|
||||
}
|
||||
export interface IndexMetadata {
|
||||
metricType?: string
|
||||
indexType?: string
|
||||
}
|
||||
export interface ConnectionOptions {
|
||||
/**
|
||||
* (For LanceDB OSS only): The interval, in seconds, at which to check for
|
||||
* updates to the table from other processes. If None, then consistency is not
|
||||
* checked. For performance reasons, this is the default. For strong
|
||||
* consistency, set this to zero seconds. Then every read will check for
|
||||
* updates from other processes. As a compromise, you can set this to a
|
||||
* non-zero value for eventual consistency. If more than that interval
|
||||
* has passed since the last check, then the table will be checked for updates.
|
||||
* Note: this consistency only applies to read operations. Write operations are
|
||||
* always consistent.
|
||||
*/
|
||||
readConsistencyInterval?: number
|
||||
/**
|
||||
* (For LanceDB OSS only): configuration for object storage.
|
||||
*
|
||||
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||
*/
|
||||
storageOptions?: Record<string, string>
|
||||
}
|
||||
/** Write mode for writing a table. */
|
||||
export const enum WriteMode {
|
||||
Create = 'Create',
|
||||
Append = 'Append',
|
||||
Overwrite = 'Overwrite'
|
||||
}
|
||||
/** Write options when creating a Table. */
|
||||
export interface WriteOptions {
|
||||
/** Write mode for writing to a table. */
|
||||
mode?: WriteMode
|
||||
}
|
||||
export interface OpenTableOptions {
|
||||
storageOptions?: Record<string, string>
|
||||
}
|
||||
export class Connection {
|
||||
/** Create a new Connection instance from the given URI. */
|
||||
static new(uri: string, options: ConnectionOptions): Promise<Connection>
|
||||
display(): string
|
||||
isOpen(): boolean
|
||||
close(): void
|
||||
/** List all tables in the dataset. */
|
||||
tableNames(startAfter?: string | undefined | null, limit?: number | undefined | null): Promise<Array<string>>
|
||||
/**
|
||||
* Create table from a Apache Arrow IPC (file) buffer.
|
||||
*
|
||||
* Parameters:
|
||||
* - name: The name of the table.
|
||||
* - buf: The buffer containing the IPC file.
|
||||
*
|
||||
*/
|
||||
createTable(name: string, buf: Buffer, mode: string, storageOptions?: Record<string, string> | undefined | null, useLegacyFormat?: boolean | undefined | null): Promise<Table>
|
||||
createEmptyTable(name: string, schemaBuf: Buffer, mode: string, storageOptions?: Record<string, string> | undefined | null, useLegacyFormat?: boolean | undefined | null): Promise<Table>
|
||||
openTable(name: string, storageOptions?: Record<string, string> | undefined | null, indexCacheSize?: number | undefined | null): Promise<Table>
|
||||
/** Drop table with the name. Or raise an error if the table does not exist. */
|
||||
dropTable(name: string): Promise<void>
|
||||
}
|
||||
export class Index {
|
||||
static ivfPq(distanceType?: string | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): Index
|
||||
static btree(): Index
|
||||
}
|
||||
/** Typescript-style Async Iterator over RecordBatches */
|
||||
export class RecordBatchIterator {
|
||||
next(): Promise<Buffer | null>
|
||||
}
|
||||
/** A builder used to create and run a merge insert operation */
|
||||
export class NativeMergeInsertBuilder {
|
||||
whenMatchedUpdateAll(condition?: string | undefined | null): NativeMergeInsertBuilder
|
||||
whenNotMatchedInsertAll(): NativeMergeInsertBuilder
|
||||
whenNotMatchedBySourceDelete(filter?: string | undefined | null): NativeMergeInsertBuilder
|
||||
execute(buf: Buffer): Promise<void>
|
||||
}
|
||||
export class Query {
|
||||
onlyIf(predicate: string): void
|
||||
select(columns: Array<[string, string]>): void
|
||||
limit(limit: number): void
|
||||
nearestTo(vector: Float32Array): VectorQuery
|
||||
execute(maxBatchLength?: number | undefined | null): Promise<RecordBatchIterator>
|
||||
explainPlan(verbose: boolean): Promise<string>
|
||||
}
|
||||
export class VectorQuery {
|
||||
column(column: string): void
|
||||
distanceType(distanceType: string): void
|
||||
postfilter(): void
|
||||
refineFactor(refineFactor: number): void
|
||||
nprobes(nprobe: number): void
|
||||
bypassVectorIndex(): void
|
||||
onlyIf(predicate: string): void
|
||||
select(columns: Array<[string, string]>): void
|
||||
limit(limit: number): void
|
||||
execute(maxBatchLength?: number | undefined | null): Promise<RecordBatchIterator>
|
||||
explainPlan(verbose: boolean): Promise<string>
|
||||
}
|
||||
export class Table {
|
||||
name: string
|
||||
display(): string
|
||||
isOpen(): boolean
|
||||
close(): void
|
||||
/** Return Schema as empty Arrow IPC file. */
|
||||
schema(): Promise<Buffer>
|
||||
add(buf: Buffer, mode: string): Promise<void>
|
||||
countRows(filter?: string | undefined | null): Promise<number>
|
||||
delete(predicate: string): Promise<void>
|
||||
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
|
||||
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
|
||||
query(): Query
|
||||
vectorSearch(vector: Float32Array): VectorQuery
|
||||
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
|
||||
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
|
||||
dropColumns(columns: Array<string>): Promise<void>
|
||||
version(): Promise<number>
|
||||
checkout(version: number): Promise<void>
|
||||
checkoutLatest(): Promise<void>
|
||||
restore(): Promise<void>
|
||||
optimize(olderThanMs?: number | undefined | null): Promise<OptimizeStats>
|
||||
listIndices(): Promise<Array<IndexConfig>>
|
||||
indexStats(indexName: string): Promise<IndexStatistics | null>
|
||||
mergeInsert(on: Array<string>): NativeMergeInsertBuilder
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"vector database",
|
||||
"ann"
|
||||
],
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
@@ -66,8 +66,8 @@
|
||||
"os": ["darwin", "linux", "win32"],
|
||||
"scripts": {
|
||||
"artifacts": "napi artifacts",
|
||||
"build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
|
||||
"build:release": "napi build --platform --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
|
||||
"build:debug": "napi build --platform --no-const-enum --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
|
||||
"build:release": "napi build --platform --no-const-enum --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
|
||||
"build": "npm run build:debug && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts && shx cp lancedb/*.node dist/",
|
||||
"build-release": "npm run build:release && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts",
|
||||
"lint-ci": "biome ci .",
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
@@ -34,7 +25,7 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
|
||||
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
|
||||
max_retries: int = (
|
||||
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
|
||||
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@@ -46,22 +37,37 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for the source column in the database
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query with retries
|
||||
def compute_query_embeddings_with_retry(
|
||||
self, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for a given user query with retries
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_query_embeddings, max_retries=self.max_retries
|
||||
@@ -70,9 +76,15 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database with retries
|
||||
def compute_source_embeddings_with_retry(
|
||||
self, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for the source column in the database with retries.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_source_embeddings, max_retries=self.max_retries
|
||||
@@ -144,18 +156,20 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
def compute_query_embeddings(
|
||||
self, query: str, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
def compute_source_embeddings(
|
||||
self, texts: TEXT, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Generate the embeddings for the given texts"""
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
@@ -19,6 +10,7 @@ from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import ollama
|
||||
|
||||
|
||||
@register("ollama")
|
||||
@@ -39,17 +31,20 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
|
||||
def _compute_embedding(self, text):
|
||||
return self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
def _compute_embedding(self, text) -> Union["np.array", None]:
|
||||
return (
|
||||
self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
or None
|
||||
)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], "np.ndarray"]
|
||||
) -> List["np.array"]:
|
||||
) -> list[Union["np.array", None]]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
@@ -63,7 +58,7 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
return embeddings
|
||||
|
||||
@cached_property
|
||||
def _ollama_client(self):
|
||||
def _ollama_client(self) -> "ollama.Client":
|
||||
ollama = attempt_import_or_raise("ollama")
|
||||
# ToDo explore ollama.AsyncClient
|
||||
return ollama.Client(host=self.host, **self.ollama_client_kwargs)
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
import logging
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
@@ -89,17 +81,26 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
# TODO retry, rate limit, token limit
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
kwargs = {
|
||||
"input": texts,
|
||||
"model": self.name,
|
||||
}
|
||||
if self.dim:
|
||||
kwargs["dimensions"] = self.dim
|
||||
rs = self._openai_client.embeddings.create(**kwargs)
|
||||
try:
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
kwargs = {
|
||||
"input": texts,
|
||||
"model": self.name,
|
||||
}
|
||||
if self.dim:
|
||||
kwargs["dimensions"] = self.dim
|
||||
rs = self._openai_client.embeddings.create(**kwargs)
|
||||
except openai.BadRequestError:
|
||||
logging.exception("Bad request: %s", texts)
|
||||
return [None] * len(texts)
|
||||
except Exception:
|
||||
logging.exception("OpenAI embeddings error")
|
||||
raise
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -36,6 +36,7 @@ from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -679,6 +680,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
if self._reranker is not None:
|
||||
rs_table = result_set.read_all()
|
||||
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||
check_reranker_result(result_set)
|
||||
# convert result_set back to RecordBatchReader
|
||||
result_set = pa.RecordBatchReader.from_batches(
|
||||
result_set.schema, result_set.to_batches()
|
||||
@@ -811,6 +813,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
results = results.read_all()
|
||||
if self._reranker is not None:
|
||||
results = self._reranker.rerank_fts(self._query, results)
|
||||
check_reranker_result(results)
|
||||
return results
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
@@ -953,8 +956,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str = None,
|
||||
vector_column: str = None,
|
||||
query: Optional[str] = None,
|
||||
vector_column: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
@@ -1060,10 +1063,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._fts_query._query, vector_results, fts_results
|
||||
)
|
||||
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
check_reranker_result(results)
|
||||
|
||||
# apply limit after reranking
|
||||
results = results.slice(length=self._limit)
|
||||
@@ -1112,8 +1112,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
normalize="score",
|
||||
reranker: Reranker = RRFReranker(),
|
||||
normalize: str = "score",
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Rerank the hybrid search results using the specified reranker. The reranker
|
||||
@@ -1121,12 +1121,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker, default RRFReranker()
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
normalize: str, default "score"
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank",
|
||||
the scores are converted to ranks and then normalized. If "score", the
|
||||
scores are normalized directly.
|
||||
reranker: Reranker, default RRFReranker()
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
|
||||
@@ -105,7 +105,7 @@ class Reranker(ABC):
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Rerank function receives the individual results from the vector and FTS search
|
||||
results. You can choose to use any of the results to generate the final results,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from numpy import NaN
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
@@ -58,14 +59,42 @@ class LinearCombinationReranker(Reranker):
|
||||
def merge_results(
|
||||
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
|
||||
):
|
||||
# If both are empty then just return an empty table
|
||||
if len(vector_results) == 0 and len(fts_results) == 0:
|
||||
return vector_results
|
||||
# If one is empty then return the other
|
||||
# If one is empty then return the other and add _relevance_score
|
||||
# column equal the existing vector or fts score
|
||||
if len(vector_results) == 0:
|
||||
return fts_results
|
||||
results = fts_results.append_column(
|
||||
"_relevance_score",
|
||||
pa.array(fts_results["_score"], type=pa.float32()),
|
||||
)
|
||||
if self.score == "relevance":
|
||||
results = self._keep_relevance_score(results)
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_distance",
|
||||
pa.array([NaN] * len(fts_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
if len(fts_results) == 0:
|
||||
return vector_results
|
||||
# invert the distance to relevance score
|
||||
results = vector_results.append_column(
|
||||
"_relevance_score",
|
||||
pa.array(
|
||||
[
|
||||
self._invert_score(distance)
|
||||
for distance in vector_results["_distance"].to_pylist()
|
||||
],
|
||||
type=pa.float32(),
|
||||
),
|
||||
)
|
||||
if self.score == "relevance":
|
||||
results = self._keep_relevance_score(results)
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_score",
|
||||
pa.array([NaN] * len(vector_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
# sort both input tables on _rowid
|
||||
combined_list = []
|
||||
|
||||
19
python/python/lancedb/rerankers/util.py
Normal file
19
python/python/lancedb/rerankers/util.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The Lance Authors
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def check_reranker_result(result):
|
||||
if not isinstance(result, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(result)}"
|
||||
)
|
||||
|
||||
# Enforce that `_relevance_score` column is present in the result of every
|
||||
# rerank_hybrid method
|
||||
if "_relevance_score" not in result.column_names:
|
||||
raise ValueError(
|
||||
"rerank_hybrid must return a pyarrow.Table with a column"
|
||||
"named `_relevance_score`"
|
||||
)
|
||||
@@ -1630,7 +1630,9 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
if (
|
||||
vector_column_name is None and query is not None and query_type != "fts"
|
||||
) or (vector_column_name is None and query_type == "hybrid"):
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
@@ -1998,22 +2000,26 @@ def _sanitize_vector_column(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
||||
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
|
||||
# kernel over f16 types.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
|
||||
if pa.types.is_float16(vec_arr.values.type):
|
||||
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
|
||||
# kernel over f16 types yet.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
else:
|
||||
if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -2057,8 +2063,15 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
def _sanitize_nans(
|
||||
data,
|
||||
fill_value,
|
||||
on_bad_vectors,
|
||||
vec_arr: pa.FixedSizeListArray,
|
||||
vector_column_name: str,
|
||||
):
|
||||
"""Sanitize NaNs in vectors"""
|
||||
assert pa.types.is_fixed_size_list(vec_arr.type)
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
@@ -2078,9 +2091,11 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False)
|
||||
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
|
||||
data = data.filter(is_full)
|
||||
# Drop is very slow to be able to filter out NaNs in a fixed size list array
|
||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
||||
not_nulls = np.any(np_arr, axis=1)
|
||||
data = data.filter(~not_nulls)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -86,6 +86,47 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
def test_embedding_with_bad_results(tmp_path):
|
||||
@register("mock-embedding")
|
||||
class MockEmbeddingFunction(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 128
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> list[Union[np.array, None]]:
|
||||
return [
|
||||
None if i % 2 == 0 else np.random.randn(self.ndims())
|
||||
for i in range(len(texts))
|
||||
]
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("mock-embedding").create()
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
table.add(
|
||||
[{"text": "hello world"}, {"text": "bar"}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
df = table.to_pandas()
|
||||
assert len(table) == 1
|
||||
assert df.iloc[0]["text"] == "bar"
|
||||
|
||||
# table = db.create_table("test2", schema=Schema, mode="overwrite")
|
||||
# table.add(
|
||||
# [{"text": "hello world"}, {"text": "bar"}],
|
||||
# )
|
||||
# assert len(table) == 2
|
||||
# tbl = table.to_arrow()
|
||||
# assert tbl["vector"].null_count == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
|
||||
@@ -120,12 +120,14 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
err = (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Vector search setting
|
||||
result = (
|
||||
@@ -135,7 +137,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
result_explicit = (
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
@@ -158,7 +162,26 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Multi-vector search setting
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
@@ -172,7 +195,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
result_deduped = reranker.rerank_multivector(
|
||||
[rs1, rs2, rs1], query, deduplicate=True
|
||||
)
|
||||
assert len(result_deduped) < 20
|
||||
assert len(result_deduped) <= 20
|
||||
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
|
||||
assert len(result) == 20 and result == result_arrow
|
||||
|
||||
@@ -213,7 +236,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.rerank(reranker, normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
@@ -228,12 +251,30 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").text(
|
||||
query
|
||||
).to_arrow()
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Test with empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
|
||||
@@ -973,7 +973,36 @@ def test_hybrid_search(db, tmp_path):
|
||||
.where("text='Arrrrggghhhhhhh'")
|
||||
.to_list()
|
||||
)
|
||||
len(result) == 1
|
||||
assert len(result) == 1
|
||||
|
||||
# with explicit query type
|
||||
vector_query = list(range(emb.ndims()))
|
||||
result = (
|
||||
table.search(query_type="hybrid")
|
||||
.vector(vector_query)
|
||||
.text("Arrrrggghhhhhhh")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
|
||||
# with vector_column_name
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(vector_query)
|
||||
.text("Arrrrggghhhhhhh")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
|
||||
# fail if only text or vector is provided
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").to_list()
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").vector(vector_query).to_list()
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.10.0"
|
||||
version = "0.11.0-beta.1"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -46,6 +46,10 @@ impl JsQuery {
|
||||
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
|
||||
.value(&mut cx);
|
||||
|
||||
let fast_search = query_obj
|
||||
.get_opt::<JsBoolean, _, _>(&mut cx, "_fastSearch")?
|
||||
.map(|val| val.value(&mut cx));
|
||||
|
||||
let is_electron = cx
|
||||
.argument::<JsBoolean>(1)
|
||||
.or_throw(&mut cx)?
|
||||
@@ -70,6 +74,9 @@ impl JsQuery {
|
||||
if let Some(limit) = limit {
|
||||
builder = builder.limit(limit as usize);
|
||||
};
|
||||
if let Some(true) = fast_search {
|
||||
builder = builder.fast_search();
|
||||
}
|
||||
|
||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.10.0"
|
||||
version = "0.11.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -402,6 +402,9 @@ pub trait QueryBase {
|
||||
///
|
||||
/// By default, it is false.
|
||||
fn fast_search(self) -> Self;
|
||||
|
||||
/// Return the `_rowid` meta column from the Table.
|
||||
fn with_row_id(self) -> Self;
|
||||
}
|
||||
|
||||
pub trait HasQuery {
|
||||
@@ -438,6 +441,11 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self.mut_query().fast_search = true;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_row_id(mut self) -> Self {
|
||||
self.mut_query().with_row_id = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for controlling the execution of a query
|
||||
@@ -548,6 +556,11 @@ pub struct Query {
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) fast_search: bool,
|
||||
|
||||
/// If set to true, the query will return the `_rowid` meta column.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) with_row_id: bool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -560,6 +573,7 @@ impl Query {
|
||||
full_text_search: None,
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1160,4 +1174,24 @@ mod tests {
|
||||
.unwrap();
|
||||
assert!(!plan.contains("Take"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_row_id() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let results = table
|
||||
.vector_search(&[0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.with_row_id()
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
for batch in results {
|
||||
assert!(batch.column_by_name("_rowid").is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1938,6 +1938,10 @@ impl TableInternal for NativeTable {
|
||||
Select::All => {}
|
||||
}
|
||||
|
||||
if query.base.with_row_id {
|
||||
scanner.with_row_id();
|
||||
}
|
||||
|
||||
if let Some(opts) = options {
|
||||
scanner.batch_size(opts.max_batch_length as usize);
|
||||
}
|
||||
@@ -3230,7 +3234,7 @@ mod tests {
|
||||
let values_builder = StringBuilder::new();
|
||||
let mut builder = ListBuilder::new(values_builder);
|
||||
for i in 0..120 {
|
||||
builder.values().append_value(TAGS[i % 3].to_string());
|
||||
builder.values().append_value(TAGS[i % 3]);
|
||||
if i % 3 == 0 {
|
||||
builder.append(true)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user