Merge branch 'main' of https://github.com/lancedb/lancedb into yang/relative-lance-dep

This commit is contained in:
BubbleCal
2024-09-27 17:00:31 +08:00
103 changed files with 4693 additions and 710 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.10.0-beta.1"
current_version = "0.11.0-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
@@ -24,34 +24,56 @@ commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# Java maven files
pre_commit_hooks = [
"""
NEW_VERSION="${BVHOOK_NEW_MAJOR}.${BVHOOK_NEW_MINOR}.${BVHOOK_NEW_PATCH}"
if [ ! -z "$BVHOOK_NEW_PRE_L" ] && [ ! -z "$BVHOOK_NEW_PRE_N" ]; then
NEW_VERSION="${NEW_VERSION}-${BVHOOK_NEW_PRE_L}.${BVHOOK_NEW_PRE_N}"
fi
echo "Constructed new version: $NEW_VERSION"
cd java && mvn versions:set -DnewVersion=$NEW_VERSION && mvn versions:commit
# Check for any modified but unstaged pom.xml files
MODIFIED_POMS=$(git ls-files -m | grep pom.xml)
if [ ! -z "$MODIFIED_POMS" ]; then
echo "The following pom.xml files were modified but not staged. Adding them now:"
echo "$MODIFIED_POMS" | while read -r file; do
git add "$file"
echo "Added: $file"
done
fi
""",
]
[tool.bumpversion.parts.pre_l]
values = ["beta", "final"]
optional_value = "final"
values = ["beta", "final"]
[[tool.bumpversion.files]]
filename = "node/package.json"
search = "\"version\": \"{current_version}\","
replace = "\"version\": \"{new_version}\","
search = "\"version\": \"{current_version}\","
[[tool.bumpversion.files]]
filename = "nodejs/package.json"
search = "\"version\": \"{current_version}\","
replace = "\"version\": \"{new_version}\","
search = "\"version\": \"{current_version}\","
# nodejs binary packages
[[tool.bumpversion.files]]
glob = "nodejs/npm/*/package.json"
search = "\"version\": \"{current_version}\","
replace = "\"version\": \"{new_version}\","
search = "\"version\": \"{current_version}\","
# Cargo files
# ------------
[[tool.bumpversion.files]]
filename = "rust/ffi/node/Cargo.toml"
search = "\nversion = \"{current_version}\""
replace = "\nversion = \"{new_version}\""
search = "\nversion = \"{current_version}\""
[[tool.bumpversion.files]]
filename = "rust/lancedb/Cargo.toml"
search = "\nversion = \"{current_version}\""
replace = "\nversion = \"{new_version}\""
search = "\nversion = \"{current_version}\""

View File

@@ -94,11 +94,16 @@ jobs:
mkdir -p ./core/target/classes/nativelib/darwin-aarch64 ./core/target/classes/nativelib/linux-aarch64
cp ../liblancedb_jni_darwin_aarch64.zip/liblancedb_jni.dylib ./core/target/classes/nativelib/darwin-aarch64/liblancedb_jni.dylib
cp ../liblancedb_jni_linux_aarch64.zip/liblancedb_jni.so ./core/target/classes/nativelib/linux-aarch64/liblancedb_jni.so
- name: Dry run
if: github.event_name == 'pull_request'
run: |
mvn --batch-mode -DskipTests package
- name: Set github
run: |
git config --global user.email "LanceDB Github Runner"
git config --global user.name "dev+gha@lancedb.com"
- name: Publish with Java 8
if: github.event_name == 'release'
run: |
echo "use-agent" >> ~/.gnupg/gpg.conf
echo "pinentry-mode loopback" >> ~/.gnupg/gpg.conf

View File

@@ -30,7 +30,7 @@ on:
default: true
type: boolean
other:
description: 'Make a Node/Rust release'
description: 'Make a Node/Rust/Java release'
required: true
default: true
type: boolean

View File

@@ -30,7 +30,6 @@ jobs:
defaults:
run:
shell: bash
working-directory: rust
env:
# Need up-to-date compilers for kernels
CC: gcc-12
@@ -50,7 +49,7 @@ jobs:
- name: Run format
run: cargo fmt --all -- --check
- name: Run clippy
run: cargo clippy --all --all-features -- -D warnings
run: cargo clippy --workspace --tests --all-features -- -D warnings
linux:
timeout-minutes: 30
# To build all features, we need more disk space than is available

View File

@@ -20,12 +20,12 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.17.0", "features" = ["dynamodb"], path = "../lance/rust/lance"}
lance-index = { "version" = "=0.17.0", path = "../lance/rust/lance-index"}
lance-linalg = { "version" = "=0.17.0", path = "../lance/rust/lance-linalg"}
lance-testing = { "version" = "=0.17.0", path = "../lance/rust/lance-testing"}
lance-datafusion = { "version" = "=0.17.0", path = "../lance/rust/lance-datafusion"}
lance-encoding = { "version" = "=0.17.0", path = "../lance/rust/lance-encoding"}
lance = { "version" = "=0.18.0", "features" = ["dynamodb"], path = "../lance/rust/lance"}
lance-index = { "version" = "=0.18.0", path = "../lance/rust/lance-index"}
lance-linalg = { "version" = "=0.18.0", path = "../lance/rust/lance-linalg"}
lance-testing = { "version" = "=0.18.0", path = "../lance/rust/lance-testing"}
lance-datafusion = { "version" = "=0.18.0", path = "../lance/rust/lance-datafusion"}
lance-encoding = { "version" = "=0.18.0", path = "../lance/rust/lance-encoding"}
# Note that this one does not include pyarrow
arrow = { version = "52.2", optional = false }
arrow-array = "52.2"

View File

@@ -34,6 +34,7 @@ theme:
- navigation.footer
- navigation.tracking
- navigation.instant
- content.footnote.tooltips
icon:
repo: fontawesome/brands/github
annotation: material/arrow-right-circle
@@ -65,6 +66,11 @@ plugins:
markdown_extensions:
- admonition
- footnotes
- pymdownx.critic
- pymdownx.caret
- pymdownx.keys
- pymdownx.mark
- pymdownx.tilde
- pymdownx.details
- pymdownx.highlight:
anchor_linenums: true
@@ -106,6 +112,17 @@ nav:
- Overview: hybrid_search/hybrid_search.md
- Comparing Rerankers: hybrid_search/eval.md
- Airbnb financial data example: notebooks/hybrid_search.ipynb
- RAG:
- Vanilla RAG: rag/vanilla_rag.md
- Multi-head RAG: rag/multi_head_rag.md
- Corrective RAG: rag/corrective_rag.md
- Agentic RAG: rag/agentic_rag.md
- Graph RAG: rag/graph_rag.md
- Self RAG: rag/self_rag.md
- Adaptive RAG: rag/adaptive_rag.md
- Advanced Techniques:
- HyDE: rag/advanced_techniques/hyde.md
- FLARE: rag/advanced_techniques/flare.md
- Reranking:
- Quickstart: reranking/index.md
- Cohere Reranker: reranking/cohere.md
@@ -127,7 +144,8 @@ nav:
- Reranking: guides/tuning_retrievers/2_reranking.md
- Embedding fine-tuning: guides/tuning_retrievers/3_embed_tuning.md
- 🧬 Managing embeddings:
- Overview: embeddings/index.md
- Understand Embeddings: embeddings/understanding_embeddings.md
- Get Started: embeddings/index.md
- Embedding functions: embeddings/embedding_functions.md
- Available models:
- Overview: embeddings/default_embedding_functions.md
@@ -165,6 +183,7 @@ nav:
- Voxel51: integrations/voxel51.md
- PromptTools: integrations/prompttools.md
- dlt: integrations/dlt.md
- phidata: integrations/phidata.md
- 🎯 Examples:
- Overview: examples/index.md
- 🐍 Python:
@@ -220,6 +239,17 @@ nav:
- Overview: hybrid_search/hybrid_search.md
- Comparing Rerankers: hybrid_search/eval.md
- Airbnb financial data example: notebooks/hybrid_search.ipynb
- RAG:
- Vanilla RAG: rag/vanilla_rag.md
- Multi-head RAG: rag/multi_head_rag.md
- Corrective RAG: rag/corrective_rag.md
- Agentic RAG: rag/agentic_rag.md
- Graph RAG: rag/graph_rag.md
- Self RAG: rag/self_rag.md
- Adaptive RAG: rag/adaptive_rag.md
- Advanced Techniques:
- HyDE: rag/advanced_techniques/hyde.md
- FLARE: rag/advanced_techniques/flare.md
- Reranking:
- Quickstart: reranking/index.md
- Cohere Reranker: reranking/cohere.md
@@ -241,7 +271,8 @@ nav:
- Reranking: guides/tuning_retrievers/2_reranking.md
- Embedding fine-tuning: guides/tuning_retrievers/3_embed_tuning.md
- Managing Embeddings:
- Overview: embeddings/index.md
- Understand Embeddings: embeddings/understanding_embeddings.md
- Get Started: embeddings/index.md
- Embedding functions: embeddings/embedding_functions.md
- Available models:
- Overview: embeddings/default_embedding_functions.md
@@ -275,6 +306,7 @@ nav:
- Voxel51: integrations/voxel51.md
- PromptTools: integrations/prompttools.md
- dlt: integrations/dlt.md
- phidata: integrations/phidata.md
- Examples:
- examples/index.md
- 🐍 Python:

View File

@@ -0,0 +1,133 @@
# Understand Embeddings
The term **dimension** is a synonym for the number of elements in a feature vector. Each feature can be thought of as a different axis in a geometric space.
High-dimensional data means there are many features(or attributes) in the data.
!!! example
1. An image is a data point and it might have thousands of dimensions because each pixel could be considered as a feature.
2. Text data, when represented by each word or character, can also lead to high dimensions, especially when considering all possible words in a language.
Embedding captures **meaning and relationships** within data by mapping high-dimensional data into a lower-dimensional space. It captures it by placing inputs that are more **similar in meaning** closer together in the **embedding space**.
## What are Vector Embeddings?
Vector embeddings is a way to convert complex data, like text, images, or audio into numerical coordinates (called vectors) that can be plotted in an n-dimensional space(embedding space).
The closer these data points are related in the real world, the closer their corresponding numerical coordinates (vectors) will be to each other in the embedding space. This proximity in the embedding space reflects their semantic similarities, allowing machines to intuitively understand and process the data in a way that mirrors human perception of relationships and meaning.
In a way, it captures the most important aspects of the data while ignoring the less important ones. As a result, tasks like searching for related content or identifying patterns become more efficient and accurate, as the embeddings make it possible to quantify how **closely related** different **data points** are and **reduce** the **computational complexity**.
??? question "Are vectors and embeddings the same thing?"
When we say “vectors” we mean - **list of numbers** that **represents the data**.
When we say “embeddings” we mean - **list of numbers** that **capture important details and relationships**.
Although the terms are often used interchangeably, “embeddings” highlight how the data is represented with meaning and structure, while “vector” simply refers to the numerical form of that representation.
## Embedding vs Indexing
We already saw that creating **embeddings** on data is a method of creating **vectors** for a **n-dimensional embedding space** that captures the meaning and relationships inherent in the data.
Once we have these **vectors**, indexing comes into play. Indexing is a method of organizing these vector embeddings, that allows us to quickly and efficiently locate and retrieve them from the entire dataset of vector embeddings.
## What types of data/objects can be embedded?
The following are common types of data that can be embedded:
1. **Text**: Text data includes sentences, paragraphs, documents, or any written content.
2. **Images**: Image data encompasses photographs, illustrations, or any visual content.
3. **Audio**: Audio data includes sounds, music, speech, or any auditory content.
4. **Video**: Video data consists of moving images and sound, which can convey complex information.
Large datasets of multi-modal data (text, audio, images, etc.) can be converted into embeddings with the appropriate model.
!!! tip "LanceDB vs Other traditional Vector DBs"
While many vector databases primarily focus on the storage and retrieval of vector embeddings, **LanceDB** uses **Lance file format** (operates on a disk-based architecture), which allows for the storage and management of not just embeddings but also **raw file data (bytes)**. This capability means that users can integrate various types of data, including images and text, alongside their vector embeddings in a unified system.
With the ability to store both vectors and associated file data, LanceDB enhances the querying process. Users can perform semantic searches that not only retrieve similar embeddings but also access related files and metadata, thus streamlining the workflow.
## How does embedding works?
As mentioned, after creating embedding, each data point is represented as a vector in a n-dimensional space (embedding space). The dimensionality of this space can vary depending on the complexity of the data and the specific embedding technique used.
Points that are close to each other in vector space are considered similar (or appear in similar contexts), and points that are far away are considered dissimilar. To quantify this closeness, we use distance as a metric which can be measured in the following way -
1. **Euclidean Distance (L2)**: It calculates the straight-line distance between two points (vectors) in a multidimensional space.
2. **Cosine Similarity**: It measures the cosine of the angle between two vectors, providing a normalized measure of similarity based on their direction.
3. **Dot product**: It is calculated as the sum of the products of their corresponding components. To measure relatedness it considers both the magnitude and direction of the vectors.
## How do you create and store vector embeddings for your data?
1. **Creating embeddings**: Choose an embedding model, it can be a pre-trained model (open-source or commercial) or you can train a custom embedding model for your scenario. Then feed your preprocessed data into the chosen model to obtain embeddings.
??? question "Popular choices for embedding models"
For text data, popular choices are OpenAIs text-embedding models, Google Gemini text-embedding models, Coheres Embed models, and SentenceTransformers, etc.
For image data, popular choices are CLIP (Contrastive LanguageImage Pretraining), Imagebind embeddings by meta (supports audio, video, and image), and Jina multi-modal embeddings, etc.
2. **Storing vector embeddings**: This effectively requires **specialized databases** that can handle the complexity of vector data, as traditional databases often struggle with this task. Vector databases are designed specifically for storing and querying vector embeddings. They optimize for efficient nearest-neighbor searches and provide built-in indexing mechanisms.
!!! tip "Why LanceDB"
LanceDB **automates** the entire process of creating and storing embeddings for your data. LanceDB allows you to define and use **embedding functions**, which can be **pre-trained models** or **custom models**.
This enables you to **generate** embeddings tailored to the nature of your data (e.g., text, images) and **store** both the **original data** and **embeddings** in a **structured schema** thus providing efficient querying capabilities for similarity searches.
Let's quickly [get started](./index.md) and learn how to manage embeddings in LanceDB.
## Bonus: As a developer, what you can create using embeddings?
As a developer, you can create a variety of innovative applications using vector embeddings. Check out the following -
<div class="grid cards" markdown>
- __Chatbots__
---
Develop chatbots that utilize embeddings to retrieve relevant context and generate coherent, contextually aware responses to user queries.
[:octicons-arrow-right-24: Check out examples](../examples/python_examples/chatbot.md)
- __Recommendation Systems__
---
Develop systems that recommend content (such as articles, movies, or products) based on the similarity of keywords and descriptions, enhancing user experience.
[:octicons-arrow-right-24: Check out examples](../examples/python_examples/recommendersystem.md)
- __Vector Search__
---
Build powerful applications that harness the full potential of semantic search, enabling them to retrieve relevant data quickly and effectively.
[:octicons-arrow-right-24: Check out examples](../examples/python_examples/vector_search.md)
- __RAG Applications__
---
Combine the strengths of large language models (LLMs) with retrieval-based approaches to create more useful applications.
[:octicons-arrow-right-24: Check out examples](../examples/python_examples/rag.md)
- __Many more examples__
---
Explore applied examples available as Colab notebooks or Python scripts to integrate into your applications.
[:octicons-arrow-right-24: More](../examples/examples_python.md)
</div>

View File

@@ -2,7 +2,7 @@
LanceDB provides support for full-text search via Lance (before via [Tantivy](https://github.com/quickwit-oss/tantivy) (Python only)), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions.
Currently, the Lance full text search is missing some features that are in the Tantivy full text search. This includes phrase queries, re-ranking, and customizing the tokenizer. Thus, in Python, Tantivy is still the default way to do full text search and many of the instructions below apply just to Tantivy-based indices.
Currently, the Lance full text search is missing some features that are in the Tantivy full text search. This includes query parser and customizing the tokenizer. Thus, in Python, Tantivy is still the default way to do full text search and many of the instructions below apply just to Tantivy-based indices.
## Installation (Only for Tantivy-based FTS)
@@ -62,7 +62,7 @@ Consider that we have a LanceDB table named `my_table`, whose string column `tex
});
await tbl
.search("puppy")
.search("puppy", queryType="fts")
.select(["text"])
.limit(10)
.toArray();
@@ -205,7 +205,7 @@ table.create_fts_index(["text_field"], use_tantivy=True, ordering_field_names=["
## Phrase queries vs. terms queries
!!! warning "Warn"
Phrase queries are available for only Tantivy-based FTS
Lance-based FTS doesn't support queries using boolean operators `OR`, `AND`.
For full-text search you can specify either a **phrase** query like `"the old man and the sea"`,
or a **terms** search query like `"(Old AND Man) AND Sea"`. For more details on the terms

View File

@@ -0,0 +1,383 @@
**phidata** is a framework for building **AI Assistants** with long-term memory, contextual knowledge, and the ability to take actions using function calling. It helps turn general-purpose LLMs into specialized assistants tailored to your use case by extending its capabilities using **memory**, **knowledge**, and **tools**.
- **Memory**: Stores chat history in a **database** and enables LLMs to have long-term conversations.
- **Knowledge**: Stores information in a **vector database** and provides LLMs with business context. (Here we will use LanceDB)
- **Tools**: Enable LLMs to take actions like pulling data from an **API**, **sending emails** or **querying a database**, etc.
![example](https://raw.githubusercontent.com/lancedb/assets/refs/heads/main/docs/assets/integration/phidata_assistant.png)
Memory & knowledge make LLMs smarter while tools make them autonomous.
LanceDB is a vector database and its integration into phidata makes it easy for us to provide a **knowledge base** to LLMs. It enables us to store information as [embeddings](../embeddings/understanding_embeddings.md) and search for the **results** similar to ours using **query**.
??? Question "What is Knowledge Base?"
Knowledge Base is a database of information that the Assistant can search to improve its responses. This information is stored in a vector database and provides LLMs with business context, which makes them respond in a context-aware manner.
While any type of storage can act as a knowledge base, vector databases offer the best solution for retrieving relevant results from dense information quickly.
Let's see how using LanceDB inside phidata helps in making LLM more useful:
## Prerequisites: install and import necessary dependencies
**Create a virtual environment**
1. install virtualenv package
```python
pip install virtualenv
```
2. Create a directory for your project and go to the directory and create a virtual environment inside it.
```python
mkdir phi
```
```python
cd phi
```
```python
python -m venv phidata_
```
**Activating virtual environment**
1. from inside the project directory, run the following command to activate the virtual environment.
```python
phidata_/Scripts/activate
```
**Install the following packages in the virtual environment**
```python
pip install lancedb phidata youtube_transcript_api openai ollama pandas numpy
```
**Create python files and import necessary libraries**
You need to create two files - `transcript.py` and `ollama_assistant.py` or `openai_assistant.py`
=== "openai_assistant.py"
```python
import os, openai
from rich.prompt import Prompt
from phi.assistant import Assistant
from phi.knowledge.text import TextKnowledgeBase
from phi.vectordb.lancedb import LanceDb
from phi.llm.openai import OpenAIChat
from phi.embedder.openai import OpenAIEmbedder
from transcript import extract_transcript
if "OPENAI_API_KEY" not in os.environ:
# OR set the key here as a variable
openai.api_key = "sk-..."
# The code below creates a file "transcript.txt" in the directory, the txt file will be used below
youtube_url = "https://www.youtube.com/watch?v=Xs33-Gzl8Mo"
segment_duration = 20
transcript_text,dict_transcript = extract_transcript(youtube_url,segment_duration)
```
=== "ollama_assistant.py"
```python
from rich.prompt import Prompt
from phi.assistant import Assistant
from phi.knowledge.text import TextKnowledgeBase
from phi.vectordb.lancedb import LanceDb
from phi.llm.ollama import Ollama
from phi.embedder.ollama import OllamaEmbedder
from transcript import extract_transcript
# The code below creates a file "transcript.txt" in the directory, the txt file will be used below
youtube_url = "https://www.youtube.com/watch?v=Xs33-Gzl8Mo"
segment_duration = 20
transcript_text,dict_transcript = extract_transcript(youtube_url,segment_duration)
```
=== "transcript.py"
``` python
from youtube_transcript_api import YouTubeTranscriptApi
import re
def smodify(seconds):
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"
def extract_transcript(youtube_url,segment_duration):
# Extract video ID from the URL
video_id = re.search(r'(?<=v=)[\w-]+', youtube_url)
if not video_id:
video_id = re.search(r'(?<=be/)[\w-]+', youtube_url)
if not video_id:
return None
video_id = video_id.group(0)
# Attempt to fetch the transcript
try:
# Try to get the official transcript
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
except Exception:
# If no official transcript is found, try to get auto-generated transcript
try:
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
for transcript in transcript_list:
transcript = transcript.translate('en').fetch()
except Exception:
return None
# Format the transcript into 120s chunks
transcript_text,dict_transcript = format_transcript(transcript,segment_duration)
# Open the file in write mode, which creates it if it doesn't exist
with open("transcript.txt", "w",encoding="utf-8") as file:
file.write(transcript_text)
return transcript_text,dict_transcript
def format_transcript(transcript,segment_duration):
chunked_transcript = []
chunk_dict = []
current_chunk = []
current_time = 0
# 2 minutes in seconds
start_time_chunk = 0 # To track the start time of the current chunk
for segment in transcript:
start_time = segment['start']
end_time_x = start_time + segment['duration']
text = segment['text']
# Add text to the current chunk
current_chunk.append(text)
# Update the current time with the duration of the current segment
# The duration of the current segment is given by segment['start'] - start_time_chunk
if current_chunk:
current_time = start_time - start_time_chunk
# If current chunk duration reaches or exceeds 2 minutes, save the chunk
if current_time >= segment_duration:
# Use the start time of the first segment in the current chunk as the timestamp
chunked_transcript.append(f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}] " + " ".join(current_chunk))
current_chunk = re.sub(r'[\xa0\n]', lambda x: '' if x.group() == '\xa0' else ' ', "\n".join(current_chunk))
chunk_dict.append({"timestamp":f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}]", "text": "".join(current_chunk)})
current_chunk = [] # Reset the chunk
start_time_chunk = start_time + segment['duration'] # Update the start time for the next chunk
current_time = 0 # Reset current time
# Add any remaining text in the last chunk
if current_chunk:
chunked_transcript.append(f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}] " + " ".join(current_chunk))
current_chunk = re.sub(r'[\xa0\n]', lambda x: '' if x.group() == '\xa0' else ' ', "\n".join(current_chunk))
chunk_dict.append({"timestamp":f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}]", "text": "".join(current_chunk)})
return "\n\n".join(chunked_transcript), chunk_dict
```
!!! warning
If creating Ollama assistant, download and install Ollama [from here](https://ollama.com/) and then run the Ollama instance in the background. Also, download the required models using `ollama pull <model-name>`. Check out the models [here](https://ollama.com/library)
**Run the following command to deactivate the virtual environment if needed**
```python
deactivate
```
## **Step 1** - Create a Knowledge Base for AI Assistant using LanceDB
=== "openai_assistant.py"
```python
# Create knowledge Base with OpenAIEmbedder in LanceDB
knowledge_base = TextKnowledgeBase(
path="transcript.txt",
vector_db=LanceDb(
embedder=OpenAIEmbedder(api_key = openai.api_key),
table_name="transcript_documents",
uri="./t3mp/.lancedb",
),
num_documents = 10
)
```
=== "ollama_assistant.py"
```python
# Create knowledge Base with OllamaEmbedder in LanceDB
knowledge_base = TextKnowledgeBase(
path="transcript.txt",
vector_db=LanceDb(
embedder=OllamaEmbedder(model="nomic-embed-text",dimensions=768),
table_name="transcript_documents",
uri="./t2mp/.lancedb",
),
num_documents = 10
)
```
Check out the list of **embedders** supported by **phidata** and their usage [here](https://docs.phidata.com/embedder/introduction).
Here we have used `TextKnowledgeBase`, which loads text/docx files to the knowledge base.
Let's see all the parameters that `TextKnowledgeBase` takes -
| Name| Type | Purpose | Default |
|:----|:-----|:--------|:--------|
|`path`|`Union[str, Path]`| Path to text file(s). It can point to a single text file or a directory of text files.| provided by user |
|`formats`|`List[str]`| File formats accepted by this knowledge base. |`[".txt"]`|
|`vector_db`|`VectorDb`| Vector Database for the Knowledge Base. phidata provides a wrapper around many vector DBs, you can import it like this - `from phi.vectordb.lancedb import LanceDb` | provided by user |
|`num_documents`|`int`| Number of results (documents/vectors) that vector search should return. |`5`|
|`reader`|`TextReader`| phidata provides many types of reader objects which read data, clean it and create chunks of data, encapsulate each chunk inside an object of the `Document` class, and return **`List[Document]`**. | `TextReader()` |
|`optimize_on`|`int`| It is used to specify the number of documents on which to optimize the vector database. Supposed to create an index. |`1000`|
??? Tip "Wonder! What is `Document` class?"
We know that, before storing the data in vectorDB, we need to split the data into smaller chunks upon which embeddings will be created and these embeddings along with the chunks will be stored in vectorDB. When the user queries over the vectorDB, some of these embeddings will be returned as the result based on the semantic similarity with the query.
When the user queries over vectorDB, the queries are converted into embeddings, and a nearest neighbor search is performed over these query embeddings which returns the embeddings that correspond to most semantically similar chunks(parts of our data) present in vectorDB.
Here, a “Document” is a class in phidata. Since there is an option to let phidata create and manage embeddings, it splits our data into smaller chunks(as expected). It does not directly create embeddings on it. Instead, it takes each chunk and encapsulates it inside the object of the `Document` class along with various other metadata related to the chunk. Then embeddings are created on these `Document` objects and stored in vectorDB.
```python
class Document(BaseModel):
"""Model for managing a document"""
content: str # <--- here data of chunk is stored
id: Optional[str] = None
name: Optional[str] = None
meta_data: Dict[str, Any] = {}
embedder: Optional[Embedder] = None
embedding: Optional[List[float]] = None
usage: Optional[Dict[str, Any]] = None
```
However, using phidata you can load many other types of data in the knowledge base(other than text). Check out [phidata Knowledge Base](https://docs.phidata.com/knowledge/introduction) for more information.
Let's dig deeper into the `vector_db` parameter and see what parameters `LanceDb` takes -
| Name| Type | Purpose | Default |
|:----|:-----|:--------|:--------|
|`embedder`|`Embedder`| phidata provides many Embedders that abstract the interaction with embedding APIs and utilize it to generate embeddings. Check out other embedders [here](https://docs.phidata.com/embedder/introduction) | `OpenAIEmbedder` |
|`distance`|`List[str]`| The choice of distance metric used to calculate the similarity between vectors, which directly impacts search results and performance in vector databases. |`Distance.cosine`|
|`connection`|`lancedb.db.LanceTable`| LanceTable can be accessed through `.connection`. You can connect to an existing table of LanceDB, created outside of phidata, and utilize it. If not provided, it creates a new table using `table_name` parameter and adds it to `connection`. |`None`|
|`uri`|`str`| It specifies the directory location of **LanceDB database** and establishes a connection that can be used to interact with the database. | `"/tmp/lancedb"` |
|`table_name`|`str`| If `connection` is not provided, it initializes and connects to a new **LanceDB table** with a specified(or default) name in the database present at `uri`. |`"phi"`|
|`nprobes`|`int`| It refers to the number of partitions that the search algorithm examines to find the nearest neighbors of a given query vector. Higher values will yield better recall (more likely to find vectors if they exist) at the expense of latency. |`20`|
!!! note
Since we just initialized the KnowledgeBase. The VectorDB table that corresponds to this Knowledge Base is not yet populated with our data. It will be populated in **Step 3**, once we perform the `load` operation.
You can check the state of the LanceDB table using - `knowledge_base.vector_db.connection.to_pandas()`
Now that the Knowledge Base is initialized, , we can go to **step 2**.
## **Step 2** - Create an assistant with our choice of LLM and reference to the knowledge base.
=== "openai_assistant.py"
```python
# define an assistant with gpt-4o-mini llm and reference to the knowledge base created above
assistant = Assistant(
llm=OpenAIChat(model="gpt-4o-mini", max_tokens=1000, temperature=0.3,api_key = openai.api_key),
description="""You are an Expert in explaining youtube video transcripts. You are a bot that takes transcript of a video and answer the question based on it.
This is transcript for the above timestamp: {relevant_document}
The user input is: {user_input}
generate highlights only when asked.
When asked to generate highlights from the video, understand the context for each timestamp and create key highlight points, answer in following way -
[timestamp] - highlight 1
[timestamp] - highlight 2
... so on
Your task is to understand the user question, and provide an answer using the provided contexts. Your answers are correct, high-quality, and written by an domain expert. If the provided context does not contain the answer, simply state,'The provided context does not have the answer.'""",
knowledge_base=knowledge_base,
add_references_to_prompt=True,
)
```
=== "ollama_assistant.py"
```python
# define an assistant with llama3.1 llm and reference to the knowledge base created above
assistant = Assistant(
llm=Ollama(model="llama3.1"),
description="""You are an Expert in explaining youtube video transcripts. You are a bot that takes transcript of a video and answer the question based on it.
This is transcript for the above timestamp: {relevant_document}
The user input is: {user_input}
generate highlights only when asked.
When asked to generate highlights from the video, understand the context for each timestamp and create key highlight points, answer in following way -
[timestamp] - highlight 1
[timestamp] - highlight 2
... so on
Your task is to understand the user question, and provide an answer using the provided contexts. Your answers are correct, high-quality, and written by an domain expert. If the provided context does not contain the answer, simply state,'The provided context does not have the answer.'""",
knowledge_base=knowledge_base,
add_references_to_prompt=True,
)
```
Assistants add **memory**, **knowledge**, and **tools** to LLMs. Here we will add only **knowledge** in this example.
Whenever we will give a query to LLM, the assistant will retrieve relevant information from our **Knowledge Base**(table in LanceDB) and pass it to LLM along with the user query in a structured way.
- The `add_references_to_prompt=True` always adds information from the knowledge base to the prompt, regardless of whether it is relevant to the question.
To know more about an creating assistant in phidata, check out [phidata docs](https://docs.phidata.com/assistants/introduction) here.
## **Step 3** - Load data to Knowledge Base.
```python
# load out data into the knowledge_base (populating the LanceTable)
assistant.knowledge_base.load(recreate=False)
```
The above code loads the data to the Knowledge Base(LanceDB Table) and now it is ready to be used by the assistant.
| Name| Type | Purpose | Default |
|:----|:-----|:--------|:--------|
|`recreate`|`bool`| If True, it drops the existing table and recreates the table in the vectorDB. |`False`|
|`upsert`|`bool`| If True and the vectorDB supports upsert, it will upsert documents to the vector db. | `False` |
|`skip_existing`|`bool`| If True, skips documents that already exist in the vectorDB when inserting. |`True`|
??? tip "What is upsert?"
Upsert is a database operation that combines "update" and "insert". It updates existing records if a document with the same identifier does exist, or inserts new records if no matching record exists. This is useful for maintaining the most current information without manually checking for existence.
During the Load operation, phidata directly interacts with the LanceDB library and performs the loading of the table with our data in the following steps -
1. **Creates** and **initializes** the table if it does not exist.
2. Then it **splits** our data into smaller **chunks**.
??? question "How do they create chunks?"
**phidata** provides many types of **Knowledge Bases** based on the type of data. Most of them :material-information-outline:{ title="except LlamaIndexKnowledgeBase and LangChainKnowledgeBase"} has a property method called `document_lists` of type `Iterator[List[Document]]`. During the load operation, this property method is invoked. It traverses on the data provided by us (in this case, a text file(s)) using `reader`. Then it **reads**, **creates chunks**, and **encapsulates** each chunk inside a `Document` object and yields **lists of `Document` objects** that contain our data.
3. Then **embeddings** are created on these chunks are **inserted** into the LanceDB Table
??? question "How do they insert your data as different rows in LanceDB Table?"
The chunks of your data are in the form - **lists of `Document` objects**. It was yielded in the step above.
for each `Document` in `List[Document]`, it does the following operations:
- Creates embedding on `Document`.
- Cleans the **content attribute**(chunks of our data is here) of `Document`.
- Prepares data by creating `id` and loading `payload` with the metadata related to this chunk. (1)
{ .annotate }
1. Three columns will be added to the table - `"id"`, `"vector"`, and `"payload"` (payload contains various metadata including **`content`**)
- Then add this data to LanceTable.
4. Now the internal state of `knowledge_base` is changed (embeddings are created and loaded in the table ) and it **ready to be used by assistant**.
## **Step 4** - Start a cli chatbot with access to the Knowledge base
```python
# start cli chatbot with knowledge base
assistant.print_response("Ask me about something from the knowledge base")
while True:
message = Prompt.ask(f"[bold] :sunglasses: User [/bold]")
if message in ("exit", "bye"):
break
assistant.print_response(message, markdown=True)
```
For more information and amazing cookbooks of phidata, read the [phidata documentation](https://docs.phidata.com/introduction) and also visit [LanceDB x phidata docmentation](https://docs.phidata.com/vectordb/lancedb).

View File

@@ -68,3 +68,25 @@ currently is also a memory intensive operation.
#### Returns
[`Index`](Index.md)
### fts()
> `static` **fts**(`options`?): [`Index`](Index.md)
Create a full text search index
This index is used to search for text data. The index is created by tokenizing the text
into words and then storing occurrences of these words in a data structure called inverted index
that allows for fast search.
During a search the query is tokenized and the inverted index is used to find the rows that
contain the query words. The rows are then scored based on BM25 and the top scoring rows are
sorted and returned.
#### Parameters
**options?**: `Partial`&lt;[`FtsOptions`](../interfaces/FtsOptions.md)&gt;
#### Returns
[`Index`](Index.md)

View File

@@ -501,16 +501,28 @@ Get the schema of the table.
#### search(query)
> `abstract` **search**(`query`): [`VectorQuery`](VectorQuery.md)
> `abstract` **search**(`query`, `queryType`, `ftsColumns`): [`VectorQuery`](VectorQuery.md)
Create a search query to find the nearest neighbors
of the given query vector
of the given query vector, or the documents
with the highest relevance to the query string.
##### Parameters
• **query**: `string`
the query. This will be converted to a vector using the table's provided embedding function
the query. This will be converted to a vector using the table's provided embedding function,
or the query string for full-text search if `queryType` is "fts".
• **queryType**: `string` = `"auto"` \| `"fts"`
the type of query to run. If "auto", the query type will be determined based on the query.
• **ftsColumns**: `string[] | str` = undefined
the columns to search in. If not provided, all indexed columns will be searched.
For now, this can support to search only one column.
##### Returns

View File

@@ -37,6 +37,7 @@
- [IndexOptions](interfaces/IndexOptions.md)
- [IndexStatistics](interfaces/IndexStatistics.md)
- [IvfPqOptions](interfaces/IvfPqOptions.md)
- [FtsOptions](interfaces/FtsOptions.md)
- [TableNamesOptions](interfaces/TableNamesOptions.md)
- [UpdateOptions](interfaces/UpdateOptions.md)
- [WriteOptions](interfaces/WriteOptions.md)

View File

@@ -0,0 +1,51 @@
**Adaptive RAG 🤹‍♂️**
====================================================================
Adaptive RAG introduces a RAG technique that combines query analysis with self-corrective RAG.
For Query Analysis, it uses a small classifier(LLM), to decide the querys complexity. Query Analysis helps routing smoothly to adjust between different retrieval strategies No retrieval, Single-shot RAG or Iterative RAG.
**[Official Paper](https://arxiv.org/pdf/2403.14403)**
<figure markdown="span">
![agent-based-rag](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/adaptive_rag.png)
<figcaption>Adaptive-RAG: <a href="https://github.com/starsuzi/Adaptive-RAG">Source</a>
</figcaption>
</figure>
**[Offical Implementation](https://github.com/starsuzi/Adaptive-RAG)**
Heres a code snippet for query analysis
```python
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
# LLM with function call
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)
```
For defining and querying retriever
```python
# add documents in LanceDB
vectorstore = LanceDB.from_documents(
documents=doc_splits,
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
# query using defined retriever
question = "How adaptive RAG works"
docs = retriever.get_relevant_documents(question)
```

View File

@@ -0,0 +1,38 @@
**FLARE 💥**
====================================================================
FLARE, stands for Forward-Looking Active REtrieval augmented generation is a generic retrieval-augmented generation method that actively decides when and what to retrieve using a prediction of the upcoming sentence to anticipate future content and utilize it as the query to retrieve relevant documents if it contains low-confidence tokens.
**[Official Paper](https://arxiv.org/abs/2305.06983)**
<figure markdown="span">
![flare](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/flare.gif)
<figcaption>FLARE: <a href="https://github.com/jzbjyb/FLARE">Source</a></figcaption>
</figure>
[![Open In Colab](../../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/examples/better-rag-FLAIR/main.ipynb)
Heres a code snippet for using FLARE with Langchain
```python
from langchain.vectorstores import LanceDB
from langchain.document_loaders import ArxivLoader
from langchain.chains import FlareChain
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import OpenAI
llm = OpenAI()
# load dataset
# LanceDB retriever
vector_store = LanceDB.from_documents(doc_chunks, embeddings, connection=table)
retriever = vector_store.as_retriever()
# define flare chain
flare = FlareChain.from_llm(llm=llm,retriever=vector_store_retriever,max_generation_len=300,min_prob=0.45)
result = flare.run(input_text)
```
[![Open In Colab](../../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/examples/better-rag-FLAIR/main.ipynb)

View File

@@ -0,0 +1,55 @@
**HyDE: Hypothetical Document Embeddings 🤹‍♂️**
====================================================================
HyDE, stands for Hypothetical Document Embeddings is an approach used for precise zero-shot dense retrieval without relevance labels. It focuses on augmenting and improving similarity searches, often intertwined with vector stores in information retrieval. The method generates a hypothetical document for an incoming query, which is then embedded and used to look up real documents that are similar to the hypothetical document.
**[Official Paper](https://arxiv.org/pdf/2212.10496)**
<figure markdown="span">
![hyde](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/hyde.png)
<figcaption>HyDE: <a href="https://arxiv.org/pdf/2212.10496">Source</a></figcaption>
</figure>
[![Open In Colab](../../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/examples/Advance-RAG-with-HyDE/main.ipynb)
Heres a code snippet for using HyDE with Langchain
```python
from langchain.llms import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, HypotheticalDocumentEmbedder
from langchain.vectorstores import LanceDB
# set OPENAI_API_KEY as env variable before this step
# initialize LLM and embedding function
llm = OpenAI()
emebeddings = OpenAIEmbeddings()
# HyDE embedding
embeddings = HypotheticalDocumentEmbedder(llm_chain=llm_chain,base_embeddings=embeddings)
# load dataset
# LanceDB retriever
retriever = LanceDB.from_documents(documents, embeddings, connection=table)
# prompt template
prompt_template = """
As a knowledgeable and helpful research assistant, your task is to provide informative answers based on the given context. Use your extensive knowledge base to offer clear, concise, and accurate responses to the user's inquiries.
if quetion is not related to documents simply say you dont know
Question: {question}
Answer:
"""
prompt = PromptTemplate(input_variables=["question"], template=prompt_template)
# LLM Chain
llm_chain = LLMChain(llm=llm, prompt=prompt)
# vector search
retriever.similarity_search(query)
llm_chain.run(query)
```
[![Open In Colab](../../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/examples/Advance-RAG-with-HyDE/main.ipynb)

101
docs/src/rag/agentic_rag.md Normal file
View File

@@ -0,0 +1,101 @@
**Agentic RAG 🤖**
====================================================================
Agentic RAG is Agent-based RAG introduces an advanced framework for answering questions by using intelligent agents instead of just relying on large language models. These agents act like expert researchers, handling complex tasks such as detailed planning, multi-step reasoning, and using external tools. They navigate multiple documents, compare information, and generate accurate answers. This system is easily scalable, with each new document set managed by a sub-agent, making it a powerful tool for tackling a wide range of information needs.
<figure markdown="span">
![agent-based-rag](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/agentic_rag.png)
<figcaption>Agent-based RAG</figcaption>
</figure>
[![Open In Colab](../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/tutorials/Agentic_RAG/main.ipynb)
Heres a code snippet for defining retriever using Langchain
```python
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import LanceDB
from langchain_openai import OpenAIEmbeddings
urls = [
"https://content.dgft.gov.in/Website/CIEP.pdf",
"https://content.dgft.gov.in/Website/GAE.pdf",
"https://content.dgft.gov.in/Website/HTE.pdf",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=100, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)
# add documents in LanceDB
vectorstore = LanceDB.from_documents(
documents=doc_splits,
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
```
Agent that formulates an improved query for better retrieval results and then grades the retrieved documents
```python
def grade_documents(state) -> Literal["generate", "rewrite"]:
class grade(BaseModel):
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
llm_with_tool = model.with_structured_output(grade)
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
input_variables=["context", "question"],
)
chain = prompt | llm_with_tool
messages = state["messages"]
last_message = messages[-1]
question = messages[0].content
docs = last_message.content
scored_result = chain.invoke({"question": question, "context": docs})
score = scored_result.binary_score
return "generate" if score == "yes" else "rewrite"
def agent(state):
messages = state["messages"]
model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo")
model = model.bind_tools(tools)
response = model.invoke(messages)
return {"messages": [response]}
def rewrite(state):
messages = state["messages"]
question = messages[0].content
msg = [
HumanMessage(
content=f""" \n
Look at the input and try to reason about the underlying semantic intent / meaning. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate an improved question: """,
)
]
model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
response = model.invoke(msg)
return {"messages": [response]}
```
[![Open In Colab](../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/tutorials/Agentic_RAG/main.ipynb)

View File

@@ -0,0 +1,120 @@
**Corrective RAG ✅**
====================================================================
Corrective-RAG (CRAG) is a strategy for Retrieval-Augmented Generation (RAG) that includes self-reflection and self-grading of retrieved documents. Heres a simplified breakdown of the steps involved:
1. **Relevance Check**: If at least one document meets the relevance threshold, the process moves forward to the generation phase.
2. **Knowledge Refinement**: Before generating an answer, the process refines the knowledge by dividing the document into smaller segments called "knowledge strips."
3. **Grading and Filtering**: Each "knowledge strip" is graded, and irrelevant ones are filtered out.
4. **Additional Data Source**: If all documents are below the relevance threshold, or if the system is unsure about their relevance, it will seek additional information by performing a web search to supplement the retrieved data.
Above steps are mentioned in
**[Official Paper](https://arxiv.org/abs/2401.15884)**
<figure markdown="span">
![agent-based-rag](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/crag_paper.png)
<figcaption>Corrective RAG: <a href="https://github.com/HuskyInSalt/CRAG">Source</a>
</figcaption>
</figure>
Corrective Retrieval-Augmented Generation (CRAG) is a method that works like a **built-in fact-checker**.
**[Offical Implementation](https://github.com/HuskyInSalt/CRAG)**
[![Open In Colab](../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/tutorials/Corrective-RAG-with_Langgraph/CRAG_with_Langgraph.ipynb)
Heres a code snippet for defining a table with the [Embedding API](https://lancedb.github.io/lancedb/embeddings/embedding_functions/), and retrieves the relevant documents.
```python
import pandas as pd
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
db = lancedb.connect("/tmp/db")
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
class Docs(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
table = db.create_table("docs", schema=Docs)
# considering chunks are in list format
df = pd.DataFrame({'text':chunks})
table.add(data=df)
# as per document feeded
query = "How Transformers work?"
actual = table.search(query).limit(1).to_list()[0]
print(actual.text)
```
Code snippet for grading retrieved documents, filtering out irrelevant ones, and performing a web search if necessary:
```python
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with relevant documents
"""
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
class grade(BaseModel):
"""
Binary score for relevance check
"""
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
# grading using openai
grade_tool_oai = convert_to_openai_tool(grade)
llm_with_tool = model.bind(
tools=[convert_to_openai_tool(grade_tool_oai)],
tool_choice={"type": "function", "function": {"name": "grade"}},
)
parser_tool = PydanticToolsParser(tools=[grade])
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
input_variables=["context", "question"],
)
chain = prompt | llm_with_tool | parser_tool
filtered_docs = []
search = "No"
for d in documents:
score = chain.invoke({"question": question, "context": d.page_content})
grade = score[0].binary_score
if grade == "yes":
filtered_docs.append(d)
else:
search = "Yes"
continue
return {
"keys": {
"documents": filtered_docs,
"question": question,
"run_web_search": search,
}
}
```
Check Colab for the Implementation of CRAG with Langgraph
[![Open In Colab](../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/tutorials/Corrective-RAG-with_Langgraph/CRAG_with_Langgraph.ipynb)

54
docs/src/rag/graph_rag.md Normal file
View File

@@ -0,0 +1,54 @@
**Graph RAG 📊**
====================================================================
Graph RAG uses knowledge graphs together with large language models (LLMs) to improve how information is retrieved and generated. It overcomes the limits of traditional search methods by using knowledge graphs, which organize data as connected entities and relationships.
One of the main benefits of Graph RAG is its ability to capture and represent complex relationships between entities, something that traditional text-based retrieval systems struggle with. By using this structured knowledge, LLMs can better grasp the context and details of a query, resulting in more accurate and insightful answers.
**[Official Paper](https://arxiv.org/pdf/2404.16130)**
**[Offical Implementation](https://github.com/microsoft/graphrag)**
[Microsoft Research Blog](https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/)
!!! note "Default VectorDB"
Graph RAG uses LanceDB as the default vector database for performing vector search to retrieve relevant entities.
Working with Graph RAG is quite straightforward
- **Installation and API KEY as env variable**
Set `OPENAI_API_KEY` as `GRAPHRAG_API_KEY`
```bash
pip install graphrag
export GRAPHRAG_API_KEY="sk-..."
```
- **Initial structure for indexing dataset**
```bash
python3 -m graphrag.index --init --root dataset-dir
```
- **Index Dataset**
```bash
python3 -m graphrag.index --root dataset-dir
```
- **Execute Query**
Global Query Execution gives a broad overview of dataset
```bash
python3 -m graphrag.query --root dataset-dir --method global "query-question"
```
Local Query Execution gives a detailed and specific answers based on the context of the entities
```bash
python3 -m graphrag.query --root dataset-dir --method local "query-question"
```
[![Open In Colab](../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/examples/Graphrag/main.ipynb)

View File

@@ -0,0 +1,49 @@
**Multi-Head RAG 📃**
====================================================================
Multi-head RAG (MRAG) is designed to handle queries that need multiple documents with diverse content. These queries are tough because the documents embeddings can be far apart, making retrieval difficult. MRAG simplifies this by using the activations from a Transformer's multi-head attention layer, rather than the decoder layer, to fetch these varied documents. Different attention heads capture different aspects of the data, so using these activations helps create embeddings that better represent various data facets and improves retrieval accuracy for complex queries.
**[Official Paper](https://arxiv.org/pdf/2406.05085)**
<figure markdown="span">
![agent-based-rag](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/mrag-paper.png)
<figcaption>Multi-Head RAG: <a href="https://github.com/spcl/MRAG">Source</a>
</figcaption>
</figure>
MRAG is cost-effective and energy-efficient because it avoids extra LLM queries, multiple model instances, increased storage, and additional inference passes.
**[Official Implementation](https://github.com/spcl/MRAG)**
Heres a code snippet for defining different embedding spaces with the [Embedding API](https://lancedb.github.io/lancedb/embeddings/embedding_functions/)
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
# model definition using LanceDB Embedding API
model1 = get_registry().get("openai").create()
model2 = get_registry().get("ollama").create(name="llama3")
model3 = get_registry().get("ollama").create(name="mistral")
# define schema for creating embedding spaces with Embedding API
class Space1(LanceModel):
text: str = model1.SourceField()
vector: Vector(model1.ndims()) = model1.VectorField()
class Space2(LanceModel):
text: str = model2.SourceField()
vector: Vector(model2.ndims()) = model2.VectorField()
class Space3(LanceModel):
text: str = model3.SourceField()
vector: Vector(model3.ndims()) = model3.VectorField()
```
Create different tables using defined embedding spaces, then make queries to each embedding space. Use the resulted closest documents from each embedding space to generate answers.

96
docs/src/rag/self_rag.md Normal file
View File

@@ -0,0 +1,96 @@
**Self RAG 🤳**
====================================================================
Self-RAG is a strategy for Retrieval-Augmented Generation (RAG) to get better retrieved information, generated text, and checking their own work, all without losing their flexibility. Unlike the traditional Retrieval-Augmented Generation (RAG) method, Self-RAG retrieves information as needed, can skip retrieval if not needed, and evaluates its own output while generating text. It also uses a process to pick the best output based on different preferences.
**[Official Paper](https://arxiv.org/pdf/2310.11511)**
<figure markdown="span">
![agent-based-rag](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/self_rag.png)
<figcaption>Self RAG: <a href="https://github.com/AkariAsai/self-rag">Source</a>
</figcaption>
</figure>
**[Offical Implementation](https://github.com/AkariAsai/self-rag)**
Self-RAG starts by generating a response without retrieving extra info if it's not needed. For questions that need more details, it retrieves to get the necessary information.
Heres a code snippet for defining retriever using Langchain
```python
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import LanceDB
from langchain_openai import OpenAIEmbeddings
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=100, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)
# add documents in LanceDB
vectorstore = LanceDB.from_documents(
documents=doc_splits,
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
```
Functions that grades the retrieved documents and if required formulates an improved query for better retrieval results
```python
def grade_documents(state) -> Literal["generate", "rewrite"]:
class grade(BaseModel):
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
llm_with_tool = model.with_structured_output(grade)
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
input_variables=["context", "question"],
)
chain = prompt | llm_with_tool
messages = state["messages"]
last_message = messages[-1]
question = messages[0].content
docs = last_message.content
scored_result = chain.invoke({"question": question, "context": docs})
score = scored_result.binary_score
return "generate" if score == "yes" else "rewrite"
def rewrite(state):
messages = state["messages"]
question = messages[0].content
msg = [
HumanMessage(
content=f""" \n
Look at the input and try to reason about the underlying semantic intent / meaning. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate an improved question: """,
)
]
model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
response = model.invoke(msg)
return {"messages": [response]}
```

View File

@@ -0,0 +1,54 @@
**Vanilla RAG 🌱**
====================================================================
RAG(Retrieval-Augmented Generation) works by finding documents related to the user's question, combining them with a prompt for a large language model (LLM), and then using the LLM to create more accurate and relevant answers.
Heres a simple guide to building a RAG pipeline from scratch:
1. **Data Loading**: Gather and load the documents you want to use for answering questions.
2. **Chunking and Embedding**: Split the documents into smaller chunks and convert them into numerical vectors (embeddings) that capture their meaning.
3. **Vector Store**: Create a LanceDB table to store and manage these vectors for quick access during retrieval.
4. **Retrieval & Prompt Preparation**: When a question is asked, find the most relevant document chunks from the table and prepare a prompt combining these chunks with the question.
5. **Answer Generation**: Send the prepared prompt to a LLM to generate a detailed and accurate answer.
<figure markdown="span">
![agent-based-rag](https://raw.githubusercontent.com/lancedb/assets/main/docs/assets/rag/rag_from_scratch.png)
<figcaption>Vanilla RAG
</figcaption>
</figure>
[![Open In Colab](../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/tutorials/RAG-from-Scratch/RAG_from_Scratch.ipynb)
Heres a code snippet for defining a table with the [Embedding API](https://lancedb.github.io/lancedb/embeddings/embedding_functions/), which simplifies the process by handling embedding extraction and querying in one step.
```python
import pandas as pd
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
db = lancedb.connect("/tmp/db")
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
class Docs(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
table = db.create_table("docs", schema=Docs)
# considering chunks are in list format
df = pd.DataFrame({'text':chunks})
table.add(data=df)
query = "What is issue date of lease?"
actual = table.search(query).limit(1).to_list()[0]
print(actual.text)
```
Check Colab for the complete code
[![Open In Colab](../assets/colab.svg)](https://colab.research.google.com/github/lancedb/vectordb-recipes/blob/main/tutorials/RAG-from-Scratch/RAG_from_Scratch.ipynb)

View File

@@ -1,6 +1,9 @@
# Linear Combination Reranker
This is the default re-ranker used by LanceDB hybrid search. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
!!! note
This is depricated. It is recommended to use the `RRFReranker` instead, if you want to use a score based reranker.
It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
!!! note
Supported Query Types: Hybrid

View File

@@ -1,6 +1,6 @@
# Reciprocal Rank Fusion Reranker
Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf).
This is the default re-ranker used by LanceDB hybrid search. Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf).
!!! note

View File

@@ -20,7 +20,11 @@ excluded_globs = [
"../src/reranking/*.md",
"../src/guides/tuning_retrievers/*.md",
"../src/embeddings/available_embedding_models/text_embedding_functions/*.md",
"../src/embeddings/available_embedding_models/multimodal_embedding_functions/*.md"
"../src/embeddings/available_embedding_models/multimodal_embedding_functions/*.md",
"../src/rag/*.md",
"../src/rag/advanced_techniques/*.md"
]
python_prefix = "py"

View File

@@ -2,7 +2,7 @@
name = "lancedb-jni"
description = "JNI bindings for LanceDB"
# TODO modify lancedb/Cargo.toml for version and dependencies
version = "0.4.18"
version = "0.10.0"
edition.workspace = true
repository.workspace = true
readme.workspace = true

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.0.3</version>
<version>0.11.0-beta.1</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -44,7 +44,7 @@
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
<scope>test</scope>
</dependency>
</dependencies>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.0.3</version>
<version>0.11.0-beta.1</version>
<packaging>pom</packaging>
<name>LanceDB Parent</name>
@@ -92,7 +92,7 @@
</repository>
</distributionManagement>
<build>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
@@ -167,7 +167,8 @@
<version>3.2.5</version>
<configuration>
<argLine>--add-opens=java.base/java.nio=ALL-UNNAMED</argLine>
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
<forkNode
implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory" />
<useSystemClassLoader>false</useSystemClassLoader>
</configuration>
</plugin>
@@ -183,7 +184,7 @@
</pluginManagement>
</build>
<profiles>
<profiles>
<profile>
<id>jdk8</id>
<activation>
@@ -210,7 +211,8 @@
<version>3.2.5</version>
<configuration>
<argLine>--add-opens=java.base/java.nio=ALL-UNNAMED</argLine>
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory" />
<forkNode
implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory" />
<useSystemClassLoader>false</useSystemClassLoader>
</configuration>
</plugin>

View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"cpu": [
"x64",
"arm64"

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",

View File

@@ -60,7 +60,7 @@ export {
type MakeArrowTableOptions
} from "./arrow";
const defaultAwsRegion = "us-west-2";
const defaultAwsRegion = "us-east-1";
const defaultRequestTimeout = 10_000
@@ -111,7 +111,7 @@ export interface ConnectionOptions {
*/
apiKey?: string
/** Region to connect */
/** Region to connect. Default is 'us-east-1' */
region?: string
/**
@@ -197,28 +197,32 @@ export async function connect(
export async function connect(
arg: string | Partial<ConnectionOptions>
): Promise<Connection> {
let opts: ConnectionOptions;
let partOpts: Partial<ConnectionOptions>;
if (typeof arg === "string") {
opts = { uri: arg };
partOpts = { uri: arg };
} else {
const keys = Object.keys(arg);
if (keys.length === 1 && keys[0] === "uri" && typeof arg.uri === "string") {
opts = { uri: arg.uri };
partOpts = { uri: arg.uri };
} else {
opts = Object.assign(
{
uri: "",
awsCredentials: undefined,
awsRegion: defaultAwsRegion,
apiKey: undefined,
region: defaultAwsRegion,
timeout: defaultRequestTimeout
},
arg
);
partOpts = arg;
}
}
let defaultRegion = process.env.AWS_REGION ?? process.env.AWS_DEFAULT_REGION;
defaultRegion = (defaultRegion ?? "").trim() !== "" ? defaultRegion : defaultAwsRegion;
const opts: ConnectionOptions = {
uri: partOpts.uri ?? "",
awsCredentials: partOpts.awsCredentials ?? undefined,
awsRegion: partOpts.awsRegion ?? defaultRegion,
apiKey: partOpts.apiKey ?? undefined,
region: partOpts.region ?? defaultRegion,
timeout: partOpts.timeout ?? defaultRequestTimeout,
readConsistencyInterval: partOpts.readConsistencyInterval ?? undefined,
storageOptions: partOpts.storageOptions ?? undefined,
hostOverride: partOpts.hostOverride ?? undefined
}
if (opts.uri.startsWith("db://")) {
// Remote connection
return new RemoteConnection(opts);

View File

@@ -33,6 +33,7 @@ export class Query<T = number[]> {
private _filter?: string
private _metricType?: MetricType
private _prefilter: boolean
private _fastSearch: boolean
protected readonly _embeddings?: EmbeddingFunction<T>
constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
@@ -46,6 +47,7 @@ export class Query<T = number[]> {
this._metricType = undefined
this._embeddings = embeddings
this._prefilter = false
this._fastSearch = false
}
/***
@@ -110,6 +112,15 @@ export class Query<T = number[]> {
return this
}
/**
* Skip searching un-indexed data. This can make search faster, but will miss
* any data that is not yet indexed.
*/
fastSearch (value: boolean): Query<T> {
this._fastSearch = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/

View File

@@ -17,6 +17,7 @@ import axios, { type AxiosResponse, type ResponseType } from 'axios'
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
import { type RemoteResponse, type RemoteRequest, Method } from '../middleware'
import { MetricType } from '..'
interface HttpLancedbClientMiddleware {
onRemoteRequest(
@@ -82,7 +83,7 @@ async function callWithMiddlewares (
interface MiddlewareInvocationOptions {
responseType?: ResponseType
timeout?: number,
timeout?: number
}
/**
@@ -130,8 +131,8 @@ export class HttpLancedbClient {
url: string,
apiKey: string,
timeout?: number,
private readonly _dbName?: string,
private readonly _dbName?: string
) {
this._url = url
this._apiKey = () => apiKey
@@ -151,7 +152,9 @@ export class HttpLancedbClient {
prefilter: boolean,
refineFactor?: number,
columns?: string[],
filter?: string
filter?: string,
metricType?: MetricType,
fastSearch?: boolean
): Promise<ArrowTable<any>> {
const result = await this.post(
`/v1/table/${tableName}/query/`,
@@ -159,10 +162,12 @@ export class HttpLancedbClient {
vector,
k,
nprobes,
refineFactor,
refine_factor: refineFactor,
columns,
filter,
prefilter
prefilter,
metric: metricType,
fast_search: fastSearch
},
undefined,
undefined,
@@ -237,7 +242,7 @@ export class HttpLancedbClient {
try {
response = await callWithMiddlewares(req, this._middlewares, {
responseType,
timeout: this._timeout,
timeout: this._timeout
})
// return response

View File

@@ -238,7 +238,9 @@ export class RemoteQuery<T = number[]> extends Query<T> {
(this as any)._prefilter,
(this as any)._refineFactor,
(this as any)._select,
(this as any)._filter
(this as any)._filter,
(this as any)._metricType,
(this as any)._fastSearch
)
return data.toArray().map((entry: Record<string, unknown>) => {

View File

@@ -112,8 +112,8 @@ describe("LanceDB client", function () {
name: 'name_2',
price: 10,
is_active: true,
vector: [ 0, 0.1 ]
},
vector: [0, 0.1]
}
]);
assert.equal(await table2.countRows(), 3);
});

View File

@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { readdirSync } from "fs";
import { Field, Float64, Schema } from "apache-arrow";
import * as tmp from "tmp";
import { Connection, Table, connect } from "../lancedb";
import { LocalTable } from "../lancedb/table";
describe("when connecting", () => {
let tmpDir: tmp.DirResult;
@@ -105,7 +107,7 @@ describe("given a connection", () => {
const data = [...Array(10000).keys()].map((i) => ({ id: i }));
// Create in v1 mode
let table = await db.createTable("test", data);
let table = await db.createTable("test", data, { useLegacyFormat: true });
const isV2 = async (table: Table) => {
const data = await table.query().toArrow({ maxBatchLength: 100000 });
@@ -116,7 +118,7 @@ describe("given a connection", () => {
await expect(isV2(table)).resolves.toBe(false);
// Create in v2 mode
table = await db.createTable("test_v2", data, { useLegacyFormat: false });
table = await db.createTable("test_v2", data);
await expect(isV2(table)).resolves.toBe(true);
@@ -134,4 +136,57 @@ describe("given a connection", () => {
await table.add(data);
await expect(isV2(table)).resolves.toBe(true);
});
it("should be able to create tables with V2 manifest paths", async () => {
const db = await connect(tmpDir.name);
let table = (await db.createEmptyTable(
"test_manifest_paths_v2_empty",
new Schema([new Field("id", new Float64(), true)]),
{
enableV2ManifestPaths: true,
},
)) as LocalTable;
expect(await table.usesV2ManifestPaths()).toBe(true);
let manifestDir =
tmpDir.name + "/test_manifest_paths_v2_empty.lance/_versions";
readdirSync(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d{20}\.manifest$/);
});
table = (await db.createTable("test_manifest_paths_v2", [{ id: 1 }], {
enableV2ManifestPaths: true,
})) as LocalTable;
expect(await table.usesV2ManifestPaths()).toBe(true);
manifestDir = tmpDir.name + "/test_manifest_paths_v2.lance/_versions";
readdirSync(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d{20}\.manifest$/);
});
});
it("should be able to migrate tables to the V2 manifest paths", async () => {
const db = await connect(tmpDir.name);
const table = (await db.createEmptyTable(
"test_manifest_path_migration",
new Schema([new Field("id", new Float64(), true)]),
{
enableV2ManifestPaths: false,
},
)) as LocalTable;
expect(await table.usesV2ManifestPaths()).toBe(false);
const manifestDir =
tmpDir.name + "/test_manifest_path_migration.lance/_versions";
readdirSync(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d\.manifest$/);
});
await table.migrateManifestPathsV2();
expect(await table.usesV2ManifestPaths()).toBe(true);
readdirSync(manifestDir).forEach((file) => {
expect(file).toMatch(/^\d{20}\.manifest$/);
});
});
});

View File

@@ -396,6 +396,10 @@ describe("When creating an index", () => {
.toArrow();
expect(rst2.numRows).toBe(2);
expect(rst.toString()).toEqual(rst2.toString());
// test offset
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
expect(rst.numRows).toBe(1);
});
it("should allow parameters to be specified", async () => {
@@ -440,6 +444,26 @@ describe("When creating an index", () => {
expect(fs.readdirSync(indexDir)).toHaveLength(1);
});
test("create a hnswPq index", async () => {
await tbl.createIndex("vec", {
config: Index.hnswPq({
numPartitions: 10,
}),
});
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
});
test("create a HnswSq index", async () => {
await tbl.createIndex("vec", {
config: Index.hnswSq({
numPartitions: 10,
}),
});
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
});
test("create a label list index", async () => {
await tbl.createIndex("tags", {
config: Index.labelList(),
@@ -840,6 +864,38 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])(
expect(results[0].text).toBe(data[0].text);
});
test("full text search without positions", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
await table.createIndex("text", {
config: Index.fts({ withPosition: false }),
});
const results = await table.search("hello").toArray();
expect(results[0].text).toBe(data[0].text);
});
test("full text search phrase query", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
await table.createIndex("text", {
config: Index.fts(),
});
const results = await table.search("world").toArray();
expect(results.length).toBe(2);
const phraseResults = await table.search('"hello world"').toArray();
expect(phraseResults.length).toBe(1);
});
test.each([
[0.4, 0.5, 0.599], // number[]
Float32Array.of(0.4, 0.5, 0.599), // Float32Array

View File

@@ -44,20 +44,30 @@ export interface CreateTableOptions {
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>;
/**
* The version of the data storage format to use.
*
* The default is `legacy`, which is Lance format v1.
* `stable` is the new format, which is Lance format v2.
* The default is `stable`.
* Set to "legacy" to use the old format.
*/
dataStorageVersion?: string;
/**
* Use the new V2 manifest paths. These paths provide more efficient
* opening of datasets with many versions on object stores. WARNING:
* turning this on will make the dataset unreadable for older versions
* of LanceDB (prior to 0.10.0). To migrate an existing dataset, instead
* use the {@link LocalTable#migrateManifestPathsV2} method.
*/
enableV2ManifestPaths?: boolean;
/**
* If true then data files will be written with the legacy format
*
* The default is true while the new format is in beta
* The default is false.
*
* Deprecated.
* Deprecated. Use data storage version instead.
*/
useLegacyFormat?: boolean;
schema?: SchemaLike;
@@ -257,7 +267,7 @@ export class LocalConnection extends Connection {
throw new Error("data is required");
}
const { buf, mode } = await Table.parseTableData(data, options);
let dataStorageVersion = "legacy";
let dataStorageVersion = "stable";
if (options?.dataStorageVersion !== undefined) {
dataStorageVersion = options.dataStorageVersion;
} else if (options?.useLegacyFormat !== undefined) {
@@ -270,6 +280,7 @@ export class LocalConnection extends Connection {
mode,
cleanseStorageOptions(options?.storageOptions),
dataStorageVersion,
options?.enableV2ManifestPaths,
);
return new LocalTable(innerTable);
@@ -293,7 +304,7 @@ export class LocalConnection extends Connection {
metadata = registry.getTableMetadata([embeddingFunction]);
}
let dataStorageVersion = "legacy";
let dataStorageVersion = "stable";
if (options?.dataStorageVersion !== undefined) {
dataStorageVersion = options.dataStorageVersion;
} else if (options?.useLegacyFormat !== undefined) {
@@ -308,6 +319,7 @@ export class LocalConnection extends Connection {
mode,
cleanseStorageOptions(options?.storageOptions),
dataStorageVersion,
options?.enableV2ManifestPaths,
);
return new LocalTable(innerTable);
}

View File

@@ -113,6 +113,234 @@ export interface IvfPqOptions {
sampleRate?: number;
}
/**
* Options to create an `HNSW_PQ` index
*/
export interface HnswPqOptions {
/**
* The distance metric used to train the index.
*
* Default value is "l2".
*
* The following distance types are available:
*
* "l2" - Euclidean distance. This is a very common distance metric that
* accounts for both magnitude and direction when determining the distance
* between vectors. L2 distance has a range of [0, ∞).
*
* "cosine" - Cosine distance. Cosine distance is a distance metric
* calculated from the cosine similarity between two vectors. Cosine
* similarity is a measure of similarity between two non-zero vectors of an
* inner product space. It is defined to equal the cosine of the angle
* between them. Unlike L2, the cosine distance is not affected by the
* magnitude of the vectors. Cosine distance has a range of [0, 2].
*
* "dot" - Dot product. Dot distance is the dot product of two vectors. Dot
* distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
* L2 norm is 1), then dot distance is equivalent to the cosine distance.
*/
distanceType?: "l2" | "cosine" | "dot";
/**
* The number of IVF partitions to create.
*
* For HNSW, we recommend a small number of partitions. Setting this to 1 works
* well for most tables. For very large tables, training just one HNSW graph
* will require too much memory. Each partition becomes its own HNSW graph, so
* setting this value higher reduces the peak memory use of training.
*
*/
numPartitions?: number;
/**
* Number of sub-vectors of PQ.
*
* This value controls how much the vector is compressed during the quantization step.
* The more sub vectors there are the less the vector is compressed. The default is
* the dimension of the vector divided by 16. If the dimension is not evenly divisible
* by 16 we use the dimension divded by 8.
*
* The above two cases are highly preferred. Having 8 or 16 values per subvector allows
* us to use efficient SIMD instructions.
*
* If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
* will likely result in poor performance.
*
*/
numSubVectors?: number;
/**
* Max iterations to train kmeans.
*
* The default value is 50.
*
* When training an IVF index we use kmeans to calculate the partitions. This parameter
* controls how many iterations of kmeans to run.
*
* Increasing this might improve the quality of the index but in most cases the parameter
* is unused because kmeans will converge with fewer iterations. The parameter is only
* used in cases where kmeans does not appear to converge. In those cases it is unlikely
* that setting this larger will lead to the index converging anyways.
*
*/
maxIterations?: number;
/**
* The rate used to calculate the number of training vectors for kmeans.
*
* Default value is 256.
*
* When an IVF index is trained, we need to calculate partitions. These are groups
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
*
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
* random sample of the data. This parameter controls the size of the sample. The total
* number of vectors used to train the index is `sample_rate * num_partitions`.
*
* Increasing this value might improve the quality of the index but in most cases the
* default should be sufficient.
*
*/
sampleRate?: number;
/**
* The number of neighbors to select for each vector in the HNSW graph.
*
* The default value is 20.
*
* This value controls the tradeoff between search speed and accuracy.
* The higher the value the more accurate the search but the slower it will be.
*
*/
m?: number;
/**
* The number of candidates to evaluate during the construction of the HNSW graph.
*
* The default value is 300.
*
* This value controls the tradeoff between build speed and accuracy.
* The higher the value the more accurate the build but the slower it will be.
* 150 to 300 is the typical range. 100 is a minimum for good quality search
* results. In most cases, there is no benefit to setting this higher than 500.
* This value should be set to a value that is not less than `ef` in the search phase.
*
*/
efConstruction?: number;
}
/**
* Options to create an `HNSW_SQ` index
*/
export interface HnswSqOptions {
/**
* The distance metric used to train the index.
*
* Default value is "l2".
*
* The following distance types are available:
*
* "l2" - Euclidean distance. This is a very common distance metric that
* accounts for both magnitude and direction when determining the distance
* between vectors. L2 distance has a range of [0, ∞).
*
* "cosine" - Cosine distance. Cosine distance is a distance metric
* calculated from the cosine similarity between two vectors. Cosine
* similarity is a measure of similarity between two non-zero vectors of an
* inner product space. It is defined to equal the cosine of the angle
* between them. Unlike L2, the cosine distance is not affected by the
* magnitude of the vectors. Cosine distance has a range of [0, 2].
*
* "dot" - Dot product. Dot distance is the dot product of two vectors. Dot
* distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
* L2 norm is 1), then dot distance is equivalent to the cosine distance.
*/
distanceType?: "l2" | "cosine" | "dot";
/**
* The number of IVF partitions to create.
*
* For HNSW, we recommend a small number of partitions. Setting this to 1 works
* well for most tables. For very large tables, training just one HNSW graph
* will require too much memory. Each partition becomes its own HNSW graph, so
* setting this value higher reduces the peak memory use of training.
*
*/
numPartitions?: number;
/**
* Max iterations to train kmeans.
*
* The default value is 50.
*
* When training an IVF index we use kmeans to calculate the partitions. This parameter
* controls how many iterations of kmeans to run.
*
* Increasing this might improve the quality of the index but in most cases the parameter
* is unused because kmeans will converge with fewer iterations. The parameter is only
* used in cases where kmeans does not appear to converge. In those cases it is unlikely
* that setting this larger will lead to the index converging anyways.
*
*/
maxIterations?: number;
/**
* The rate used to calculate the number of training vectors for kmeans.
*
* Default value is 256.
*
* When an IVF index is trained, we need to calculate partitions. These are groups
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
*
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
* random sample of the data. This parameter controls the size of the sample. The total
* number of vectors used to train the index is `sample_rate * num_partitions`.
*
* Increasing this value might improve the quality of the index but in most cases the
* default should be sufficient.
*
*/
sampleRate?: number;
/**
* The number of neighbors to select for each vector in the HNSW graph.
*
* The default value is 20.
*
* This value controls the tradeoff between search speed and accuracy.
* The higher the value the more accurate the search but the slower it will be.
*
*/
m?: number;
/**
* The number of candidates to evaluate during the construction of the HNSW graph.
*
* The default value is 300.
*
* This value controls the tradeoff between build speed and accuracy.
* The higher the value the more accurate the build but the slower it will be.
* 150 to 300 is the typical range. 100 is a minimum for good quality search
* results. In most cases, there is no benefit to setting this higher than 500.
* This value should be set to a value that is not less than `ef` in the search phase.
*
*/
efConstruction?: number;
}
/**
* Options to create a full text search index
*/
export interface FtsOptions {
/**
* Whether to build the index with positions.
* True by default.
* If set to false, the index will not store the positions of the tokens in the text,
* which will make the index smaller and faster to build, but will not support phrase queries.
*/
withPosition?: boolean;
}
export class Index {
private readonly inner: LanceDbIndex;
private constructor(inner: LanceDbIndex) {
@@ -211,8 +439,53 @@ export class Index {
*
* For now, the full text search index only supports English, and doesn't support phrase search.
*/
static fts() {
return new Index(LanceDbIndex.fts());
static fts(options?: Partial<FtsOptions>) {
return new Index(LanceDbIndex.fts(options?.withPosition));
}
/**
*
* Create a hnswPq index
*
* HNSW-PQ stands for Hierarchical Navigable Small World - Product Quantization.
* It is a variant of the HNSW algorithm that uses product quantization to compress
* the vectors.
*
*/
static hnswPq(options?: Partial<HnswPqOptions>) {
return new Index(
LanceDbIndex.hnswPq(
options?.distanceType,
options?.numPartitions,
options?.numSubVectors,
options?.maxIterations,
options?.sampleRate,
options?.m,
options?.efConstruction,
),
);
}
/**
*
* Create a hnswSq index
*
* HNSW-SQ stands for Hierarchical Navigable Small World - Scalar Quantization.
* It is a variant of the HNSW algorithm that uses scalar quantization to compress
* the vectors.
*
*/
static hnswSq(options?: Partial<HnswSqOptions>) {
return new Index(
LanceDbIndex.hnswSq(
options?.distanceType,
options?.numPartitions,
options?.maxIterations,
options?.sampleRate,
options?.m,
options?.efConstruction,
),
);
}
}

View File

@@ -234,6 +234,11 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
return this;
}
offset(offset: number): this {
this.doCall((inner: NativeQueryType) => inner.offset(offset));
return this;
}
protected nativeExecute(
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {

View File

@@ -697,4 +697,31 @@ export class LocalTable extends Table {
on = Array.isArray(on) ? on : [on];
return new MergeInsertBuilder(this.inner.mergeInsert(on));
}
/**
* Check if the table uses the new manifest path scheme.
*
* This function will return true if the table uses the V2 manifest
* path scheme.
*/
async usesV2ManifestPaths(): Promise<boolean> {
return await this.inner.usesV2ManifestPaths();
}
/**
* Migrate the table to use the new manifest path scheme.
*
* This function will rename all V1 manifests to V2 manifest paths.
* These paths provide more efficient opening of datasets with many versions
* on object stores.
*
* This function is idempotent, and can be run multiple times without
* changing the state of the object store.
*
* However, it should not be run while other concurrent operations are happening.
* And it should also run until completion before resuming other operations.
*/
async migrateManifestPathsV2(): Promise<void> {
await this.inner.migrateManifestPathsV2();
}
}

208
nodejs/native.d.ts vendored
View File

@@ -1,208 +0,0 @@
/* tslint:disable */
/* eslint-disable */
/* auto-generated by NAPI-RS */
/** A description of an index currently configured on a column */
export interface IndexConfig {
/** The name of the index */
name: string
/** The type of the index */
indexType: string
/**
* The columns in the index
*
* Currently this is always an array of size 1. In the future there may
* be more columns to represent composite indices.
*/
columns: Array<string>
}
/** Statistics about a compaction operation. */
export interface CompactionStats {
/** The number of fragments removed */
fragmentsRemoved: number
/** The number of new, compacted fragments added */
fragmentsAdded: number
/** The number of data files removed */
filesRemoved: number
/** The number of new, compacted data files added */
filesAdded: number
}
/** Statistics about a cleanup operation */
export interface RemovalStats {
/** The number of bytes removed */
bytesRemoved: number
/** The number of old versions removed */
oldVersionsRemoved: number
}
/** Statistics about an optimize operation */
export interface OptimizeStats {
/** Statistics about the compaction operation */
compaction: CompactionStats
/** Statistics about the removal operation */
prune: RemovalStats
}
/**
* A definition of a column alteration. The alteration changes the column at
* `path` to have the new name `name`, to be nullable if `nullable` is true,
* and to have the data type `data_type`. At least one of `rename` or `nullable`
* must be provided.
*/
export interface ColumnAlteration {
/**
* The path to the column to alter. This is a dot-separated path to the column.
* If it is a top-level column then it is just the name of the column. If it is
* a nested column then it is the path to the column, e.g. "a.b.c" for a column
* `c` nested inside a column `b` nested inside a column `a`.
*/
path: string
/**
* The new name of the column. If not provided then the name will not be changed.
* This must be distinct from the names of all other columns in the table.
*/
rename?: string
/** Set the new nullability. Note that a nullable column cannot be made non-nullable. */
nullable?: boolean
}
/** A definition of a new column to add to a table. */
export interface AddColumnsSql {
/** The name of the new column. */
name: string
/**
* The values to populate the new column with, as a SQL expression.
* The expression can reference other columns in the table.
*/
valueSql: string
}
export interface IndexStatistics {
/** The number of rows indexed by the index */
numIndexedRows: number
/** The number of rows not indexed */
numUnindexedRows: number
/** The type of the index */
indexType?: string
/** The metadata for each index */
indices: Array<IndexMetadata>
}
export interface IndexMetadata {
metricType?: string
indexType?: string
}
export interface ConnectionOptions {
/**
* (For LanceDB OSS only): The interval, in seconds, at which to check for
* updates to the table from other processes. If None, then consistency is not
* checked. For performance reasons, this is the default. For strong
* consistency, set this to zero seconds. Then every read will check for
* updates from other processes. As a compromise, you can set this to a
* non-zero value for eventual consistency. If more than that interval
* has passed since the last check, then the table will be checked for updates.
* Note: this consistency only applies to read operations. Write operations are
* always consistent.
*/
readConsistencyInterval?: number
/**
* (For LanceDB OSS only): configuration for object storage.
*
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>
}
/** Write mode for writing a table. */
export const enum WriteMode {
Create = 'Create',
Append = 'Append',
Overwrite = 'Overwrite'
}
/** Write options when creating a Table. */
export interface WriteOptions {
/** Write mode for writing to a table. */
mode?: WriteMode
}
export interface OpenTableOptions {
storageOptions?: Record<string, string>
}
export class Connection {
/** Create a new Connection instance from the given URI. */
static new(uri: string, options: ConnectionOptions): Promise<Connection>
display(): string
isOpen(): boolean
close(): void
/** List all tables in the dataset. */
tableNames(startAfter?: string | undefined | null, limit?: number | undefined | null): Promise<Array<string>>
/**
* Create table from a Apache Arrow IPC (file) buffer.
*
* Parameters:
* - name: The name of the table.
* - buf: The buffer containing the IPC file.
*
*/
createTable(name: string, buf: Buffer, mode: string, storageOptions?: Record<string, string> | undefined | null, useLegacyFormat?: boolean | undefined | null): Promise<Table>
createEmptyTable(name: string, schemaBuf: Buffer, mode: string, storageOptions?: Record<string, string> | undefined | null, useLegacyFormat?: boolean | undefined | null): Promise<Table>
openTable(name: string, storageOptions?: Record<string, string> | undefined | null, indexCacheSize?: number | undefined | null): Promise<Table>
/** Drop table with the name. Or raise an error if the table does not exist. */
dropTable(name: string): Promise<void>
}
export class Index {
static ivfPq(distanceType?: string | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): Index
static btree(): Index
}
/** Typescript-style Async Iterator over RecordBatches */
export class RecordBatchIterator {
next(): Promise<Buffer | null>
}
/** A builder used to create and run a merge insert operation */
export class NativeMergeInsertBuilder {
whenMatchedUpdateAll(condition?: string | undefined | null): NativeMergeInsertBuilder
whenNotMatchedInsertAll(): NativeMergeInsertBuilder
whenNotMatchedBySourceDelete(filter?: string | undefined | null): NativeMergeInsertBuilder
execute(buf: Buffer): Promise<void>
}
export class Query {
onlyIf(predicate: string): void
select(columns: Array<[string, string]>): void
limit(limit: number): void
nearestTo(vector: Float32Array): VectorQuery
execute(maxBatchLength?: number | undefined | null): Promise<RecordBatchIterator>
explainPlan(verbose: boolean): Promise<string>
}
export class VectorQuery {
column(column: string): void
distanceType(distanceType: string): void
postfilter(): void
refineFactor(refineFactor: number): void
nprobes(nprobe: number): void
bypassVectorIndex(): void
onlyIf(predicate: string): void
select(columns: Array<[string, string]>): void
limit(limit: number): void
execute(maxBatchLength?: number | undefined | null): Promise<RecordBatchIterator>
explainPlan(verbose: boolean): Promise<string>
}
export class Table {
name: string
display(): string
isOpen(): boolean
close(): void
/** Return Schema as empty Arrow IPC file. */
schema(): Promise<Buffer>
add(buf: Buffer, mode: string): Promise<void>
countRows(filter?: string | undefined | null): Promise<number>
delete(predicate: string): Promise<void>
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
query(): Query
vectorSearch(vector: Float32Array): VectorQuery
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
dropColumns(columns: Array<string>): Promise<void>
version(): Promise<number>
checkout(version: number): Promise<void>
checkoutLatest(): Promise<void>
restore(): Promise<void>
optimize(olderThanMs?: number | undefined | null): Promise<OptimizeStats>
listIndices(): Promise<Array<IndexConfig>>
indexStats(indexName: string): Promise<IndexStatistics | null>
mergeInsert(on: Array<string>): NativeMergeInsertBuilder
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.8.0",
"version": "0.10.0-beta.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.8.0",
"version": "0.10.0-beta.1",
"cpu": [
"x64",
"arm64"

View File

@@ -10,7 +10,7 @@
"vector database",
"ann"
],
"version": "0.10.0-beta.1",
"version": "0.11.0-beta.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",
@@ -66,8 +66,8 @@
"os": ["darwin", "linux", "win32"],
"scripts": {
"artifacts": "napi artifacts",
"build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
"build:release": "napi build --platform --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
"build:debug": "napi build --platform --no-const-enum --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
"build:release": "napi build --platform --no-const-enum --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
"build": "npm run build:debug && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts && shx cp lancedb/*.node dist/",
"build-release": "npm run build:release && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts",
"lint-ci": "biome ci .",

View File

@@ -124,11 +124,13 @@ impl Connection {
mode: String,
storage_options: Option<HashMap<String, String>>,
data_storage_options: Option<String>,
enable_v2_manifest_paths: Option<bool>,
) -> napi::Result<Table> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let mode = Self::parse_create_mode_str(&mode)?;
let mut builder = self.get_inner()?.create_table(&name, batches).mode(mode);
if let Some(storage_options) = storage_options {
for (key, value) in storage_options {
builder = builder.storage_option(key, value);
@@ -140,6 +142,9 @@ impl Connection {
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?,
);
}
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
}
let tbl = builder
.execute()
.await
@@ -155,6 +160,7 @@ impl Connection {
mode: String,
storage_options: Option<HashMap<String, String>>,
data_storage_options: Option<String>,
enable_v2_manifest_paths: Option<bool>,
) -> napi::Result<Table> {
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
@@ -175,6 +181,9 @@ impl Connection {
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?,
);
}
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
}
let tbl = builder
.execute()
.await

View File

@@ -15,7 +15,7 @@
use std::sync::Mutex;
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder};
use lancedb::index::Index as LanceDbIndex;
use napi_derive::napi;
@@ -92,9 +92,85 @@ impl Index {
}
#[napi(factory)]
pub fn fts() -> Self {
pub fn fts(with_position: Option<bool>) -> Self {
let mut opts = FtsIndexBuilder::default();
if let Some(with_position) = with_position {
opts = opts.with_position(with_position);
}
Self {
inner: Mutex::new(Some(LanceDbIndex::FTS(FtsIndexBuilder::default()))),
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
}
}
#[napi(factory)]
pub fn hnsw_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> napi::Result<Self> {
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
})
}
#[napi(factory)]
pub fn hnsw_sq(
distance_type: Option<String>,
num_partitions: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> napi::Result<Self> {
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
}
if let Some(max_iterations) = max_iterations {
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
})
}
}

View File

@@ -64,6 +64,11 @@ impl Query {
self.inner = self.inner.clone().limit(limit as usize);
}
#[napi]
pub fn offset(&mut self, offset: u32) {
self.inner = self.inner.clone().offset(offset as usize);
}
#[napi]
pub fn nearest_to(&mut self, vector: Float32Array) -> Result<VectorQuery> {
let inner = self
@@ -166,6 +171,11 @@ impl VectorQuery {
self.inner = self.inner.clone().limit(limit as usize);
}
#[napi]
pub fn offset(&mut self, offset: u32) {
self.inner = self.inner.clone().offset(offset as usize);
}
#[napi(catch_unwind)]
pub async fn execute(
&self,

View File

@@ -156,7 +156,7 @@ impl Table {
&self,
only_if: Option<String>,
columns: Vec<(String, String)>,
) -> napi::Result<()> {
) -> napi::Result<u64> {
let mut op = self.inner_ref()?.update();
if let Some(only_if) = only_if {
op = op.only_if(only_if);
@@ -347,6 +347,26 @@ impl Table {
let on: Vec<_> = on.iter().map(String::as_str).collect();
Ok(self.inner_ref()?.merge_insert(on.as_slice()).into())
}
#[napi(catch_unwind)]
pub async fn uses_v2_manifest_paths(&self) -> napi::Result<bool> {
self.inner_ref()?
.as_native()
.ok_or_else(|| napi::Error::from_reason("This cannot be run on a remote table"))?
.uses_v2_manifest_paths()
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn migrate_manifest_paths_v2(&self) -> napi::Result<()> {
self.inner_ref()?
.as_native()
.ok_or_else(|| napi::Error::from_reason("This cannot be run on a remote table"))?
.migrate_manifest_paths_v2()
.await
.default_error()
}
}
#[napi(object)]

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.13.0-beta.1"
current_version = "0.14.0-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.13.0-beta.1"
version = "0.14.0-beta.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -3,8 +3,7 @@ name = "lancedb"
# version in Cargo.toml
dependencies = [
"deprecation",
"pylance==0.17.0",
"ratelimiter~=1.0",
"pylance==0.18.0",
"requests>=2.31.0",
"retry>=0.9.2",
"tqdm>=4.27.0",

View File

@@ -25,6 +25,7 @@ class Connection(object):
data: pa.RecordBatchReader,
storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None,
) -> Table: ...
async def create_empty_table(
self,
@@ -33,6 +34,7 @@ class Connection(object):
schema: pa.Schema,
storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None,
) -> Table: ...
class Table:

View File

@@ -567,6 +567,7 @@ class AsyncConnection(object):
*,
data_storage_version: Optional[str] = None,
use_legacy_format: Optional[bool] = None,
enable_v2_manifest_paths: Optional[bool] = None,
) -> AsyncTable:
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
@@ -609,15 +610,22 @@ class AsyncConnection(object):
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
data_storage_version: optional, str, default "legacy"
data_storage_version: optional, str, default "stable"
The version of the data storage format to use. Newer versions are more
efficient but require newer versions of lance to read. The default is
"legacy" which will use the legacy v1 version. See the user guide
"stable" which will use the legacy v2 version. See the user guide
for more details.
use_legacy_format: bool, optional, default True. (Deprecated)
use_legacy_format: bool, optional, default False. (Deprecated)
If True, use the legacy format for the table. If False, use the new format.
The default is True while the new format is in beta.
This method is deprecated, use `data_storage_version` instead.
enable_v2_manifest_paths: bool, optional, default False
Use the new V2 manifest paths. These paths provide more efficient
opening of datasets with many versions on object stores. WARNING:
turning this on will make the dataset unreadable for older versions
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
use the
[AsyncTable.migrate_manifest_paths_v2][lancedb.table.AsyncTable.migrate_manifest_paths_v2]
method.
Returns
@@ -750,9 +758,7 @@ class AsyncConnection(object):
mode = "exist_ok"
if not data_storage_version:
data_storage_version = (
"legacy" if use_legacy_format is None or use_legacy_format else "stable"
)
data_storage_version = "legacy" if use_legacy_format else "stable"
if data is None:
new_table = await self._inner.create_empty_table(
@@ -761,6 +767,7 @@ class AsyncConnection(object):
schema,
storage_options=storage_options,
data_storage_version=data_storage_version,
enable_v2_manifest_paths=enable_v2_manifest_paths,
)
else:
data = data_to_reader(data, schema)
@@ -770,6 +777,7 @@ class AsyncConnection(object):
data,
storage_options=storage_options,
data_storage_version=data_storage_version,
enable_v2_manifest_paths=enable_v2_manifest_paths,
)
return AsyncTable(new_table)

View File

@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
#
# The following code is originally from https://github.com/pola-rs/polars/blob/ea4389c31b0e87ddf20a85e4c3797b285966edb6/py-polars/polars/dependencies.py
# and is licensed under the MIT license:
#
# License: MIT, Copyright (c) 2020 Ritchie Vink
# https://github.com/pola-rs/polars/blob/main/LICENSE
#
# It has been modified by the LanceDB developers
# to fit the needs of the LanceDB project.
from __future__ import annotations
import re
import sys
from functools import lru_cache
from importlib import import_module
from importlib.util import find_spec
from types import ModuleType
from typing import TYPE_CHECKING, Any, ClassVar, Hashable, cast
_NUMPY_AVAILABLE = True
_PANDAS_AVAILABLE = True
_POLARS_AVAILABLE = True
_TORCH_AVAILABLE = True
_HUGGING_FACE_AVAILABLE = True
_TENSORFLOW_AVAILABLE = True
_RAY_AVAILABLE = True
class _LazyModule(ModuleType):
"""
Module that can act both as a lazy-loader and as a proxy.
Notes
-----
We do NOT register this module with `sys.modules` so as not to cause
confusion in the global environment. This way we have a valid proxy
module for our own use, but it lives _exclusively_ within lance.
"""
__lazy__ = True
_mod_pfx: ClassVar[dict[str, str]] = {
"numpy": "np.",
"pandas": "pd.",
"polars": "pl.",
"torch": "torch.",
"tensorflow": "tf.",
"ray": "ray.",
}
def __init__(
self,
module_name: str,
*,
module_available: bool,
) -> None:
"""
Initialise lazy-loading proxy module.
Parameters
----------
module_name : str
the name of the module to lazy-load (if available).
module_available : bool
indicate if the referenced module is actually available (we will proxy it
in both cases, but raise a helpful error when invoked if it doesn't exist).
"""
self._module_available = module_available
self._module_name = module_name
self._globals = globals()
super().__init__(module_name)
def _import(self) -> ModuleType:
# import the referenced module, replacing the proxy in this module's globals
module = import_module(self.__name__)
self._globals[self._module_name] = module
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, attr: Any) -> Any:
# have "hasattr('__wrapped__')" return False without triggering import
# (it's for decorators, not modules, but keeps "make doctest" happy)
if attr == "__wrapped__":
raise AttributeError(
f"{self._module_name!r} object has no attribute {attr!r}"
)
# accessing the proxy module's attributes triggers import of the real thing
if self._module_available:
# import the module and return the requested attribute
module = self._import()
return getattr(module, attr)
# user has not installed the proxied/lazy module
elif attr == "__name__":
return self._module_name
elif re.match(r"^__\w+__$", attr) and attr != "__version__":
# allow some minimal introspection on private module
# attrs to avoid unnecessary error-handling elsewhere
return None
else:
# all other attribute access raises a helpful exception
pfx = self._mod_pfx.get(self._module_name, "")
raise ModuleNotFoundError(
f"{pfx}{attr} requires {self._module_name!r} module to be installed"
) from None
def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
"""
Lazy import the given module; avoids up-front import costs.
Parameters
----------
module_name : str
name of the module to import, eg: "polars".
Notes
-----
If the requested module is not available (eg: has not been installed), a proxy
module is created in its place, which raises an exception on any attribute
access. This allows for import and use as normal, without requiring explicit
guard conditions - if the module is never used, no exception occurs; if it
is, then a helpful exception is raised.
Returns
-------
tuple of (Module, bool)
A lazy-loading module and a boolean indicating if the requested/underlying
module exists (if not, the returned module is a proxy).
"""
# check if module is LOADED
if module_name in sys.modules:
return sys.modules[module_name], True
# check if module is AVAILABLE
try:
module_spec = find_spec(module_name)
module_available = not (module_spec is None or module_spec.loader is None)
except ModuleNotFoundError:
module_available = False
# create lazy/proxy module that imports the real one on first use
# (or raises an explanatory ModuleNotFoundError if not available)
return (
_LazyModule(
module_name=module_name,
module_available=module_available,
),
module_available,
)
if TYPE_CHECKING:
import datasets
import numpy
import pandas
import polars
import ray
import tensorflow
import torch
else:
# heavy/optional third party libs
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
polars, _POLARS_AVAILABLE = _lazy_import("polars")
torch, _TORCH_AVAILABLE = _lazy_import("torch")
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow")
ray, _RAY_AVAILABLE = _lazy_import("ray")
@lru_cache(maxsize=None)
def _might_be(cls: type, type_: str) -> bool:
# infer whether the given class "might" be associated with the given
# module (in which case it's reasonable to do a real isinstance check)
try:
return any(f"{type_}." in str(o) for o in cls.mro())
except TypeError:
return False
def _check_for_numpy(obj: Any, *, check_type: bool = True) -> bool:
return _NUMPY_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "numpy"
)
def _check_for_pandas(obj: Any, *, check_type: bool = True) -> bool:
return _PANDAS_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "pandas"
)
def _check_for_polars(obj: Any, *, check_type: bool = True) -> bool:
return _POLARS_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "polars"
)
def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
return _TORCH_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "torch"
)
def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool:
return _HUGGING_FACE_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "datasets"
)
def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool:
return _TENSORFLOW_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "tensorflow"
)
def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool:
return _RAY_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "ray"
)
__all__ = [
# lazy-load third party libs
"datasets",
"numpy",
"pandas",
"polars",
"ray",
"tensorflow",
"torch",
# lazy utilities
"_check_for_hugging_face",
"_check_for_numpy",
"_check_for_pandas",
"_check_for_polars",
"_check_for_tensorflow",
"_check_for_torch",
"_check_for_ray",
"_LazyModule",
# exported flags/guards
"_NUMPY_AVAILABLE",
"_PANDAS_AVAILABLE",
"_POLARS_AVAILABLE",
"_TORCH_AVAILABLE",
"_HUGGING_FACE_AVAILABLE",
"_TENSORFLOW_AVAILABLE",
"_RAY_AVAILABLE",
]

View File

@@ -1,15 +1,6 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from abc import ABC, abstractmethod
from typing import List, Union
@@ -34,7 +25,7 @@ class EmbeddingFunction(BaseModel, ABC):
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
max_retries: int = (
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
)
_ndims: int = PrivateAttr()
@@ -46,22 +37,37 @@ class EmbeddingFunction(BaseModel, ABC):
return cls(**kwargs)
@abstractmethod
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
"""
Compute the embeddings for a given user query
Returns
-------
A list of embeddings for each input. The embedding of each input can be None
when the embedding is not valid.
"""
pass
@abstractmethod
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database
def compute_source_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
"""Compute the embeddings for the source column in the database
Returns
-------
A list of embeddings for each input. The embedding of each input can be None
when the embedding is not valid.
"""
pass
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for a given user query with retries
def compute_query_embeddings_with_retry(
self, *args, **kwargs
) -> list[Union[np.array, None]]:
"""Compute the embeddings for a given user query with retries
Returns
-------
A list of embeddings for each input. The embedding of each input can be None
when the embedding is not valid.
"""
return retry_with_exponential_backoff(
self.compute_query_embeddings, max_retries=self.max_retries
@@ -70,9 +76,15 @@ class EmbeddingFunction(BaseModel, ABC):
**kwargs,
)
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database with retries
def compute_source_embeddings_with_retry(
self, *args, **kwargs
) -> list[Union[np.array, None]]:
"""Compute the embeddings for the source column in the database with retries.
Returns
-------
A list of embeddings for each input. The embedding of each input can be None
when the embedding is not valid.
"""
return retry_with_exponential_backoff(
self.compute_source_embeddings, max_retries=self.max_retries
@@ -94,8 +106,14 @@ class EmbeddingFunction(BaseModel, ABC):
from ..pydantic import PYDANTIC_VERSION
if PYDANTIC_VERSION.major < 2:
return dict(self)
return self.model_dump()
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
return self.model_dump(
exclude={
field_name
for field_name in self.model_fields
if field_name.startswith("_")
}
)
@abstractmethod
def ndims(self):
@@ -144,18 +162,20 @@ class TextEmbeddingFunction(EmbeddingFunction):
A callable ABC for embedding functions that take text as input
"""
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
def compute_query_embeddings(
self, query: str, *args, **kwargs
) -> list[Union[np.array, None]]:
return self.compute_source_embeddings(query, *args, **kwargs)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
def compute_source_embeddings(
self, texts: TEXT, *args, **kwargs
) -> list[Union[np.array, None]]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
) -> list[Union[np.array, None]]:
"""Generate the embeddings for the given texts"""
pass

View File

@@ -1,15 +1,6 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from functools import cached_property
from typing import TYPE_CHECKING, List, Optional, Union
@@ -19,6 +10,7 @@ from .registry import register
if TYPE_CHECKING:
import numpy as np
import ollama
@register("ollama")
@@ -39,17 +31,20 @@ class OllamaEmbeddings(TextEmbeddingFunction):
def ndims(self):
return len(self.generate_embeddings(["foo"])[0])
def _compute_embedding(self, text):
return self._ollama_client.embeddings(
model=self.name,
prompt=text,
options=self.options,
keep_alive=self.keep_alive,
)["embedding"]
def _compute_embedding(self, text) -> Union["np.array", None]:
return (
self._ollama_client.embeddings(
model=self.name,
prompt=text,
options=self.options,
keep_alive=self.keep_alive,
)["embedding"]
or None
)
def generate_embeddings(
self, texts: Union[List[str], "np.ndarray"]
) -> List["np.array"]:
) -> list[Union["np.array", None]]:
"""
Get the embeddings for the given texts
@@ -63,7 +58,7 @@ class OllamaEmbeddings(TextEmbeddingFunction):
return embeddings
@cached_property
def _ollama_client(self):
def _ollama_client(self) -> "ollama.Client":
ollama = attempt_import_or_raise("ollama")
# ToDo explore ollama.AsyncClient
return ollama.Client(host=self.host, **self.ollama_client_kwargs)

View File

@@ -1,17 +1,9 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from functools import cached_property
from typing import TYPE_CHECKING, List, Optional, Union
import logging
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
@@ -89,17 +81,26 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
openai = attempt_import_or_raise("openai")
# TODO retry, rate limit, token limit
if self.name == "text-embedding-ada-002":
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else:
kwargs = {
"input": texts,
"model": self.name,
}
if self.dim:
kwargs["dimensions"] = self.dim
rs = self._openai_client.embeddings.create(**kwargs)
try:
if self.name == "text-embedding-ada-002":
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else:
kwargs = {
"input": texts,
"model": self.name,
}
if self.dim:
kwargs["dimensions"] = self.dim
rs = self._openai_client.embeddings.create(**kwargs)
except openai.BadRequestError:
logging.exception("Bad request: %s", texts)
return [None] * len(texts)
except Exception:
logging.exception("OpenAI embeddings error")
raise
return [v.embedding for v in rs.data]
@cached_property

View File

@@ -16,6 +16,7 @@ import math
import random
import socket
import sys
import threading
import time
import urllib.error
import weakref
@@ -38,6 +39,42 @@ IMAGES = Union[
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
class RateLimiter:
def __init__(self, max_calls: int = 1, period: float = 1.0):
self.period = period
self.max_calls = max(1, min(sys.maxsize, math.floor(max_calls)))
self._last_reset = time.time()
self._num_calls = 0
self._lock = threading.RLock()
def _check_sleep(self) -> float:
current_time = time.time()
elapsed = current_time - self._last_reset
period_remaining = self.period - elapsed
# If the time window has elapsed then reset.
if period_remaining <= 0:
self._num_calls = 0
self._last_reset = current_time
self._num_calls += 1
if self._num_calls > self.max_calls:
return period_remaining
return 0.0
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self._lock:
time.sleep(self._check_sleep())
return func(*args, **kwargs)
return wrapper
@deprecated
def with_embeddings(
func: Callable,
@@ -109,21 +146,12 @@ class FunctionWrapper:
def embed_func(c):
return self.func(c.tolist())
if len(self.rate_limiter_kwargs) > 0:
v = int(sys.version_info.minor)
if v >= 11:
print(
"WARNING: rate limit only support up to 3.10, proceeding "
"without rate limiter"
)
else:
import ratelimiter
max_calls = self.rate_limiter_kwargs["max_calls"]
limiter = ratelimiter.RateLimiter(
max_calls, period=self.rate_limiter_kwargs["period"]
)
embed_func = limiter(embed_func)
if self.rate_limiter_kwargs:
limiter = RateLimiter(
max_calls=self.rate_limiter_kwargs["max_calls"],
period=self.rate_limiter_kwargs["period"],
)
embed_func = limiter(embed_func)
batches = self.to_batches(text)
embeds = [emb for c in batches for emb in embed_func(c)]
return embeds

View File

@@ -78,8 +78,243 @@ class FTS:
For example, it works with `title`, `description`, `content`, etc.
"""
def __init__(self):
self._inner = LanceDbIndex.fts()
def __init__(self, with_position: bool = True):
self._inner = LanceDbIndex.fts(with_position=with_position)
class HnswPq:
"""Describe a HNSW-PQ index configuration.
HNSW-PQ stands for Hierarchical Navigable Small World - Product Quantization.
It is a variant of the HNSW algorithm that uses product quantization to compress
the vectors. To create an HNSW-PQ index, you can specify the following parameters:
Parameters
----------
distance_type: str, default "L2"
The distance metric used to train the index.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
num_partitions, default sqrt(num_rows)
The number of IVF partitions to create.
For HNSW, we recommend a small number of partitions. Setting this to 1 works
well for most tables. For very large tables, training just one HNSW graph
will require too much memory. Each partition becomes its own HNSW graph, so
setting this value higher reduces the peak memory use of training.
num_sub_vectors, default is vector dimension / 16
Number of sub-vectors of PQ.
This value controls how much the vector is compressed during the
quantization step. The more sub vectors there are the less the vector is
compressed. The default is the dimension of the vector divided by 16.
If the dimension is not evenly divisible by 16 we use the dimension
divided by 8.
The above two cases are highly preferred. Having 8 or 16 values per
subvector allows us to use efficient SIMD instructions.
If the dimension is not visible by 8 then we use 1 subvector. This is not
ideal and will likely result in poor performance.
max_iterations, default 50
Max iterations to train kmeans.
When training an IVF index we use kmeans to calculate the partitions. This
parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases the
parameter is unused because kmeans will converge with fewer iterations. The
parameter is only used in cases where kmeans does not appear to converge. In
those cases it is unlikely that setting this larger will lead to the index
converging anyways.
sample_rate, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF index is trained, we need to calculate partitions. These are
groups of vectors that are similar to each other. To do this we use an
algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we
run kmeans on a random sample of the data. This parameter controls the
size of the sample. The total number of vectors used to train the index
is `sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in
most cases the default should be sufficient.
m, default 20
The number of neighbors to select for each vector in the HNSW graph.
This value controls the tradeoff between search speed and accuracy.
The higher the value the more accurate the search but the slower it will be.
ef_construction, default 300
The number of candidates to evaluate during the construction of the HNSW graph.
This value controls the tradeoff between build speed and accuracy.
The higher the value the more accurate the build but the slower it will be.
150 to 300 is the typical range. 100 is a minimum for good quality search
results. In most cases, there is no benefit to setting this higher than 500.
This value should be set to a value that is not less than `ef` in the
search phase.
"""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_pq(
distance_type=distance_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
class HnswSq:
"""Describe a HNSW-SQ index configuration.
HNSW-SQ stands for Hierarchical Navigable Small World - Scalar Quantization.
It is a variant of the HNSW algorithm that uses scalar quantization to compress
the vectors.
Parameters
----------
distance_type: str, default "L2"
The distance metric used to train the index.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
num_partitions, default sqrt(num_rows)
The number of IVF partitions to create.
For HNSW, we recommend a small number of partitions. Setting this to 1 works
well for most tables. For very large tables, training just one HNSW graph
will require too much memory. Each partition becomes its own HNSW graph, so
setting this value higher reduces the peak memory use of training.
max_iterations, default 50
Max iterations to train kmeans.
When training an IVF index we use kmeans to calculate the partitions.
This parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases
the parameter is unused because kmeans will converge with fewer iterations.
The parameter is only used in cases where kmeans does not appear to converge.
In those cases it is unlikely that setting this larger will lead to
the index converging anyways.
sample_rate, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF index is trained, we need to calculate partitions. These
are groups of vectors that are similar to each other. To do this
we use an algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we
run kmeans on a random sample of the data. This parameter controls the
size of the sample. The total number of vectors used to train the index
is `sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in
most cases the default should be sufficient.
m, default 20
The number of neighbors to select for each vector in the HNSW graph.
This value controls the tradeoff between search speed and accuracy.
The higher the value the more accurate the search but the slower it will be.
ef_construction, default 300
The number of candidates to evaluate during the construction of the HNSW graph.
This value controls the tradeoff between build speed and accuracy.
The higher the value the more accurate the build but the slower it will be.
150 to 300 is the typical range. 100 is a minimum for good quality search
results. In most cases, there is no benefit to setting this higher than 500.
This value should be set to a value that is not less than `ef` in the search
phase.
"""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_sq(
distance_type=distance_type,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
class IvfPq:

View File

@@ -36,6 +36,7 @@ from . import __version__
from .arrow import AsyncRecordBatchReader
from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import safe_import_pandas
if TYPE_CHECKING:
@@ -575,12 +576,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._reranker = None
self._str_query = str_query
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
Parameters
----------
metric: "L2" or "cosine"
metric: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used.
Returns
@@ -588,7 +589,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._metric = metric
self._metric = metric.lower()
return self
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
@@ -679,6 +680,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
if self._reranker is not None:
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
check_reranker_result(result_set)
# convert result_set back to RecordBatchReader
result_set = pa.RecordBatchReader.from_batches(
result_set.schema, result_set.to_batches()
@@ -811,6 +813,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
results = results.read_all()
if self._reranker is not None:
results = self._reranker.rerank_fts(self._query, results)
check_reranker_result(results)
return results
def tantivy_to_arrow(self) -> pa.Table:
@@ -953,8 +956,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(
self,
table: "Table",
query: str = None,
vector_column: str = None,
query: Optional[str] = None,
vector_column: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
):
super().__init__(table)
@@ -1060,10 +1063,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._fts_query._query, vector_results, fts_results
)
if not isinstance(results, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
)
check_reranker_result(results)
# apply limit after reranking
results = results.slice(length=self._limit)
@@ -1112,8 +1112,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def rerank(
self,
normalize="score",
reranker: Reranker = RRFReranker(),
normalize: str = "score",
) -> LanceHybridQueryBuilder:
"""
Rerank the hybrid search results using the specified reranker. The reranker
@@ -1121,12 +1121,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
Parameters
----------
reranker: Reranker, default RRFReranker()
The reranker to use. Must be an instance of Reranker class.
normalize: str, default "score"
The method to normalize the scores. Can be "rank" or "score". If "rank",
the scores are converted to ranks and then normalized. If "score", the
scores are normalized directly.
reranker: Reranker, default RRFReranker()
The reranker to use. Must be an instance of Reranker class.
Returns
-------
LanceHybridQueryBuilder

View File

@@ -79,6 +79,13 @@ class RestfulLanceDBClient:
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
return False # Do not suppress exceptions
def close(self):
self.session.close()
self.closed = True

View File

@@ -26,7 +26,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import Query, Table, _sanitize_data
from ..util import inf_vector_column_query, value_to_sql
from ..util import value_to_sql, infer_vector_column_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection
@@ -126,6 +126,7 @@ class RemoteTable(Table):
column: str,
*,
replace: bool = False,
with_position: bool = True,
):
data = {
"column": column,
@@ -265,7 +266,7 @@ class RemoteTable(Table):
def search(
self,
query: Union[VEC, str],
query: Union[VEC, str] = None,
vector_column_name: Optional[str] = None,
query_type="auto",
fts_columns: Optional[Union[str, List[str]]] = None,
@@ -304,8 +305,6 @@ class RemoteTable(Table):
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, optional
The name of the vector column to search.
@@ -328,11 +327,15 @@ class RemoteTable(Table):
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None and query_type != "fts":
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
raise e
# empty query builder is not supported in saas, raise error
if query is None and query_type != "hybrid":
raise ValueError("Empty query is not supported")
vector_column_name = infer_vector_column_name(
schema=self.schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
return LanceQueryBuilder.create(
self,

View File

@@ -32,6 +32,9 @@ class AnswerdotaiRerankers(Reranker):
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
**kwargs
Additional keyword arguments to pass to the model. For example, 'device'.
See AnswerDotAI/rerankers for more information.
"""
def __init__(
@@ -40,13 +43,14 @@ class AnswerdotaiRerankers(Reranker):
model_name: str = "answerdotai/answerai-colbert-small-v1",
column: str = "text",
return_score="relevance",
**kwargs,
):
super().__init__(return_score)
self.column = column
rerankers = attempt_import_or_raise(
"rerankers"
) # import here for faster ops later
self.reranker = rerankers.Reranker(model_name, model_type)
self.reranker = rerankers.Reranker(model_name, model_type, **kwargs)
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()

View File

@@ -105,7 +105,7 @@ class Reranker(ABC):
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
) -> pa.Table:
"""
Rerank function receives the individual results from the vector and FTS search
results. You can choose to use any of the results to generate the final results,

View File

@@ -26,6 +26,9 @@ class ColbertReranker(AnswerdotaiRerankers):
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
**kwargs
Additional keyword arguments to pass to the model, for example, 'device'.
See AnswerDotAI/rerankers for more information.
"""
def __init__(
@@ -33,10 +36,12 @@ class ColbertReranker(AnswerdotaiRerankers):
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
**kwargs,
):
super().__init__(
model_type="colbert",
model_name=model_name,
column=column,
return_score=return_score,
**kwargs,
)

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from numpy import NaN
import pyarrow as pa
from .base import Reranker
@@ -58,14 +59,42 @@ class LinearCombinationReranker(Reranker):
def merge_results(
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
):
# If both are empty then just return an empty table
if len(vector_results) == 0 and len(fts_results) == 0:
return vector_results
# If one is empty then return the other
# If one is empty then return the other and add _relevance_score
# column equal the existing vector or fts score
if len(vector_results) == 0:
return fts_results
results = fts_results.append_column(
"_relevance_score",
pa.array(fts_results["_score"], type=pa.float32()),
)
if self.score == "relevance":
results = self._keep_relevance_score(results)
elif self.score == "all":
results = results.append_column(
"_distance",
pa.array([NaN] * len(fts_results), type=pa.float32()),
)
return results
if len(fts_results) == 0:
return vector_results
# invert the distance to relevance score
results = vector_results.append_column(
"_relevance_score",
pa.array(
[
self._invert_score(distance)
for distance in vector_results["_distance"].to_pylist()
],
type=pa.float32(),
),
)
if self.score == "relevance":
results = self._keep_relevance_score(results)
elif self.score == "all":
results = results.append_column(
"_score",
pa.array([NaN] * len(vector_results), type=pa.float32()),
)
return results
# sort both input tables on _rowid
combined_list = []

View File

@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
import pyarrow as pa
def check_reranker_result(result):
if not isinstance(result, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(result)}"
)
# Enforce that `_relevance_score` column is present in the result of every
# rerank_hybrid method
if "_relevance_score" not in result.column_names:
raise ValueError(
"rerank_hybrid must return a pyarrow.Table with a column"
"named `_relevance_score`"
)

View File

@@ -19,10 +19,12 @@ from typing import (
Optional,
Tuple,
Union,
overload,
)
from urllib.parse import urlparse
import lance
from .dependencies import _check_for_pandas
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
@@ -35,11 +37,20 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
from .query import (
AsyncQuery,
AsyncVectorQuery,
LanceEmptyQueryBuilder,
LanceFtsQueryBuilder,
LanceHybridQueryBuilder,
LanceQueryBuilder,
LanceVectorQueryBuilder,
Query,
)
from .util import (
fs_from_uri,
get_uri_scheme,
inf_vector_column_query,
infer_vector_column_name,
join_uri,
safe_import_pandas,
safe_import_polars,
@@ -53,38 +64,25 @@ if TYPE_CHECKING:
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
pd = safe_import_pandas()
pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _sanitize_data(
data,
schema: Optional[pa.Schema],
metadata: Optional[dict],
on_bad_vectors: str,
fill_value: Any,
):
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if _check_for_hugging_face(data):
# Huggingface datasets
from lance.dependencies import datasets
if isinstance(data, datasets.dataset_dict.DatasetDict):
if schema is None:
schema = _schema_from_hf(data, schema)
data = _to_record_batch_generator(
_to_batches_with_split(data),
schema,
metadata,
on_bad_vectors,
fill_value,
)
elif isinstance(data, datasets.Dataset):
if isinstance(data, datasets.Dataset):
if schema is None:
schema = data.features.arrow_schema
data = _to_record_batch_generator(
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
)
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
elif isinstance(data, datasets.dataset_dict.DatasetDict):
if schema is None:
schema = _schema_from_hf(data, schema)
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
if isinstance(data, LanceModel):
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
@@ -95,40 +93,68 @@ def _sanitize_data(
if schema is None:
schema = data[0].__class__.to_arrow_schema()
data = [model_to_dict(d) for d in data]
data = pa.Table.from_pylist(data, schema=schema)
return pa.Table.from_pylist(data, schema=schema)
elif isinstance(data[0], pa.RecordBatch):
return pa.Table.from_batches(data, schema=schema)
else:
data = pa.Table.from_pylist(data)
return pa.Table.from_pylist(data)
elif isinstance(data, dict):
data = vec_to_table(data)
elif pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False)
return vec_to_table(data)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
# Do not add schema here, since schema may contains the vector column
table = pa.Table.from_pandas(data, preserve_index=False)
# Do not serialize Pandas metadata
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = table.schema.metadata if table.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
elif pl is not None and isinstance(data, pl.DataFrame):
data = data.to_arrow()
if isinstance(data, pa.Table):
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if schema is None:
schema = data.schema
return table.replace_schema_metadata(meta)
elif isinstance(data, pa.Table):
return data
elif isinstance(data, pa.RecordBatch):
return pa.Table.from_batches([data])
elif isinstance(data, LanceDataset):
return data.scanner().to_table()
elif isinstance(data, pa.dataset.Dataset):
return data.to_table()
elif isinstance(data, pa.dataset.Scanner):
return data.to_table()
elif isinstance(data, pa.RecordBatchReader):
return data.read_all()
elif (
type(data).__module__.startswith("polars")
and data.__class__.__name__ == "DataFrame"
):
return data.to_arrow()
elif isinstance(data, Iterable):
data = _to_record_batch_generator(
data, schema, metadata, on_bad_vectors, fill_value
)
if schema is None:
data, schema = _generator_to_data_and_schema(data)
if schema is None:
raise ValueError("Cannot infer schema from generator data")
return _process_iterator(data, schema)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
raise TypeError(
f"Unknown data type {type(data)}. "
"Please check "
"https://lancedb.github.io/lancedb/python/python/ "
"to see supported types."
)
def _sanitize_data(
data: Any,
schema: Optional[pa.Schema] = None,
metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
data = _coerce_to_table(data, schema)
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
# TODO improve the logics in _sanitize_schema
data = _sanitize_schema(data, schema, on_bad_vectors, fill_value)
if schema is None:
schema = data.schema
_validate_schema(schema)
return data, schema
@@ -149,6 +175,9 @@ def sanitize_create_table(
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
else:
if schema is not None:
data = pa.Table.from_pylist([], schema)
if schema is None:
if data is None:
raise ValueError("Either data or schema must be provided")
@@ -468,6 +497,7 @@ class Table(ABC):
ordering_field_names: Union[str, List[str]] = None,
*,
replace: bool = False,
with_position: bool = True,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
tokenizer_name: str = "default",
use_tantivy: bool = True,
@@ -500,6 +530,12 @@ class Table(ABC):
use_tantivy: bool, default True
If True, use the legacy full-text search implementation based on tantivy.
If False, use the new full-text search implementation based on lance-index.
with_position: bool, default True
Only available with use_tantivy=False
If False, do not store the positions of the terms in the text.
This can reduce the size of the index and improve indexing speed.
But it will raise an exception for phrase queries.
"""
raise NotImplementedError
@@ -600,7 +636,7 @@ class Table(ABC):
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder:
@@ -1305,6 +1341,7 @@ class LanceTable(Table):
ordering_field_names: Union[str, List[str]] = None,
*,
replace: bool = False,
with_position: bool = True,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
tokenizer_name: str = "default",
use_tantivy: bool = True,
@@ -1318,7 +1355,10 @@ class LanceTable(Table):
if exist:
fs.delete_dir(path)
self._dataset_mut.create_scalar_index(
field_names, index_type="INVERTED", replace=replace
field_names,
index_type="INVERTED",
replace=replace,
with_position=with_position,
)
return
@@ -1476,11 +1516,51 @@ class LanceTable(Table):
self.schema.metadata
)
@overload
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
query_type: Literal["vector"] = "vector",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceVectorQueryBuilder: ...
@overload
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["fts"] = "fts",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceFtsQueryBuilder: ...
@overload
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["hybrid"] = "hybrid",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceHybridQueryBuilder: ...
@overload
def search(
self,
query: None = None,
vector_column_name: Optional[str] = None,
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceEmptyQueryBuilder: ...
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> LanceQueryBuilder:
@@ -1550,11 +1630,12 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None and query_type != "fts":
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
raise e
vector_column_name = infer_vector_column_name(
schema=self.schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
return LanceQueryBuilder.create(
self,
@@ -1918,22 +1999,26 @@ def _sanitize_vector_column(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
vec_arr = data[vector_column_name].combine_chunks()
vec_arr = ensure_fixed_size_list(vec_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif not pa.types.is_fixed_size_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
vec_arr = ensure_fixed_size_list(vec_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
# kernel over f16 types.
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
if np.isnan(values_np).any():
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
if pa.types.is_float16(vec_arr.values.type):
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
# kernel over f16 types yet.
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
if np.isnan(values_np).any():
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
else:
if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py():
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
return data
@@ -1977,8 +2062,15 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
return data
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
def _sanitize_nans(
data,
fill_value,
on_bad_vectors,
vec_arr: pa.FixedSizeListArray,
vector_column_name: str,
):
"""Sanitize NaNs in vectors"""
assert pa.types.is_fixed_size_list(vec_arr.type)
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has NaNs. "
@@ -1998,12 +2090,63 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif on_bad_vectors == "drop":
is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False)
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
data = data.filter(is_full)
# Drop is very slow to be able to filter out NaNs in a fixed size list array
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
not_nulls = np.any(np_arr, axis=1)
data = data.filter(~not_nulls)
return data
def _validate_schema(schema: pa.Schema):
"""
Make sure the metadata is valid utf8
"""
if schema.metadata is not None:
_validate_metadata(schema.metadata)
def _validate_metadata(metadata: dict):
"""
Make sure the metadata values are valid utf8 (can be nested)
Raises ValueError if not valid utf8
"""
for k, v in metadata.items():
if isinstance(v, bytes):
try:
v.decode("utf8")
except UnicodeDecodeError:
raise ValueError(
f"Metadata key {k} is not valid utf8. "
"Consider base64 encode for generic binary metadata."
)
elif isinstance(v, dict):
_validate_metadata(v)
def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table:
batches = []
for batch in data:
batch_table = _coerce_to_table(batch, schema)
if schema is not None:
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid:
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n"
f"Got:\n{batch_table.schema}"
)
batches.append(batch_table)
if batches:
return pa.concat_tables(batches)
else:
raise ValueError("Input iterable is empty")
class AsyncTable:
"""
An AsyncTable is a collection of Records in a LanceDB Database.
@@ -2539,3 +2682,34 @@ class AsyncTable:
List all indices that have been created with Self::create_index
"""
return await self._inner.list_indices()
async def uses_v2_manifest_paths(self) -> bool:
"""
Check if the table is using the new v2 manifest paths.
Returns
-------
bool
True if the table is using the new v2 manifest paths, False otherwise.
"""
return await self._inner.uses_v2_manifest_paths()
async def migrate_manifest_paths_v2(self):
"""
Migrate the manifest paths to the new format.
This will update the manifest to use the new v2 format for paths.
This function is idempotent, and can be run multiple times without
changing the state of the object store.
!!! danger
This should not be run while other concurrent operations are happening.
And it should also run until completion before resuming other operations.
You can use
[AsyncTable.uses_v2_manifest_paths][lancedb.table.AsyncTable.uses_v2_manifest_paths]
to check if the table is already using the new path style.
"""
await self._inner.migrate_manifest_paths_v2()

View File

@@ -9,7 +9,7 @@ import pathlib
import warnings
from datetime import date, datetime
from functools import singledispatch
from typing import Tuple, Union
from typing import Tuple, Union, Optional, Any
from urllib.parse import urlparse
import numpy as np
@@ -212,6 +212,23 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
return vector_col_name
def infer_vector_column_name(
schema: pa.Schema,
query_type: str,
query: Optional[Any], # inferred later in query builder
vector_column_name: Optional[str],
):
if (vector_column_name is None and query is not None and query_type != "fts") or (
vector_column_name is None and query_type == "hybrid"
):
try:
vector_column_name = inf_vector_column_query(schema)
except Exception as e:
raise e
return vector_column_name
@singledispatch
def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type")
@@ -219,6 +236,7 @@ def value_to_sql(value):
@value_to_sql.register(str)
def _(value: str):
value = value.replace("'", "''")
return f"'{value}'"

View File

@@ -13,6 +13,7 @@
import re
from datetime import timedelta
import os
import lancedb
import numpy as np
@@ -413,6 +414,40 @@ async def test_create_exist_ok_async(tmp_path):
# await db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
async def test_create_table_v2_manifest_paths_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
# Create table in v2 mode with v2 manifest paths enabled
tbl = await db.create_table(
"test_v2_manifest_paths",
data=[{"id": 0}],
use_legacy_format=False,
enable_v2_manifest_paths=True,
)
assert await tbl.uses_v2_manifest_paths()
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
for manifest in os.listdir(manifests_dir):
assert re.match(r"\d{20}\.manifest", manifest)
# Start a table in V1 mode then migrate
tbl = await db.create_table(
"test_v2_migration",
data=[{"id": 0}],
use_legacy_format=False,
enable_v2_manifest_paths=False,
)
assert not await tbl.uses_v2_manifest_paths()
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
for manifest in os.listdir(manifests_dir):
assert re.match(r"\d\.manifest", manifest)
await tbl.migrate_manifest_paths_v2()
assert await tbl.uses_v2_manifest_paths()
for manifest in os.listdir(manifests_dir):
assert re.match(r"\d{20}\.manifest", manifest)
def test_open_table_sync(tmp_path):
db = lancedb.connect(tmp_path)
db.create_table("test", data=[{"id": 0}])
@@ -559,7 +594,9 @@ async def test_create_in_v2_mode(tmp_path):
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table("test", data=make_data(), schema=schema)
tbl = await db.create_table(
"test", data=make_data(), schema=schema, data_storage_version="legacy"
)
async def is_in_v2_mode(tbl):
batches = await tbl.query().to_batches(max_batch_length=1024 * 10)
@@ -591,7 +628,9 @@ async def test_create_in_v2_mode(tmp_path):
assert await is_in_v2_mode(tbl)
# Create empty table uses v1 mode by default
tbl = await db.create_table("test_empty_v2_default", data=None, schema=schema)
tbl = await db.create_table(
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
)
await tbl.add(make_table())
assert not await is_in_v2_mode(tbl)

View File

@@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import List, Union
import lance
@@ -35,9 +34,6 @@ def mock_embed_func(input_data):
def test_with_embeddings():
for wrap_api in [True, False]:
if wrap_api and sys.version_info.minor >= 11:
# ratelimiter package doesn't work on 3.11
continue
data = pa.Table.from_arrays(
[
pa.array(["foo", "bar"]),
@@ -90,6 +86,47 @@ def test_embedding_function(tmp_path):
assert np.allclose(actual, expected)
def test_embedding_with_bad_results(tmp_path):
@register("mock-embedding")
class MockEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> list[Union[np.array, None]]:
return [
None if i % 2 == 0 else np.random.randn(self.ndims())
for i in range(len(texts))
]
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
model = registry.get("mock-embedding").create()
class Schema(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
table = db.create_table("test", schema=Schema, mode="overwrite")
table.add(
[{"text": "hello world"}, {"text": "bar"}],
on_bad_vectors="drop",
)
df = table.to_pandas()
assert len(table) == 1
assert df.iloc[0]["text"] == "bar"
# table = db.create_table("test2", schema=Schema, mode="overwrite")
# table.add(
# [{"text": "hello world"}, {"text": "bar"}],
# )
# assert len(table) == 2
# tbl = table.to_arrow()
# assert tbl["vector"].null_count == 1
@pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model):
@@ -146,3 +183,45 @@ def test_add_optional_vector(tmp_path):
expected = LanceSchema(id="id", text="text")
tbl.add([expected])
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()
@pytest.mark.parametrize(
"embedding_type",
[
"openai",
"sentence-transformers",
"huggingface",
"ollama",
"cohere",
"instructor",
],
)
def test_embedding_function_safe_model_dump(embedding_type):
registry = get_registry()
# Note: Some embedding types might require specific parameters
try:
model = registry.get(embedding_type).create()
except Exception as e:
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
dumped_model = model.safe_model_dump()
assert all(
not k.startswith("_") for k in dumped_model.keys()
), f"{embedding_type}: Dumped model contains keys starting with underscore"
assert (
"max_retries" in dumped_model
), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
assert isinstance(
dumped_model, dict
), f"{embedding_type}: Dumped model is not a dictionary"
for key in model.__dict__:
if key.startswith("_"):
assert key not in dumped_model, (
f"{embedding_type}: Private attribute '{key}' "
f"is present in dumped model"
)

View File

@@ -442,3 +442,42 @@ def test_watsonx_embedding(tmp_path):
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("ollama") is None, reason="Ollama not installed"
)
def test_ollama_embedding(tmp_path):
model = get_registry().get("ollama").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
# Test safe_model_dump
dumped_model = model.safe_model_dump()
assert isinstance(dumped_model, dict)
assert "name" in dumped_model
assert "max_retries" in dumped_model
assert dumped_model["max_retries"] == 0
assert all(not k.startswith("_") for k in dumped_model.keys())
# Test serialization of the dumped model
import json
try:
json.dumps(dumped_model)
except TypeError:
pytest.fail("Failed to JSON serialize the dumped model")

View File

@@ -140,8 +140,11 @@ def test_create_index_with_stemming(tmp_path, table):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_create_inverted_index(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
@pytest.mark.parametrize("with_position", [True, False])
def test_create_inverted_index(table, use_tantivy, with_position):
if use_tantivy and not with_position:
pytest.skip("we don't support building a tantivy index without position")
table.create_fts_index("text", use_tantivy=use_tantivy, with_position=with_position)
def test_populate_index(tmp_path, table):
@@ -166,6 +169,40 @@ def test_search_fts(table, use_tantivy):
assert len(results) == 5
def test_search_fts_phrase_query(table):
table.create_fts_index("text", use_tantivy=False, with_position=False)
try:
phrase_results = table.search('"puppy runs"').limit(100).to_list()
assert False
except Exception:
pass
table.create_fts_index("text", use_tantivy=False, replace=True)
results = table.search("puppy").limit(100).to_list()
phrase_results = table.search('"puppy runs"').limit(100).to_list()
assert len(results) > len(phrase_results)
assert len(phrase_results) > 0
@pytest.mark.asyncio
async def test_search_fts_phrase_query_async(async_table):
async_table = await async_table
await async_table.create_index("text", config=FTS(with_position=False))
try:
phrase_results = (
await async_table.query().nearest_to_text("puppy runs").limit(100).to_list()
)
assert False
except Exception:
pass
await async_table.create_index("text", config=FTS())
results = await async_table.query().nearest_to_text("puppy").limit(100).to_list()
phrase_results = (
await async_table.query().nearest_to_text('"puppy runs"').limit(100).to_list()
)
assert len(results) > len(phrase_results)
assert len(phrase_results) > 0
def test_search_fts_specify_column(table):
table.create_fts_index("text", use_tantivy=False)
table.create_fts_index("text2", use_tantivy=False)

View File

@@ -8,7 +8,7 @@ import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq, Bitmap, LabelList
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
@pytest_asyncio.fixture
@@ -91,3 +91,17 @@ async def test_create_vector_index(some_table: AsyncTable):
assert len(indices) == 1
assert indices[0].index_type == "IvfPq"
assert indices[0].columns == ["vector"]
@pytest.mark.asyncio
async def test_create_hnswpq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1
@pytest.mark.asyncio
async def test_create_hnswsq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1

View File

@@ -74,21 +74,23 @@ async def test_e2e_with_mock_server():
await mock_server.start()
try:
client = RestfulLanceDBClient("lancedb+http://localhost:8111")
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
with RestfulLanceDBClient("lancedb+http://localhost:8111") as client:
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns
assert "vector" in df.columns
assert "id" in df.columns
assert client.closed
finally:
# make sure we don't leak resources
await mock_server.stop()

View File

@@ -1,15 +1,7 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from unittest.mock import MagicMock
import lancedb
import pyarrow as pa
@@ -39,3 +31,53 @@ def test_remote_db():
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas()
def test_create_empty_table():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", schema=schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
json_schema = {
"fields": [
{
"name": "vector",
"nullable": True,
"type": {
"type": "fixed_size_list",
"fields": [
{"name": "item", "nullable": True, "type": {"type": "float"}}
],
"length": 2,
},
},
]
}
client.post.return_value = {"schema": json_schema}
assert table.schema == schema
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
client.post.return_value = 0
assert table.count_rows(None) == 0
def test_create_table_with_recordbatches():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", [batch], schema=batch.schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"

View File

@@ -120,12 +120,14 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
)
assert len(result) == 30
err = (
ascending_relevance_err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# Vector search setting
result = (
@@ -135,7 +137,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
result_explicit = (
table.search(query_vector, vector_column_name="vector")
.rerank(reranker=reranker, query_string=query)
@@ -158,7 +162,26 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# empty FTS results
query = "abcxyz" * 100
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# Multi-vector search setting
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
@@ -172,7 +195,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
result_deduped = reranker.rerank_multivector(
[rs1, rs2, rs1], query, deduplicate=True
)
assert len(result_deduped) < 20
assert len(result_deduped) <= 20
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
assert len(result) == 20 and result == result_arrow
@@ -213,7 +236,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
.vector(query_vector)
.text(query)
.limit(30)
.rerank(normalize="score")
.rerank(reranker, normalize="score")
.to_arrow()
)
assert len(result) == 30
@@ -228,12 +251,30 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
table.search(query, query_type="hybrid", vector_column_name="vector").text(
query
).to_arrow()
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
# Test with empty FTS results
query = "abcxyz" * 100
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
@pytest.mark.parametrize("use_tantivy", [True, False])

View File

@@ -64,6 +64,55 @@ def test_basic(db):
assert table.to_lance().to_table() == ds.to_table()
def test_input_data_type(db, tmp_path):
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
pa.field("age", pa.int32()),
]
)
data = {
"id": [1, 2, 3, 4, 5],
"name": ["Alice", "Bob", "Charlie", "David", "Eve"],
"age": [25, 30, 35, 40, 45],
}
record_batch = pa.RecordBatch.from_pydict(data, schema=schema)
pa_reader = pa.RecordBatchReader.from_batches(record_batch.schema, [record_batch])
pa_table = pa.Table.from_batches([record_batch])
def create_dataset(tmp_path):
path = os.path.join(tmp_path, "test_source_dataset")
pa.dataset.write_dataset(pa_table, path, format="parquet")
return pa.dataset.dataset(path, format="parquet")
pa_dataset = create_dataset(tmp_path)
pa_scanner = pa_dataset.scanner()
input_types = [
("RecordBatchReader", pa_reader),
("RecordBatch", record_batch),
("Table", pa_table),
("Dataset", pa_dataset),
("Scanner", pa_scanner),
]
for input_type, input_data in input_types:
table_name = f"test_{input_type.lower()}"
ds = LanceTable.create(db, table_name, data=input_data).to_lance()
assert ds.schema == schema
assert ds.count_rows() == 5
assert ds.schema.field("id").type == pa.int64()
assert ds.schema.field("name").type == pa.string()
assert ds.schema.field("age").type == pa.int32()
result_table = ds.to_table()
assert result_table.column("id").to_pylist() == data["id"]
assert result_table.column("name").to_pylist() == data["name"]
assert result_table.column("age").to_pylist() == data["age"]
@pytest.mark.asyncio
async def test_close(db_async: AsyncConnection):
table = await db_async.create_table("some_table", data=[{"id": 0}])
@@ -274,7 +323,6 @@ def test_polars(db):
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
@@ -925,7 +973,36 @@ def test_hybrid_search(db, tmp_path):
.where("text='Arrrrggghhhhhhh'")
.to_list()
)
len(result) == 1
assert len(result) == 1
# with explicit query type
vector_query = list(range(emb.ndims()))
result = (
table.search(query_type="hybrid")
.vector(vector_query)
.text("Arrrrggghhhhhhh")
.to_arrow()
)
assert len(result) > 0
assert "_relevance_score" in result.column_names
# with vector_column_name
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(vector_query)
.text("Arrrrggghhhhhhh")
.to_arrow()
)
assert len(result) > 0
assert "_relevance_score" in result.column_names
# fail if only text or vector is provided
with pytest.raises(ValueError):
table.search(query_type="hybrid").to_list()
with pytest.raises(ValueError):
table.search(query_type="hybrid").vector(vector_query).to_list()
with pytest.raises(ValueError):
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
@pytest.mark.parametrize(

View File

@@ -15,7 +15,8 @@ import os
import pathlib
import pytest
from lancedb.util import get_uri_scheme, join_uri
import lancedb
from lancedb.util import get_uri_scheme, join_uri, value_to_sql
def test_normalize_uri():
@@ -84,3 +85,29 @@ def test_local_join_uri_windows():
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"
def test_value_to_sql_string(tmp_path):
# Make sure we can convert Python string literals to SQL strings, even if
# they contain characters meaningful in SQL, such as ' and \.
values = ["anthony's", 'a "test" string', "anthony's \"favorite color\" wasn't red"]
expected_values = [
"'anthony''s'",
"'a \"test\" string'",
"'anthony''s \"favorite color\" wasn''t red'",
]
for value, expected in zip(values, expected_values):
assert value_to_sql(value) == expected
# Also test we can roundtrip those strings through update.
# This validates the query parser understands the strings we
# are creating.
db = lancedb.connect(tmp_path)
table = db.create_table(
"test",
[{"search": value, "replace": "something"} for value in values],
)
for value in values:
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
assert table.to_pandas().query("search == @value")["replace"].item() == value

View File

@@ -81,6 +81,7 @@ impl Connection {
data: Bound<'_, PyAny>,
storage_options: Option<HashMap<String, String>>,
data_storage_version: Option<String>,
enable_v2_manifest_paths: Option<bool>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone();
@@ -93,6 +94,10 @@ impl Connection {
builder = builder.storage_options(storage_options);
}
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
}
if let Some(data_storage_version) = data_storage_version.as_ref() {
builder = builder.data_storage_version(
LanceFileVersion::from_str(data_storage_version)
@@ -113,6 +118,7 @@ impl Connection {
schema: Bound<'_, PyAny>,
storage_options: Option<HashMap<String, String>>,
data_storage_version: Option<String>,
enable_v2_manifest_paths: Option<bool>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.get_inner()?.clone();
@@ -126,6 +132,10 @@ impl Connection {
builder = builder.storage_options(storage_options);
}
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
}
if let Some(data_storage_version) = data_storage_version.as_ref() {
builder = builder.data_storage_version(
LanceFileVersion::from_str(data_storage_version)

View File

@@ -14,8 +14,13 @@
use std::sync::Mutex;
use lancedb::index::scalar::FtsIndexBuilder;
use lancedb::{
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
index::{
scalar::BTreeIndexBuilder,
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
Index as LanceDbIndex,
},
DistanceType,
};
use pyo3::{
@@ -23,6 +28,8 @@ use pyo3::{
pyclass, pymethods, PyResult,
};
use crate::util::parse_distance_type;
#[pyclass]
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
@@ -100,9 +107,85 @@ impl Index {
}
#[staticmethod]
pub fn fts() -> PyResult<Self> {
pub fn fts(with_position: Option<bool>) -> Self {
let mut opts = FtsIndexBuilder::default();
if let Some(with_position) = with_position {
opts = opts.with_position(with_position);
}
Self {
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
}
}
#[staticmethod]
pub fn hnsw_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> PyResult<Self> {
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::FTS(Default::default()))),
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
})
}
#[staticmethod]
pub fn hnsw_sq(
distance_type: Option<String>,
num_partitions: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> PyResult<Self> {
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
}
if let Some(max_iterations) = max_iterations {
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
})
}
}

View File

@@ -303,4 +303,28 @@ impl Table {
})
})
}
pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner
.as_native()
.ok_or_else(|| PyValueError::new_err("This cannot be run on a remote table"))?
.uses_v2_manifest_paths()
.await
.infer_error()
})
}
pub fn migrate_manifest_paths_v2(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner
.as_native()
.ok_or_else(|| PyValueError::new_err("This cannot be run on a remote table"))?
.migrate_manifest_paths_v2()
.await
.infer_error()
})
}
}

View File

@@ -1,11 +1,11 @@
# Release process
There are five total packages we release. Three are the `lancedb` packages
for Python, Rust, and Node.js. The other two are the legacy `vectordb`
packages for Rust and node.js.
There are five total packages we release. Four are the `lancedb` packages
for Python, Rust, Java, and Node.js. The other one is the legacy `vectordb`
package node.js.
The Python package is versioned and released separately from the Rust and Node.js
ones. For Rust and Node.js, the release process is shared between `lancedb` and
The Python package is versioned and released separately from the Rust, Java, and Node.js
ones. For Node.js the release process is shared between `lancedb` and
`vectordb` for now.
## Preview releases
@@ -24,20 +24,23 @@ indexes.
The release process uses a handful of GitHub actions to automate the process.
```text
┌─────────────────────┐
│Create Release Commit│
└─┬───────────────────┘
│ ┌────────────┐ ┌──►Python GH Release
├──►(tag) python-vX.Y.Z ───►│PyPI Publish├─┤
│ └────────────┘ └──►Python Wheels
│ ┌───────────┐
└──►(tag) vX.Y.Z ───┬──────►│NPM Publish├──┬──►Rust/Node GH Release
│ └───────────┘ │
│ └──►NPM Packages
│ ┌─────────────┐
──────►│Cargo Publish├───►Cargo Release
└─────────────┘
┌─────────────────────┐
│Create Release Commit│
└─┬───────────────────┘
│ ┌────────────┐ ┌──►Python GH Release
├──►(tag) python-vX.Y.Z ───►│PyPI Publish├─┤
│ └────────────┘ └──►Python Wheels
│ ┌───────────┐
└──►(tag) vX.Y.Z ───┬──────►│NPM Publish├──┬──►Rust/Node GH Release
│ └───────────┘ │
│ └──►NPM Packages
│ ┌─────────────┐
──────►│Cargo Publish├───►Cargo Release
└─────────────┘
│ ┌─────────────┐
└──────►│Maven Publish├───►Java Maven Repo Release
└─────────────┘
```
To start a release, trigger a `Create Release Commit` action from
@@ -56,10 +59,10 @@ To start a release, trigger a `Create Release Commit` action from
## Breaking changes
We try to avoid breaking changes, but sometimes they are necessary. When there
are breaking changes, we will increment the minor version. (This is valid
are breaking changes, we will increment the minor version. (This is valid
semantic versioning because we are still in `0.x` versions.)
When a PR makes a breaking change, the PR author should mark the PR using the
When a PR makes a breaking change, the PR author should mark the PR using the
conventional commit markers: either exclamation mark after the type
(such as `feat!: change signature of func`) or have `BREAKING CHANGE` in the
body of the PR. A CI job will add a `breaking-change` label to the PR, which is

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.10.0-beta.1"
version = "0.11.0-beta.1"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -46,6 +46,10 @@ impl JsQuery {
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
.value(&mut cx);
let fast_search = query_obj
.get_opt::<JsBoolean, _, _>(&mut cx, "_fastSearch")?
.map(|val| val.value(&mut cx));
let is_electron = cx
.argument::<JsBoolean>(1)
.or_throw(&mut cx)?
@@ -70,6 +74,9 @@ impl JsQuery {
if let Some(limit) = limit {
builder = builder.limit(limit as usize);
};
if let Some(true) = fast_search {
builder = builder.fast_search();
}
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.10.0-beta.1"
version = "0.11.0-beta.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -27,6 +27,7 @@ lazy_static.workspace = true
lance = { workspace = true }
lance-datafusion.workspace = true
lance-index = { workspace = true }
lance-table = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }
lance-encoding = { workspace = true }
@@ -44,7 +45,8 @@ serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = { version = "0.3.2", optional = true }
@@ -64,10 +66,11 @@ aws-sdk-s3 = { version = "1.38.0" }
aws-sdk-kms = { version = "1.37" }
aws-config = { version = "1.0" }
aws-smithy-runtime = { version = "1.3" }
http-body = "1" # Matching reqwest
[features]
default = []
remote = ["dep:reqwest"]
remote = ["dep:reqwest", "dep:http"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
openai = ["dep:async-openai", "dep:reqwest"]

View File

@@ -142,6 +142,7 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
pub(crate) table_definition: Option<TableDefinition>,
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
pub(crate) data_storage_version: Option<LanceFileVersion>,
pub(crate) enable_v2_manifest_paths: Option<bool>,
}
// Builder methods that only apply when we have initial data
@@ -156,6 +157,7 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
table_definition: None,
embeddings: Vec::new(),
data_storage_version: None,
enable_v2_manifest_paths: None,
}
}
@@ -188,6 +190,7 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
write_options: self.write_options,
embeddings: self.embeddings,
data_storage_version: self.data_storage_version,
enable_v2_manifest_paths: self.enable_v2_manifest_paths,
};
Ok((data, builder))
}
@@ -222,6 +225,7 @@ impl CreateTableBuilder<false, NoData> {
write_options: WriteOptions::default(),
embeddings: Vec::new(),
data_storage_version: None,
enable_v2_manifest_paths: None,
}
}
@@ -284,9 +288,26 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
self
}
/// Set whether to use V2 manifest paths for the table. (default: false)
///
/// These paths provide more efficient opening of tables with many
/// versions on object stores.
///
/// <div class="warning">Turning this on will make the dataset unreadable
/// for older versions of LanceDB (prior to 0.10.0).</div>
///
/// To migrate an existing dataset, instead use the
/// [[NativeTable::migrate_manifest_paths_v2]].
///
/// This has no effect in LanceDB Cloud.
pub fn enable_v2_manifest_paths(mut self, use_v2_manifest_paths: bool) -> Self {
self.enable_v2_manifest_paths = Some(use_v2_manifest_paths);
self
}
/// Set the data storage version.
///
/// The default is `LanceFileVersion::Legacy`.
/// The default is `LanceFileVersion::Stable`.
pub fn data_storage_version(mut self, data_storage_version: LanceFileVersion) -> Self {
self.data_storage_version = Some(data_storage_version);
self
@@ -294,13 +315,9 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
/// Set to true to use the v1 format for data files
///
/// This is currently defaulted to true and can be set to false to opt-in
/// to the new format. This should only be used for experimentation and
/// evaluation. The new format is still in beta and may change in ways that
/// are not backwards compatible.
///
/// Once the new format is stable, the default will change to `false` for
/// several releases and then eventually this option will be removed.
/// This is set to false by default to enable the stable format.
/// This should only be used for experimentation and
/// evaluation. This option may be removed in the future releases.
#[deprecated(since = "0.9.0", note = "use data_storage_version instead")]
pub fn use_legacy_format(mut self, use_legacy_format: bool) -> Self {
self.data_storage_version = if use_legacy_format {
@@ -314,8 +331,8 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
#[derive(Clone, Debug)]
pub struct OpenTableBuilder {
parent: Arc<dyn ConnectionInternal>,
name: String,
pub(crate) parent: Arc<dyn ConnectionInternal>,
pub(crate) name: String,
index_cache_size: u32,
lance_read_params: Option<ReadParams>,
}
@@ -976,7 +993,10 @@ impl ConnectionInternal for Database {
if matches!(&options.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite;
}
write_params.data_storage_version = options.data_storage_version;
write_params.enable_v2_manifest_paths =
options.enable_v2_manifest_paths.unwrap_or_default();
match NativeTable::create(
&table_uri,
@@ -1071,6 +1091,25 @@ impl ConnectionInternal for Database {
}
}
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
impl Connection {
pub fn new_with_handler<T>(
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
) -> Self
where
T: Into<reqwest::Body>,
{
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
internal,
uri: "db://test".to_string(),
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_schema::{DataType, Field, Schema};
@@ -1184,9 +1223,9 @@ mod tests {
assert_eq!(tables, vec!["table1".to_owned()]);
}
fn make_data() -> impl RecordBatchReader + Send + 'static {
fn make_data() -> Box<dyn RecordBatchReader + Send + 'static> {
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
BatchGenerator::new().col(id).batches(10, 2000)
Box::new(BatchGenerator::new().col(id).batches(10, 2000))
}
#[tokio::test]
@@ -1197,6 +1236,7 @@ mod tests {
let tbl = db
.create_table("v1_test", make_data())
.data_storage_version(LanceFileVersion::Legacy)
.execute()
.await
.unwrap();

View File

@@ -145,7 +145,7 @@ impl SentenceTransformersEmbeddingsBuilder {
let device = self.device.unwrap_or(Device::Cpu);
let repo = if let Some(revision) = self.revision {
Repo::with_revision(model_id, RepoType::Model, revision.to_string())
Repo::with_revision(model_id, RepoType::Model, revision)
} else {
Repo::new(model_id, RepoType::Model)
};

View File

@@ -59,9 +59,11 @@ pub enum Index {
IvfPq(IvfPqIndexBuilder),
/// IVF-HNSW index with Product Quantization
/// It is a variant of the HNSW algorithm that uses product quantization to compress the vectors.
IvfHnswPq(IvfHnswPqIndexBuilder),
/// IVF-HNSW index with Scalar Quantization
/// It is a variant of the HNSW algorithm that uses scalar quantization to compress the vectors.
IvfHnswSq(IvfHnswSqIndexBuilder),
}

View File

@@ -51,9 +51,25 @@ pub struct LabelListIndexBuilder {}
/// Builder for a full text search index
///
/// A full text search index is an index on a string column that allows for full text search
#[derive(Debug, Clone, Default)]
pub struct FtsIndexBuilder {}
#[derive(Debug, Clone)]
pub struct FtsIndexBuilder {
pub(crate) with_position: bool,
}
impl FtsIndexBuilder {}
impl Default for FtsIndexBuilder {
fn default() -> Self {
Self {
with_position: true,
}
}
}
impl FtsIndexBuilder {
/// Set the with_position flag
pub fn with_position(mut self, with_position: bool) -> Self {
self.with_position = with_position;
self
}
}
pub use lance_index::scalar::FullTextSearchQuery;

View File

@@ -214,6 +214,11 @@ pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
max(1, num_partitions)
}
pub(crate) fn suggested_num_partitions_for_hnsw(rows: usize, dim: u32) -> u32 {
let num_partitions = (((rows as u64) * (dim as u64)) / (256 * 5_000_000)) as u32;
max(1, num_partitions)
}
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim % 16 == 0 {
// Should be more aggressive than this default.

View File

@@ -228,6 +228,7 @@ pub use table::Table;
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "lowercase")]
pub enum DistanceType {
/// Euclidean distance. This is a very common distance metric that
/// accounts for both magnitude and direction when determining the distance

View File

@@ -402,6 +402,9 @@ pub trait QueryBase {
///
/// By default, it is false.
fn fast_search(self) -> Self;
/// Return the `_rowid` meta column from the Table.
fn with_row_id(self) -> Self;
}
pub trait HasQuery {
@@ -438,6 +441,11 @@ impl<T: HasQuery> QueryBase for T {
self.mut_query().fast_search = true;
self
}
fn with_row_id(mut self) -> Self {
self.mut_query().with_row_id = true;
self
}
}
/// Options for controlling the execution of a query
@@ -548,6 +556,11 @@ pub struct Query {
///
/// By default, this is false.
pub(crate) fast_search: bool,
/// If set to true, the query will return the `_rowid` meta column.
///
/// By default, this is false.
pub(crate) with_row_id: bool,
}
impl Query {
@@ -560,6 +573,7 @@ impl Query {
full_text_search: None,
select: Select::All,
fast_search: false,
with_row_id: false,
}
}
@@ -1160,4 +1174,24 @@ mod tests {
.unwrap();
assert!(!plan.contains("Take"));
}
#[tokio::test]
async fn test_with_row_id() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let results = table
.vector_search(&[0.1, 0.2, 0.3, 0.4])
.unwrap()
.with_row_id()
.limit(10)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
for batch in results {
assert!(batch.column_by_name("_rowid").is_some());
}
}
}

View File

@@ -21,3 +21,5 @@ pub mod client;
pub mod db;
pub mod table;
pub mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use std::{future::Future, time::Duration};
use reqwest::{
header::{HeaderMap, HeaderValue},
@@ -21,13 +21,66 @@ use reqwest::{
use crate::error::{Error, Result};
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
// we can mock responses in tests. Based on the patterns from this blog post:
// https://write.as/balrogboogie/testing-reqwest-based-clients
#[derive(Clone, Debug)]
pub struct RestfulLanceDbClient {
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
client: reqwest::Client,
host: String,
sender: S,
}
impl RestfulLanceDbClient {
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
fn send(&self, req: RequestBuilder) -> impl Future<Output = Result<Response>> + Send;
}
// Default implementation of HttpSend which sends the request normally with reqwest
#[derive(Clone, Debug)]
pub struct Sender;
impl HttpSend for Sender {
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
Ok(request.send().await?)
}
}
impl RestfulLanceDbClient<Sender> {
pub fn try_new(
db_url: &str,
api_key: &str,
region: &str,
host_override: Option<String>,
) -> Result<Self> {
let parsed_url = url::Url::parse(db_url)?;
debug_assert_eq!(parsed_url.scheme(), "db");
if !parsed_url.has_host() {
return Err(Error::Http {
message: format!("Invalid database URL (missing host) '{}'", db_url),
});
}
let db_name = parsed_url.host_str().unwrap();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(Self::default_headers(
api_key,
region,
db_name,
host_override.is_some(),
)?)
.build()?;
let host = match host_override {
Some(host_override) => host_override,
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
};
Ok(Self {
client,
host,
sender: Sender,
})
}
}
impl<S: HttpSend> RestfulLanceDbClient<S> {
pub fn host(&self) -> &str {
&self.host
}
@@ -66,36 +119,6 @@ impl RestfulLanceDbClient {
Ok(headers)
}
pub fn try_new(
db_url: &str,
api_key: &str,
region: &str,
host_override: Option<String>,
) -> Result<Self> {
let parsed_url = url::Url::parse(db_url)?;
debug_assert_eq!(parsed_url.scheme(), "db");
if !parsed_url.has_host() {
return Err(Error::Http {
message: format!("Invalid database URL (missing host) '{}'", db_url),
});
}
let db_name = parsed_url.host_str().unwrap();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(Self::default_headers(
api_key,
region,
db_name,
host_override.is_some(),
)?)
.build()?;
let host = match host_override {
Some(host_override) => host_override,
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
};
Ok(Self { client, host })
}
pub fn get(&self, uri: &str) -> RequestBuilder {
let full_uri = format!("{}{}", self.host, uri);
self.client.get(full_uri)
@@ -106,6 +129,10 @@ impl RestfulLanceDbClient {
self.client.post(full_uri)
}
pub async fn send(&self, req: RequestBuilder) -> Result<Response> {
self.sender.send(req).await
}
async fn rsp_to_str(response: Response) -> String {
let status = response.status();
response.text().await.unwrap_or_else(|_| status.to_string())
@@ -126,3 +153,49 @@ impl RestfulLanceDbClient {
}
}
}
#[cfg(test)]
pub mod test_utils {
use std::sync::Arc;
use super::*;
#[derive(Clone)]
pub struct MockSender {
f: Arc<dyn Fn(reqwest::Request) -> reqwest::Response + Send + Sync + 'static>,
}
impl std::fmt::Debug for MockSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MockSender")
}
}
impl HttpSend for MockSender {
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
let request = request.build().unwrap();
let response = (self.f)(request);
Ok(response)
}
}
pub fn client_with_handler<T>(
handler: impl Fn(reqwest::Request) -> http::response::Response<T> + Send + Sync + 'static,
) -> RestfulLanceDbClient<MockSender>
where
T: Into<reqwest::Body>,
{
let wrapper = move |req: reqwest::Request| {
let response = handler(req);
response.into()
};
RestfulLanceDbClient {
client: reqwest::Client::new(),
host: "http://localhost".to_string(),
sender: MockSender {
f: Arc::new(wrapper),
},
}
}
}

View File

@@ -16,6 +16,7 @@ use std::sync::Arc;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use http::StatusCode;
use reqwest::header::CONTENT_TYPE;
use serde::Deserialize;
use tokio::task::spawn_blocking;
@@ -27,11 +28,10 @@ use crate::embeddings::EmbeddingRegistry;
use crate::error::Result;
use crate::Table;
use super::client::RestfulLanceDbClient;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::table::RemoteTable;
use super::util::batches_to_ipc_bytes;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
use super::ARROW_STREAM_CONTENT_TYPE;
#[derive(Deserialize)]
struct ListTablesResponse {
@@ -39,8 +39,8 @@ struct ListTablesResponse {
}
#[derive(Debug)]
pub struct RemoteDatabase {
client: RestfulLanceDbClient,
pub struct RemoteDatabase<S: HttpSend = Sender> {
client: RestfulLanceDbClient<S>,
}
impl RemoteDatabase {
@@ -55,14 +55,32 @@ impl RemoteDatabase {
}
}
impl std::fmt::Display for RemoteDatabase {
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
use crate::remote::client::test_utils::client_with_handler;
use crate::remote::client::test_utils::MockSender;
impl RemoteDatabase<MockSender> {
pub fn new_mock<F, T>(handler: F) -> Self
where
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
T: Into<reqwest::Body>,
{
let client = client_with_handler(handler);
Self { client }
}
}
}
impl<S: HttpSend> std::fmt::Display for RemoteDatabase<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RemoteDatabase(host={})", self.client.host())
}
}
#[async_trait]
impl ConnectionInternal for RemoteDatabase {
impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
let mut req = self.client.get("/v1/table/");
if let Some(limit) = options.limit {
@@ -71,7 +89,7 @@ impl ConnectionInternal for RemoteDatabase {
if let Some(start_after) = options.start_after {
req = req.query(&[("page_token", start_after)]);
}
let rsp = req.send().await?;
let rsp = self.client.send(req).await?;
let rsp = self.client.check_response(rsp).await?;
Ok(rsp.json::<ListTablesResponse>().await?.tables)
}
@@ -88,15 +106,24 @@ impl ConnectionInternal for RemoteDatabase {
.await
.unwrap()?;
let rsp = self
let req = self
.client
.post(&format!("/v1/table/{}/create/", options.name))
.body(data_buffer)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
// This is currently expected by LanceDb cloud but will be removed soon.
.header("x-request-id", "na")
.send()
.await?;
.header("x-request-id", "na");
let rsp = self.client.send(req).await?;
if rsp.status() == StatusCode::BAD_REQUEST {
let body = rsp.text().await?;
if body.contains("already exists") {
return Err(crate::Error::TableAlreadyExists { name: options.name });
} else {
return Err(crate::Error::InvalidInput { message: body });
}
}
self.client.check_response(rsp).await?;
Ok(Table::new(Arc::new(RemoteTable::new(
@@ -105,19 +132,206 @@ impl ConnectionInternal for RemoteDatabase {
))))
}
async fn do_open_table(&self, _options: OpenTableBuilder) -> Result<Table> {
todo!()
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table> {
// We describe the table to confirm it exists before moving on.
// TODO: a TTL cache of table existence
let req = self
.client
.get(&format!("/v1/table/{}/describe/", options.name));
let resp = self.client.send(req).await?;
if resp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: options.name });
}
self.client.check_response(resp).await?;
Ok(Table::new(Arc::new(RemoteTable::new(
self.client.clone(),
options.name,
))))
}
async fn drop_table(&self, _name: &str) -> Result<()> {
todo!()
async fn drop_table(&self, name: &str) -> Result<()> {
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
let resp = self.client.send(req).await?;
self.client.check_response(resp).await?;
Ok(())
}
async fn drop_db(&self) -> Result<()> {
todo!()
Err(crate::Error::NotSupported {
message: "Dropping databases is not supported in the remote API".to_string(),
})
}
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use crate::{remote::db::ARROW_STREAM_CONTENT_TYPE, Connection};
#[tokio::test]
async fn test_table_names() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/table/");
assert_eq!(request.url().query(), None);
http::Response::builder()
.status(200)
.body(r#"{"tables": ["table1", "table2"]}"#)
.unwrap()
});
let names = conn.table_names().execute().await.unwrap();
assert_eq!(names, vec!["table1", "table2"]);
}
#[tokio::test]
async fn test_table_names_pagination() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/table/");
assert!(request.url().query().unwrap().contains("limit=2"));
assert!(request.url().query().unwrap().contains("page_token=table2"));
http::Response::builder()
.status(200)
.body(r#"{"tables": ["table3", "table4"], "page_token": "token"}"#)
.unwrap()
});
let names = conn
.table_names()
.start_after("table2")
.limit(2)
.execute()
.await
.unwrap();
assert_eq!(names, vec!["table3", "table4"]);
}
#[tokio::test]
async fn test_open_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/table/table1/describe/");
assert_eq!(request.url().query(), None);
http::Response::builder()
.status(200)
.body(r#"{"table": "table1"}"#)
.unwrap()
});
let table = conn.open_table("table1").execute().await.unwrap();
assert_eq!(table.name(), "table1");
// Storage options should be ignored.
let table = conn
.open_table("table1")
.storage_option("key", "value")
.execute()
.await
.unwrap();
assert_eq!(table.name(), "table1");
}
#[tokio::test]
async fn test_open_table_not_found() {
let conn = Connection::new_with_handler(|_| {
http::Response::builder()
.status(404)
.body("table not found")
.unwrap()
});
let result = conn.open_table("table1").execute().await;
assert!(result.is_err());
assert!(matches!(result, Err(crate::Error::TableNotFound { .. })));
}
#[tokio::test]
async fn test_create_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/create/");
assert_eq!(
request
.headers()
.get(reqwest::header::CONTENT_TYPE)
.unwrap(),
ARROW_STREAM_CONTENT_TYPE.as_bytes()
);
http::Response::builder().status(200).body("").unwrap()
});
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema());
let table = conn.create_table("table1", reader).execute().await.unwrap();
assert_eq!(table.name(), "table1");
}
#[tokio::test]
async fn test_create_table_already_exists() {
let conn = Connection::new_with_handler(|_| {
http::Response::builder()
.status(400)
.body("table table1 already exists")
.unwrap()
});
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema());
let result = conn.create_table("table1", reader).execute().await;
assert!(result.is_err());
assert!(
matches!(result, Err(crate::Error::TableAlreadyExists { name }) if name == "table1")
);
}
#[tokio::test]
async fn test_create_table_empty() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/create/");
assert_eq!(
request
.headers()
.get(reqwest::header::CONTENT_TYPE)
.unwrap(),
ARROW_STREAM_CONTENT_TYPE.as_bytes()
);
http::Response::builder().status(200).body("").unwrap()
});
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
conn.create_empty_table("table1", schema)
.execute()
.await
.unwrap();
}
#[tokio::test]
async fn test_drop_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/drop/");
assert_eq!(request.url().query(), None);
assert!(request.body().is_none());
http::Response::builder().status(200).body("").unwrap()
});
conn.drop_table("table1").await.unwrap();
// NOTE: the API will return 200 even if the table does not exist. So we shouldn't expect 404.
}
}

Some files were not shown because too many files have changed in this diff Show More