mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
41 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
515ab5f417 | ||
|
|
8d0055fe6b | ||
|
|
5f9d8509b3 | ||
|
|
f3b6a1f55b | ||
|
|
aff25e3bf9 | ||
|
|
8509f73221 | ||
|
|
607476788e | ||
|
|
4d458d5829 | ||
|
|
e61ba7f4e2 | ||
|
|
408bc96a44 | ||
|
|
6ceaf8b06e | ||
|
|
e2ca8daee1 | ||
|
|
f305f34d9b | ||
|
|
a416925ca1 | ||
|
|
2c4b07eb17 | ||
|
|
33b402c861 | ||
|
|
7b2cdd2269 | ||
|
|
d6b5054778 | ||
|
|
f0e7f5f665 | ||
|
|
f958f4d2e8 | ||
|
|
c1d9d6f70b | ||
|
|
1778219ea9 | ||
|
|
ee6c18f207 | ||
|
|
e606a455df | ||
|
|
8f0eb34109 | ||
|
|
2f2721e242 | ||
|
|
f00b21c98c | ||
|
|
962b3afd17 | ||
|
|
b72ac073ab | ||
|
|
3152ccd13c | ||
|
|
d5021356b4 | ||
|
|
e82f63b40a | ||
|
|
f81ce68e41 | ||
|
|
f5c25b6fff | ||
|
|
86978e7588 | ||
|
|
7c314d61cc | ||
|
|
7a8d2f37c4 | ||
|
|
11072b9edc | ||
|
|
915d828cee | ||
|
|
d9a72adc58 | ||
|
|
d6cf2dafc6 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.10.0"
|
||||
current_version = "0.11.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
@@ -66,6 +66,32 @@ glob = "nodejs/npm/*/package.json"
|
||||
replace = "\"version\": \"{new_version}\","
|
||||
search = "\"version\": \"{current_version}\","
|
||||
|
||||
# vectodb node binary packages
|
||||
[[tool.bumpversion.files]]
|
||||
glob = "node/package.json"
|
||||
replace = "\"@lancedb/vectordb-darwin-arm64\": \"{new_version}\""
|
||||
search = "\"@lancedb/vectordb-darwin-arm64\": \"{current_version}\""
|
||||
|
||||
[[tool.bumpversion.files]]
|
||||
glob = "node/package.json"
|
||||
replace = "\"@lancedb/vectordb-darwin-x64\": \"{new_version}\""
|
||||
search = "\"@lancedb/vectordb-darwin-x64\": \"{current_version}\""
|
||||
|
||||
[[tool.bumpversion.files]]
|
||||
glob = "node/package.json"
|
||||
replace = "\"@lancedb/vectordb-linux-arm64-gnu\": \"{new_version}\""
|
||||
search = "\"@lancedb/vectordb-linux-arm64-gnu\": \"{current_version}\""
|
||||
|
||||
[[tool.bumpversion.files]]
|
||||
glob = "node/package.json"
|
||||
replace = "\"@lancedb/vectordb-linux-x64-gnu\": \"{new_version}\""
|
||||
search = "\"@lancedb/vectordb-linux-x64-gnu\": \"{current_version}\""
|
||||
|
||||
[[tool.bumpversion.files]]
|
||||
glob = "node/package.json"
|
||||
replace = "\"@lancedb/vectordb-win32-x64-msvc\": \"{new_version}\""
|
||||
search = "\"@lancedb/vectordb-win32-x64-msvc\": \"{current_version}\""
|
||||
|
||||
# Cargo files
|
||||
# ------------
|
||||
[[tool.bumpversion.files]]
|
||||
@@ -77,3 +103,8 @@ search = "\nversion = \"{current_version}\""
|
||||
filename = "rust/lancedb/Cargo.toml"
|
||||
replace = "\nversion = \"{new_version}\""
|
||||
search = "\nversion = \"{current_version}\""
|
||||
|
||||
[[tool.bumpversion.files]]
|
||||
filename = "nodejs/Cargo.toml"
|
||||
replace = "\nversion = \"{new_version}\""
|
||||
search = "\nversion = \"{current_version}\""
|
||||
|
||||
4
.github/workflows/docs_test.yml
vendored
4
.github/workflows/docs_test.yml
vendored
@@ -24,7 +24,7 @@ env:
|
||||
jobs:
|
||||
test-python:
|
||||
name: Test doc python code
|
||||
runs-on: "warp-ubuntu-latest-x64-4x"
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
|
||||
test-node:
|
||||
name: Test doc nodejs code
|
||||
runs-on: "warp-ubuntu-latest-x64-4x"
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
23
.github/workflows/rust.yml
vendored
23
.github/workflows/rust.yml
vendored
@@ -26,15 +26,14 @@ env:
|
||||
jobs:
|
||||
lint:
|
||||
timeout-minutes: 30
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-24.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: rust
|
||||
env:
|
||||
# Need up-to-date compilers for kernels
|
||||
CC: gcc-12
|
||||
CXX: g++-12
|
||||
CC: clang-18
|
||||
CXX: clang++-18
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -50,21 +49,21 @@ jobs:
|
||||
- name: Run format
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Run clippy
|
||||
run: cargo clippy --all --all-features -- -D warnings
|
||||
run: cargo clippy --workspace --tests --all-features -- -D warnings
|
||||
linux:
|
||||
timeout-minutes: 30
|
||||
# To build all features, we need more disk space than is available
|
||||
# on the GitHub-provided runner. This is mostly due to the the
|
||||
# on the free OSS github runner. This is mostly due to the the
|
||||
# sentence-transformers feature.
|
||||
runs-on: warp-ubuntu-latest-x64-4x
|
||||
runs-on: ubuntu-2404-4x-x64
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: rust
|
||||
env:
|
||||
# Need up-to-date compilers for kernels
|
||||
CC: gcc-12
|
||||
CXX: g++-12
|
||||
CC: clang-18
|
||||
CXX: clang++-18
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -77,6 +76,12 @@ jobs:
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Make Swap
|
||||
run: |
|
||||
sudo fallocate -l 16G /swapfile
|
||||
sudo chmod 600 /swapfile
|
||||
sudo mkswap /swapfile
|
||||
sudo swapon /swapfile
|
||||
- name: Start S3 integration test environment
|
||||
working-directory: .
|
||||
run: docker compose up --detach --wait
|
||||
|
||||
19
Cargo.toml
19
Cargo.toml
@@ -20,13 +20,13 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.18.0", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.18.0" }
|
||||
lance-linalg = { "version" = "=0.18.0" }
|
||||
lance-table = { "version" = "=0.18.0" }
|
||||
lance-testing = { "version" = "=0.18.0" }
|
||||
lance-datafusion = { "version" = "=0.18.0" }
|
||||
lance-encoding = { "version" = "=0.18.0" }
|
||||
lance = { "version" = "=0.18.2", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.18.2" }
|
||||
lance-linalg = { "version" = "=0.18.2" }
|
||||
lance-table = { "version" = "=0.18.2" }
|
||||
lance-testing = { "version" = "=0.18.2" }
|
||||
lance-datafusion = { "version" = "=0.18.2" }
|
||||
lance-encoding = { "version" = "=0.18.2" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "52.2", optional = false }
|
||||
arrow-array = "52.2"
|
||||
@@ -38,16 +38,19 @@ arrow-arith = "52.2"
|
||||
arrow-cast = "52.2"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
datafusion-physical-plan = "40.0"
|
||||
datafusion-common = "41.0"
|
||||
datafusion-physical-plan = "41.0"
|
||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
] }
|
||||
futures = "0"
|
||||
log = "0.4"
|
||||
moka = { version = "0.11", features = ["future"] }
|
||||
object_store = "0.10.2"
|
||||
pin-project = "1.0.7"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
num-traits = "0.2"
|
||||
rand = "0.8"
|
||||
regex = "1.10"
|
||||
lazy_static = "1"
|
||||
|
||||
@@ -82,4 +82,4 @@ result = table.search([100, 100]).limit(2).to_pandas()
|
||||
|
||||
## Blogs, Tutorials & Videos
|
||||
* 📈 <a href="https://blog.lancedb.com/benchmarking-random-access-in-lance/">2000x better performance with Lance over Parquet</a>
|
||||
* 🤖 <a href="https://github.com/lancedb/lancedb/blob/main/docs/src/notebooks/youtube_transcript_search.ipynb">Build a question and answer bot with LanceDB</a>
|
||||
* 🤖 <a href="https://github.com/lancedb/vectordb-recipes/tree/main/examples/Youtube-Search-QA-Bot">Build a question and answer bot with LanceDB</a>
|
||||
|
||||
@@ -34,6 +34,7 @@ theme:
|
||||
- navigation.footer
|
||||
- navigation.tracking
|
||||
- navigation.instant
|
||||
- content.footnote.tooltips
|
||||
icon:
|
||||
repo: fontawesome/brands/github
|
||||
annotation: material/arrow-right-circle
|
||||
@@ -65,6 +66,11 @@ plugins:
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
- footnotes
|
||||
- pymdownx.critic
|
||||
- pymdownx.caret
|
||||
- pymdownx.keys
|
||||
- pymdownx.mark
|
||||
- pymdownx.tilde
|
||||
- pymdownx.details
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
@@ -114,6 +120,7 @@ nav:
|
||||
- Graph RAG: rag/graph_rag.md
|
||||
- Self RAG: rag/self_rag.md
|
||||
- Adaptive RAG: rag/adaptive_rag.md
|
||||
- SFR RAG: rag/sfr_rag.md
|
||||
- Advanced Techniques:
|
||||
- HyDE: rag/advanced_techniques/hyde.md
|
||||
- FLARE: rag/advanced_techniques/flare.md
|
||||
@@ -177,6 +184,7 @@ nav:
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- dlt: integrations/dlt.md
|
||||
- phidata: integrations/phidata.md
|
||||
- 🎯 Examples:
|
||||
- Overview: examples/index.md
|
||||
- 🐍 Python:
|
||||
@@ -240,6 +248,7 @@ nav:
|
||||
- Graph RAG: rag/graph_rag.md
|
||||
- Self RAG: rag/self_rag.md
|
||||
- Adaptive RAG: rag/adaptive_rag.md
|
||||
- SFR RAG: rag/sfr_rag.md
|
||||
- Advanced Techniques:
|
||||
- HyDE: rag/advanced_techniques/hyde.md
|
||||
- FLARE: rag/advanced_techniques/flare.md
|
||||
@@ -299,6 +308,7 @@ nav:
|
||||
- Voxel51: integrations/voxel51.md
|
||||
- PromptTools: integrations/prompttools.md
|
||||
- dlt: integrations/dlt.md
|
||||
- phidata: integrations/phidata.md
|
||||
- Examples:
|
||||
- examples/index.md
|
||||
- 🐍 Python:
|
||||
@@ -354,4 +364,5 @@ extra:
|
||||
- icon: fontawesome/brands/x-twitter
|
||||
link: https://twitter.com/lancedb
|
||||
- icon: fontawesome/brands/linkedin
|
||||
link: https://www.linkedin.com/company/lancedb
|
||||
link: https://www.linkedin.com/company/lancedb
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Huggingface embedding models
|
||||
We offer support for all huggingface models (which can be loaded via [transformers](https://huggingface.co/docs/transformers/en/index) library). The default model is `colbert-ir/colbertv2.0` which also has its own special callout - `registry.get("colbert")`
|
||||
We offer support for all Hugging Face models (which can be loaded via [transformers](https://huggingface.co/docs/transformers/en/index) library). The default model is `colbert-ir/colbertv2.0` which also has its own special callout - `registry.get("colbert")`. Some Hugging Face models might require custom models defined on the HuggingFace Hub in their own modeling files. You may enable this by setting `trust_remote_code=True`. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.
|
||||
|
||||
Example usage -
|
||||
```python
|
||||
|
||||
@@ -8,9 +8,15 @@ LanceDB provides language APIs, allowing you to embed a database in your languag
|
||||
* 👾 [JavaScript](examples_js.md) examples
|
||||
* 🦀 Rust examples (coming soon)
|
||||
|
||||
## Applications powered by LanceDB
|
||||
## Python Applications powered by LanceDB
|
||||
|
||||
| Project Name | Description |
|
||||
| --- | --- |
|
||||
| **Ultralytics Explorer 🚀**<br>[](https://docs.ultralytics.com/datasets/explorer/)<br>[](https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/docs/en/datasets/explorer/explorer.ipynb) | - 🔍 **Explore CV Datasets**: Semantic search, SQL queries, vector similarity, natural language.<br>- 🖥️ **GUI & Python API**: Seamless dataset interaction.<br>- ⚡ **Efficient & Scalable**: Leverages LanceDB for large datasets.<br>- 📊 **Detailed Analysis**: Easily analyze data patterns.<br>- 🌐 **Browser GUI Demo**: Create embeddings, search images, run queries. |
|
||||
| **Website Chatbot🤖**<br>[](https://github.com/lancedb/lancedb-vercel-chatbot)<br>[](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Flancedb%2Flancedb-vercel-chatbot&env=OPENAI_API_KEY&envDescription=OpenAI%20API%20Key%20for%20chat%20completion.&project-name=lancedb-vercel-chatbot&repository-name=lancedb-vercel-chatbot&demo-title=LanceDB%20Chatbot%20Demo&demo-description=Demo%20website%20chatbot%20with%20LanceDB.&demo-url=https%3A%2F%2Flancedb.vercel.app&demo-image=https%3A%2F%2Fi.imgur.com%2FazVJtvr.png) | - 🌐 **Chatbot from Sitemap/Docs**: Create a chatbot using site or document context.<br>- 🚀 **Embed LanceDB in Next.js**: Lightweight, on-prem storage.<br>- 🧠 **AI-Powered Context Retrieval**: Efficiently access relevant data.<br>- 🔧 **Serverless & Native JS**: Seamless integration with Next.js.<br>- ⚡ **One-Click Deploy on Vercel**: Quick and easy setup.. |
|
||||
|
||||
## Nodejs Applications powered by LanceDB
|
||||
|
||||
| Project Name | Description |
|
||||
| --- | --- |
|
||||
| **Langchain Writing Assistant✍️ **<br>[](https://github.com/lancedb/vectordb-recipes/tree/main/applications/node/lanchain_writing_assistant) | - **📂 Data Source Integration**: Use your own data by specifying data source file, and the app instantly processes it to provide insights. <br>- **🧠 Intelligent Suggestions**: Powered by LangChain.js and LanceDB, it improves writing productivity and accuracy. <br>- **💡 Enhanced Writing Experience**: It delivers real-time contextual insights and factual suggestions while the user writes. |
|
||||
383
docs/src/integrations/phidata.md
Normal file
383
docs/src/integrations/phidata.md
Normal file
@@ -0,0 +1,383 @@
|
||||
**phidata** is a framework for building **AI Assistants** with long-term memory, contextual knowledge, and the ability to take actions using function calling. It helps turn general-purpose LLMs into specialized assistants tailored to your use case by extending its capabilities using **memory**, **knowledge**, and **tools**.
|
||||
|
||||
- **Memory**: Stores chat history in a **database** and enables LLMs to have long-term conversations.
|
||||
- **Knowledge**: Stores information in a **vector database** and provides LLMs with business context. (Here we will use LanceDB)
|
||||
- **Tools**: Enable LLMs to take actions like pulling data from an **API**, **sending emails** or **querying a database**, etc.
|
||||
|
||||

|
||||
|
||||
Memory & knowledge make LLMs smarter while tools make them autonomous.
|
||||
|
||||
LanceDB is a vector database and its integration into phidata makes it easy for us to provide a **knowledge base** to LLMs. It enables us to store information as [embeddings](../embeddings/understanding_embeddings.md) and search for the **results** similar to ours using **query**.
|
||||
|
||||
??? Question "What is Knowledge Base?"
|
||||
Knowledge Base is a database of information that the Assistant can search to improve its responses. This information is stored in a vector database and provides LLMs with business context, which makes them respond in a context-aware manner.
|
||||
|
||||
While any type of storage can act as a knowledge base, vector databases offer the best solution for retrieving relevant results from dense information quickly.
|
||||
|
||||
Let's see how using LanceDB inside phidata helps in making LLM more useful:
|
||||
|
||||
## Prerequisites: install and import necessary dependencies
|
||||
|
||||
**Create a virtual environment**
|
||||
|
||||
1. install virtualenv package
|
||||
```python
|
||||
pip install virtualenv
|
||||
```
|
||||
2. Create a directory for your project and go to the directory and create a virtual environment inside it.
|
||||
```python
|
||||
mkdir phi
|
||||
```
|
||||
```python
|
||||
cd phi
|
||||
```
|
||||
```python
|
||||
python -m venv phidata_
|
||||
```
|
||||
|
||||
**Activating virtual environment**
|
||||
|
||||
1. from inside the project directory, run the following command to activate the virtual environment.
|
||||
```python
|
||||
phidata_/Scripts/activate
|
||||
```
|
||||
|
||||
**Install the following packages in the virtual environment**
|
||||
```python
|
||||
pip install lancedb phidata youtube_transcript_api openai ollama pandas numpy
|
||||
```
|
||||
|
||||
**Create python files and import necessary libraries**
|
||||
|
||||
You need to create two files - `transcript.py` and `ollama_assistant.py` or `openai_assistant.py`
|
||||
|
||||
=== "openai_assistant.py"
|
||||
|
||||
```python
|
||||
import os, openai
|
||||
from rich.prompt import Prompt
|
||||
from phi.assistant import Assistant
|
||||
from phi.knowledge.text import TextKnowledgeBase
|
||||
from phi.vectordb.lancedb import LanceDb
|
||||
from phi.llm.openai import OpenAIChat
|
||||
from phi.embedder.openai import OpenAIEmbedder
|
||||
from transcript import extract_transcript
|
||||
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
|
||||
# The code below creates a file "transcript.txt" in the directory, the txt file will be used below
|
||||
youtube_url = "https://www.youtube.com/watch?v=Xs33-Gzl8Mo"
|
||||
segment_duration = 20
|
||||
transcript_text,dict_transcript = extract_transcript(youtube_url,segment_duration)
|
||||
```
|
||||
|
||||
=== "ollama_assistant.py"
|
||||
|
||||
```python
|
||||
from rich.prompt import Prompt
|
||||
from phi.assistant import Assistant
|
||||
from phi.knowledge.text import TextKnowledgeBase
|
||||
from phi.vectordb.lancedb import LanceDb
|
||||
from phi.llm.ollama import Ollama
|
||||
from phi.embedder.ollama import OllamaEmbedder
|
||||
from transcript import extract_transcript
|
||||
|
||||
# The code below creates a file "transcript.txt" in the directory, the txt file will be used below
|
||||
youtube_url = "https://www.youtube.com/watch?v=Xs33-Gzl8Mo"
|
||||
segment_duration = 20
|
||||
transcript_text,dict_transcript = extract_transcript(youtube_url,segment_duration)
|
||||
```
|
||||
|
||||
=== "transcript.py"
|
||||
|
||||
``` python
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
import re
|
||||
|
||||
def smodify(seconds):
|
||||
hours, remainder = divmod(seconds, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"
|
||||
|
||||
def extract_transcript(youtube_url,segment_duration):
|
||||
# Extract video ID from the URL
|
||||
video_id = re.search(r'(?<=v=)[\w-]+', youtube_url)
|
||||
if not video_id:
|
||||
video_id = re.search(r'(?<=be/)[\w-]+', youtube_url)
|
||||
if not video_id:
|
||||
return None
|
||||
|
||||
video_id = video_id.group(0)
|
||||
|
||||
# Attempt to fetch the transcript
|
||||
try:
|
||||
# Try to get the official transcript
|
||||
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
|
||||
except Exception:
|
||||
# If no official transcript is found, try to get auto-generated transcript
|
||||
try:
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
|
||||
for transcript in transcript_list:
|
||||
transcript = transcript.translate('en').fetch()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Format the transcript into 120s chunks
|
||||
transcript_text,dict_transcript = format_transcript(transcript,segment_duration)
|
||||
# Open the file in write mode, which creates it if it doesn't exist
|
||||
with open("transcript.txt", "w",encoding="utf-8") as file:
|
||||
file.write(transcript_text)
|
||||
return transcript_text,dict_transcript
|
||||
|
||||
def format_transcript(transcript,segment_duration):
|
||||
chunked_transcript = []
|
||||
chunk_dict = []
|
||||
current_chunk = []
|
||||
current_time = 0
|
||||
# 2 minutes in seconds
|
||||
start_time_chunk = 0 # To track the start time of the current chunk
|
||||
|
||||
for segment in transcript:
|
||||
start_time = segment['start']
|
||||
end_time_x = start_time + segment['duration']
|
||||
text = segment['text']
|
||||
|
||||
# Add text to the current chunk
|
||||
current_chunk.append(text)
|
||||
|
||||
# Update the current time with the duration of the current segment
|
||||
# The duration of the current segment is given by segment['start'] - start_time_chunk
|
||||
if current_chunk:
|
||||
current_time = start_time - start_time_chunk
|
||||
|
||||
# If current chunk duration reaches or exceeds 2 minutes, save the chunk
|
||||
if current_time >= segment_duration:
|
||||
# Use the start time of the first segment in the current chunk as the timestamp
|
||||
chunked_transcript.append(f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}] " + " ".join(current_chunk))
|
||||
current_chunk = re.sub(r'[\xa0\n]', lambda x: '' if x.group() == '\xa0' else ' ', "\n".join(current_chunk))
|
||||
chunk_dict.append({"timestamp":f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}]", "text": "".join(current_chunk)})
|
||||
current_chunk = [] # Reset the chunk
|
||||
start_time_chunk = start_time + segment['duration'] # Update the start time for the next chunk
|
||||
current_time = 0 # Reset current time
|
||||
|
||||
# Add any remaining text in the last chunk
|
||||
if current_chunk:
|
||||
chunked_transcript.append(f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}] " + " ".join(current_chunk))
|
||||
current_chunk = re.sub(r'[\xa0\n]', lambda x: '' if x.group() == '\xa0' else ' ', "\n".join(current_chunk))
|
||||
chunk_dict.append({"timestamp":f"[{smodify(start_time_chunk)} to {smodify(end_time_x)}]", "text": "".join(current_chunk)})
|
||||
|
||||
return "\n\n".join(chunked_transcript), chunk_dict
|
||||
```
|
||||
|
||||
!!! warning
|
||||
If creating Ollama assistant, download and install Ollama [from here](https://ollama.com/) and then run the Ollama instance in the background. Also, download the required models using `ollama pull <model-name>`. Check out the models [here](https://ollama.com/library)
|
||||
|
||||
|
||||
**Run the following command to deactivate the virtual environment if needed**
|
||||
```python
|
||||
deactivate
|
||||
```
|
||||
|
||||
## **Step 1** - Create a Knowledge Base for AI Assistant using LanceDB
|
||||
|
||||
=== "openai_assistant.py"
|
||||
|
||||
```python
|
||||
# Create knowledge Base with OpenAIEmbedder in LanceDB
|
||||
knowledge_base = TextKnowledgeBase(
|
||||
path="transcript.txt",
|
||||
vector_db=LanceDb(
|
||||
embedder=OpenAIEmbedder(api_key = openai.api_key),
|
||||
table_name="transcript_documents",
|
||||
uri="./t3mp/.lancedb",
|
||||
),
|
||||
num_documents = 10
|
||||
)
|
||||
```
|
||||
|
||||
=== "ollama_assistant.py"
|
||||
|
||||
```python
|
||||
# Create knowledge Base with OllamaEmbedder in LanceDB
|
||||
knowledge_base = TextKnowledgeBase(
|
||||
path="transcript.txt",
|
||||
vector_db=LanceDb(
|
||||
embedder=OllamaEmbedder(model="nomic-embed-text",dimensions=768),
|
||||
table_name="transcript_documents",
|
||||
uri="./t2mp/.lancedb",
|
||||
),
|
||||
num_documents = 10
|
||||
)
|
||||
```
|
||||
Check out the list of **embedders** supported by **phidata** and their usage [here](https://docs.phidata.com/embedder/introduction).
|
||||
|
||||
Here we have used `TextKnowledgeBase`, which loads text/docx files to the knowledge base.
|
||||
|
||||
Let's see all the parameters that `TextKnowledgeBase` takes -
|
||||
|
||||
| Name| Type | Purpose | Default |
|
||||
|:----|:-----|:--------|:--------|
|
||||
|`path`|`Union[str, Path]`| Path to text file(s). It can point to a single text file or a directory of text files.| provided by user |
|
||||
|`formats`|`List[str]`| File formats accepted by this knowledge base. |`[".txt"]`|
|
||||
|`vector_db`|`VectorDb`| Vector Database for the Knowledge Base. phidata provides a wrapper around many vector DBs, you can import it like this - `from phi.vectordb.lancedb import LanceDb` | provided by user |
|
||||
|`num_documents`|`int`| Number of results (documents/vectors) that vector search should return. |`5`|
|
||||
|`reader`|`TextReader`| phidata provides many types of reader objects which read data, clean it and create chunks of data, encapsulate each chunk inside an object of the `Document` class, and return **`List[Document]`**. | `TextReader()` |
|
||||
|`optimize_on`|`int`| It is used to specify the number of documents on which to optimize the vector database. Supposed to create an index. |`1000`|
|
||||
|
||||
??? Tip "Wonder! What is `Document` class?"
|
||||
We know that, before storing the data in vectorDB, we need to split the data into smaller chunks upon which embeddings will be created and these embeddings along with the chunks will be stored in vectorDB. When the user queries over the vectorDB, some of these embeddings will be returned as the result based on the semantic similarity with the query.
|
||||
|
||||
When the user queries over vectorDB, the queries are converted into embeddings, and a nearest neighbor search is performed over these query embeddings which returns the embeddings that correspond to most semantically similar chunks(parts of our data) present in vectorDB.
|
||||
|
||||
Here, a “Document” is a class in phidata. Since there is an option to let phidata create and manage embeddings, it splits our data into smaller chunks(as expected). It does not directly create embeddings on it. Instead, it takes each chunk and encapsulates it inside the object of the `Document` class along with various other metadata related to the chunk. Then embeddings are created on these `Document` objects and stored in vectorDB.
|
||||
|
||||
```python
|
||||
class Document(BaseModel):
|
||||
"""Model for managing a document"""
|
||||
|
||||
content: str # <--- here data of chunk is stored
|
||||
id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
meta_data: Dict[str, Any] = {}
|
||||
embedder: Optional[Embedder] = None
|
||||
embedding: Optional[List[float]] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
```
|
||||
|
||||
However, using phidata you can load many other types of data in the knowledge base(other than text). Check out [phidata Knowledge Base](https://docs.phidata.com/knowledge/introduction) for more information.
|
||||
|
||||
Let's dig deeper into the `vector_db` parameter and see what parameters `LanceDb` takes -
|
||||
|
||||
| Name| Type | Purpose | Default |
|
||||
|:----|:-----|:--------|:--------|
|
||||
|`embedder`|`Embedder`| phidata provides many Embedders that abstract the interaction with embedding APIs and utilize it to generate embeddings. Check out other embedders [here](https://docs.phidata.com/embedder/introduction) | `OpenAIEmbedder` |
|
||||
|`distance`|`List[str]`| The choice of distance metric used to calculate the similarity between vectors, which directly impacts search results and performance in vector databases. |`Distance.cosine`|
|
||||
|`connection`|`lancedb.db.LanceTable`| LanceTable can be accessed through `.connection`. You can connect to an existing table of LanceDB, created outside of phidata, and utilize it. If not provided, it creates a new table using `table_name` parameter and adds it to `connection`. |`None`|
|
||||
|`uri`|`str`| It specifies the directory location of **LanceDB database** and establishes a connection that can be used to interact with the database. | `"/tmp/lancedb"` |
|
||||
|`table_name`|`str`| If `connection` is not provided, it initializes and connects to a new **LanceDB table** with a specified(or default) name in the database present at `uri`. |`"phi"`|
|
||||
|`nprobes`|`int`| It refers to the number of partitions that the search algorithm examines to find the nearest neighbors of a given query vector. Higher values will yield better recall (more likely to find vectors if they exist) at the expense of latency. |`20`|
|
||||
|
||||
|
||||
!!! note
|
||||
Since we just initialized the KnowledgeBase. The VectorDB table that corresponds to this Knowledge Base is not yet populated with our data. It will be populated in **Step 3**, once we perform the `load` operation.
|
||||
|
||||
You can check the state of the LanceDB table using - `knowledge_base.vector_db.connection.to_pandas()`
|
||||
|
||||
Now that the Knowledge Base is initialized, , we can go to **step 2**.
|
||||
|
||||
## **Step 2** - Create an assistant with our choice of LLM and reference to the knowledge base.
|
||||
|
||||
|
||||
=== "openai_assistant.py"
|
||||
|
||||
```python
|
||||
# define an assistant with gpt-4o-mini llm and reference to the knowledge base created above
|
||||
assistant = Assistant(
|
||||
llm=OpenAIChat(model="gpt-4o-mini", max_tokens=1000, temperature=0.3,api_key = openai.api_key),
|
||||
description="""You are an Expert in explaining youtube video transcripts. You are a bot that takes transcript of a video and answer the question based on it.
|
||||
|
||||
This is transcript for the above timestamp: {relevant_document}
|
||||
The user input is: {user_input}
|
||||
generate highlights only when asked.
|
||||
When asked to generate highlights from the video, understand the context for each timestamp and create key highlight points, answer in following way -
|
||||
[timestamp] - highlight 1
|
||||
[timestamp] - highlight 2
|
||||
... so on
|
||||
|
||||
Your task is to understand the user question, and provide an answer using the provided contexts. Your answers are correct, high-quality, and written by an domain expert. If the provided context does not contain the answer, simply state,'The provided context does not have the answer.'""",
|
||||
knowledge_base=knowledge_base,
|
||||
add_references_to_prompt=True,
|
||||
)
|
||||
```
|
||||
|
||||
=== "ollama_assistant.py"
|
||||
|
||||
```python
|
||||
# define an assistant with llama3.1 llm and reference to the knowledge base created above
|
||||
assistant = Assistant(
|
||||
llm=Ollama(model="llama3.1"),
|
||||
description="""You are an Expert in explaining youtube video transcripts. You are a bot that takes transcript of a video and answer the question based on it.
|
||||
|
||||
This is transcript for the above timestamp: {relevant_document}
|
||||
The user input is: {user_input}
|
||||
generate highlights only when asked.
|
||||
When asked to generate highlights from the video, understand the context for each timestamp and create key highlight points, answer in following way -
|
||||
[timestamp] - highlight 1
|
||||
[timestamp] - highlight 2
|
||||
... so on
|
||||
|
||||
Your task is to understand the user question, and provide an answer using the provided contexts. Your answers are correct, high-quality, and written by an domain expert. If the provided context does not contain the answer, simply state,'The provided context does not have the answer.'""",
|
||||
knowledge_base=knowledge_base,
|
||||
add_references_to_prompt=True,
|
||||
)
|
||||
```
|
||||
|
||||
Assistants add **memory**, **knowledge**, and **tools** to LLMs. Here we will add only **knowledge** in this example.
|
||||
|
||||
Whenever we will give a query to LLM, the assistant will retrieve relevant information from our **Knowledge Base**(table in LanceDB) and pass it to LLM along with the user query in a structured way.
|
||||
|
||||
- The `add_references_to_prompt=True` always adds information from the knowledge base to the prompt, regardless of whether it is relevant to the question.
|
||||
|
||||
To know more about an creating assistant in phidata, check out [phidata docs](https://docs.phidata.com/assistants/introduction) here.
|
||||
|
||||
## **Step 3** - Load data to Knowledge Base.
|
||||
|
||||
```python
|
||||
# load out data into the knowledge_base (populating the LanceTable)
|
||||
assistant.knowledge_base.load(recreate=False)
|
||||
```
|
||||
The above code loads the data to the Knowledge Base(LanceDB Table) and now it is ready to be used by the assistant.
|
||||
|
||||
| Name| Type | Purpose | Default |
|
||||
|:----|:-----|:--------|:--------|
|
||||
|`recreate`|`bool`| If True, it drops the existing table and recreates the table in the vectorDB. |`False`|
|
||||
|`upsert`|`bool`| If True and the vectorDB supports upsert, it will upsert documents to the vector db. | `False` |
|
||||
|`skip_existing`|`bool`| If True, skips documents that already exist in the vectorDB when inserting. |`True`|
|
||||
|
||||
??? tip "What is upsert?"
|
||||
Upsert is a database operation that combines "update" and "insert". It updates existing records if a document with the same identifier does exist, or inserts new records if no matching record exists. This is useful for maintaining the most current information without manually checking for existence.
|
||||
|
||||
During the Load operation, phidata directly interacts with the LanceDB library and performs the loading of the table with our data in the following steps -
|
||||
|
||||
1. **Creates** and **initializes** the table if it does not exist.
|
||||
|
||||
2. Then it **splits** our data into smaller **chunks**.
|
||||
|
||||
??? question "How do they create chunks?"
|
||||
**phidata** provides many types of **Knowledge Bases** based on the type of data. Most of them :material-information-outline:{ title="except LlamaIndexKnowledgeBase and LangChainKnowledgeBase"} has a property method called `document_lists` of type `Iterator[List[Document]]`. During the load operation, this property method is invoked. It traverses on the data provided by us (in this case, a text file(s)) using `reader`. Then it **reads**, **creates chunks**, and **encapsulates** each chunk inside a `Document` object and yields **lists of `Document` objects** that contain our data.
|
||||
|
||||
3. Then **embeddings** are created on these chunks are **inserted** into the LanceDB Table
|
||||
|
||||
??? question "How do they insert your data as different rows in LanceDB Table?"
|
||||
The chunks of your data are in the form - **lists of `Document` objects**. It was yielded in the step above.
|
||||
|
||||
for each `Document` in `List[Document]`, it does the following operations:
|
||||
|
||||
- Creates embedding on `Document`.
|
||||
- Cleans the **content attribute**(chunks of our data is here) of `Document`.
|
||||
- Prepares data by creating `id` and loading `payload` with the metadata related to this chunk. (1)
|
||||
{ .annotate }
|
||||
|
||||
1. Three columns will be added to the table - `"id"`, `"vector"`, and `"payload"` (payload contains various metadata including **`content`**)
|
||||
|
||||
- Then add this data to LanceTable.
|
||||
|
||||
4. Now the internal state of `knowledge_base` is changed (embeddings are created and loaded in the table ) and it **ready to be used by assistant**.
|
||||
|
||||
## **Step 4** - Start a cli chatbot with access to the Knowledge base
|
||||
|
||||
```python
|
||||
# start cli chatbot with knowledge base
|
||||
assistant.print_response("Ask me about something from the knowledge base")
|
||||
while True:
|
||||
message = Prompt.ask(f"[bold] :sunglasses: User [/bold]")
|
||||
if message in ("exit", "bye"):
|
||||
break
|
||||
assistant.print_response(message, markdown=True)
|
||||
```
|
||||
|
||||
|
||||
For more information and amazing cookbooks of phidata, read the [phidata documentation](https://docs.phidata.com/introduction) and also visit [LanceDB x phidata docmentation](https://docs.phidata.com/vectordb/lancedb).
|
||||
@@ -1,13 +1,73 @@
|
||||
# FiftyOne
|
||||
|
||||
FiftyOne is an open source toolkit for building high-quality datasets and computer vision models. It provides an API to create LanceDB tables and run similarity queries, both programmatically in Python and via point-and-click in the App.
|
||||
FiftyOne is an open source toolkit that enables users to curate better data and build better models. It includes tools for data exploration, visualization, and management, as well as features for collaboration and sharing.
|
||||
|
||||
Any developers, data scientists, and researchers who work with computer vision and machine learning can use FiftyOne to improve the quality of their datasets and deliver insights about their models.
|
||||
|
||||
|
||||

|
||||
|
||||
## Basic recipe
|
||||
**FiftyOne** provides an API to create LanceDB tables and run similarity queries, both **programmatically in Python** and via **point-and-click in the App**.
|
||||
|
||||
The basic workflow shown below uses LanceDB to create a similarity index on your FiftyOne
|
||||
datasets:
|
||||
Let's get started and see how to use **LanceDB** to create a **similarity index** on your FiftyOne datasets.
|
||||
|
||||
## Overview
|
||||
|
||||
**[Embeddings](../embeddings/understanding_embeddings.md)** are foundational to all of the **vector search** features. In FiftyOne, embeddings are managed by the [**FiftyOne Brain**](https://docs.voxel51.com/user_guide/brain.html) that provides powerful machine learning techniques designed to transform how you curate your data from an art into a measurable science.
|
||||
|
||||
!!!question "Have you ever wanted to find the images most similar to an image in your dataset?"
|
||||
The **FiftyOne Brain** makes computing **visual similarity** really easy. You can compute the similarity of samples in your dataset using an embedding model and store the results in the **brain key**.
|
||||
|
||||
You can then sort your samples by similarity or use this information to find potential duplicate images.
|
||||
|
||||
Here we will be doing the following :
|
||||
|
||||
1. **Create Index** - In order to run similarity queries against our media, we need to **index** the data. We can do this via the `compute_similarity()` function.
|
||||
|
||||
- In the function, specify the **model** you want to use to generate the embedding vectors, and what **vector search engine** you want to use on the **backend** (here LanceDB).
|
||||
|
||||
!!!tip
|
||||
You can also give the similarity index a name(`brain_key`), which is useful if you want to run vector searches against multiple indexes.
|
||||
|
||||
2. **Query** - Once you have generated your similarity index, you can query your dataset with `sort_by_similarity()`. The query can be any of the following:
|
||||
|
||||
- An ID (sample or patch)
|
||||
- A query vector of same dimension as the index
|
||||
- A list of IDs (samples or patches)
|
||||
- A text prompt (search semantically)
|
||||
|
||||
## Prerequisites: install necessary dependencies
|
||||
|
||||
1. **Create and activate a virtual environment**
|
||||
|
||||
Install virtualenv package and run the following command in your project directory.
|
||||
```python
|
||||
python -m venv fiftyone_
|
||||
```
|
||||
From inside the project directory run the following to activate the virtual environment.
|
||||
=== "Windows"
|
||||
|
||||
```python
|
||||
fiftyone_/Scripts/activate
|
||||
```
|
||||
|
||||
=== "macOS/Linux"
|
||||
|
||||
```python
|
||||
source fiftyone_/Scripts/activate
|
||||
```
|
||||
|
||||
2. **Install the following packages in the virtual environment**
|
||||
|
||||
To install FiftyOne, ensure you have activated any virtual environment that you are using, then run
|
||||
```python
|
||||
pip install fiftyone
|
||||
```
|
||||
|
||||
|
||||
## Understand basic workflow
|
||||
|
||||
The basic workflow shown below uses LanceDB to create a similarity index on your FiftyOne datasets:
|
||||
|
||||
1. Load a dataset into FiftyOne.
|
||||
|
||||
@@ -19,14 +79,10 @@ datasets:
|
||||
|
||||
5. If desired, delete the table.
|
||||
|
||||
The example below demonstrates this workflow.
|
||||
## Quick Example
|
||||
|
||||
!!! Note
|
||||
Let's jump on a quick example that demonstrates this workflow.
|
||||
|
||||
Install the LanceDB Python client to run the code shown below.
|
||||
```
|
||||
pip install lancedb
|
||||
```
|
||||
|
||||
```python
|
||||
|
||||
@@ -36,7 +92,10 @@ import fiftyone.zoo as foz
|
||||
|
||||
# Step 1: Load your data into FiftyOne
|
||||
dataset = foz.load_zoo_dataset("quickstart")
|
||||
```
|
||||
Make sure you install torch ([guide here](https://pytorch.org/get-started/locally/)) before proceeding.
|
||||
|
||||
```python
|
||||
# Steps 2 and 3: Compute embeddings and create a similarity index
|
||||
lancedb_index = fob.compute_similarity(
|
||||
dataset,
|
||||
@@ -45,8 +104,11 @@ lancedb_index = fob.compute_similarity(
|
||||
backend="lancedb",
|
||||
)
|
||||
```
|
||||
Once the similarity index has been generated, we can query our data in FiftyOne
|
||||
by specifying the `brain_key`:
|
||||
|
||||
!!! note
|
||||
Running the code above will download the clip model (2.6Gb)
|
||||
|
||||
Once the similarity index has been generated, we can query our data in FiftyOne by specifying the `brain_key`:
|
||||
|
||||
```python
|
||||
# Step 4: Query your data
|
||||
@@ -56,7 +118,22 @@ view = dataset.sort_by_similarity(
|
||||
brain_key="lancedb_index",
|
||||
k=10, # limit to 10 most similar samples
|
||||
)
|
||||
```
|
||||
The returned result are of type - `DatasetView`.
|
||||
|
||||
!!! note
|
||||
`DatasetView` does not hold its contents in-memory. Views simply store the rule(s) that are applied to extract the content of interest from the underlying Dataset when the view is iterated/aggregated on.
|
||||
|
||||
This means, for example, that the contents of a `DatasetView` may change as the underlying Dataset is modified.
|
||||
|
||||
??? question "Can you query a view instead of dataset?"
|
||||
Yes, you can also query a view.
|
||||
|
||||
Performing a similarity search on a `DatasetView` will only return results from the view; if the view contains samples that were not included in the index, they will never be included in the result.
|
||||
|
||||
This means that you can index an entire Dataset once and then perform searches on subsets of the dataset by constructing views that contain the images of interest.
|
||||
|
||||
```python
|
||||
# Step 5 (optional): Cleanup
|
||||
|
||||
# Delete the LanceDB table
|
||||
@@ -66,4 +143,90 @@ lancedb_index.cleanup()
|
||||
dataset.delete_brain_run("lancedb_index")
|
||||
```
|
||||
|
||||
|
||||
## Using LanceDB backend
|
||||
By default, calling `compute_similarity()` or `sort_by_similarity()` will use an sklearn backend.
|
||||
|
||||
To use the LanceDB backend, simply set the optional `backend` parameter of `compute_similarity()` to `"lancedb"`:
|
||||
|
||||
```python
|
||||
import fiftyone.brain as fob
|
||||
#... rest of the code
|
||||
fob.compute_similarity(..., backend="lancedb", ...)
|
||||
```
|
||||
|
||||
Alternatively, you can configure FiftyOne to use the LanceDB backend by setting the following environment variable.
|
||||
|
||||
In your terminal, set the environment variable using:
|
||||
=== "Windows"
|
||||
|
||||
```python
|
||||
$Env:FIFTYONE_BRAIN_DEFAULT_SIMILARITY_BACKEND="lancedb" //powershell
|
||||
|
||||
set FIFTYONE_BRAIN_DEFAULT_SIMILARITY_BACKEND=lancedb //cmd
|
||||
```
|
||||
|
||||
=== "macOS/Linux"
|
||||
|
||||
```python
|
||||
export FIFTYONE_BRAIN_DEFAULT_SIMILARITY_BACKEND=lancedb
|
||||
```
|
||||
|
||||
!!! note
|
||||
This will only run during the terminal session. Once terminal is closed, environment variable is deleted.
|
||||
|
||||
Alternatively, you can **permanently** configure FiftyOne to use the LanceDB backend creating a `brain_config.json` at `~/.fiftyone/brain_config.json`. The JSON file may contain any desired subset of config fields that you wish to customize.
|
||||
|
||||
```json
|
||||
{
|
||||
"default_similarity_backend": "lancedb"
|
||||
}
|
||||
```
|
||||
This will override the default `brain_config` and will set it according to your customization. You can check the configuration by running the following code :
|
||||
|
||||
```python
|
||||
import fiftyone.brain as fob
|
||||
# Print your current brain config
|
||||
print(fob.brain_config)
|
||||
```
|
||||
|
||||
## LanceDB config parameters
|
||||
|
||||
The LanceDB backend supports query parameters that can be used to customize your similarity queries. These parameters include:
|
||||
|
||||
| Name| Purpose | Default |
|
||||
|:----|:--------|:--------|
|
||||
|**table_name**|The name of the LanceDB table to use. If none is provided, a new table will be created|`None`|
|
||||
|**metric**|The embedding distance metric to use when creating a new table. The supported values are ("cosine", "euclidean")|`"cosine"`|
|
||||
|**uri**| The database URI to use. In this Database URI, tables will be created. |`"/tmp/lancedb"`|
|
||||
|
||||
There are two ways to specify/customize the parameters:
|
||||
|
||||
1. **Using `brain_config.json` file**
|
||||
|
||||
```json
|
||||
{
|
||||
"similarity_backends": {
|
||||
"lancedb": {
|
||||
"table_name": "your-table",
|
||||
"metric": "euclidean",
|
||||
"uri": "/tmp/lancedb"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Directly passing to `compute_similarity()` to configure a specific new index** :
|
||||
|
||||
```python
|
||||
lancedb_index = fob.compute_similarity(
|
||||
...
|
||||
backend="lancedb",
|
||||
brain_key="lancedb_index",
|
||||
table_name="your-table",
|
||||
metric="euclidean",
|
||||
uri="/tmp/lancedb",
|
||||
)
|
||||
```
|
||||
|
||||
For a much more in depth walkthrough of the integration, visit the LanceDB x Voxel51 [docs page](https://docs.voxel51.com/integrations/lancedb.html).
|
||||
|
||||
17
docs/src/rag/sfr_rag.md
Normal file
17
docs/src/rag/sfr_rag.md
Normal file
@@ -0,0 +1,17 @@
|
||||
**SFR RAG 📑**
|
||||
====================================================================
|
||||
Salesforce AI Research introduces SFR-RAG, a 9-billion-parameter language model trained with a significant emphasis on reliable, precise, and faithful contextual generation abilities specific to real-world RAG use cases and relevant agentic tasks. They include precise factual knowledge extraction, distinguishing relevant against distracting contexts, citing appropriate sources along with answers, producing complex and multi-hop reasoning over multiple contexts, consistent format following, as well as refraining from hallucination over unanswerable queries.
|
||||
|
||||
**[Offical Implementation](https://github.com/SalesforceAIResearch/SFR-RAG)**
|
||||
|
||||
<figure markdown="span">
|
||||

|
||||
<figcaption>Average Scores in ContextualBench: <a href="https://blog.salesforceairesearch.com/sfr-rag/">Source</a>
|
||||
</figcaption>
|
||||
</figure>
|
||||
|
||||
To reliably evaluate LLMs in contextual question-answering for RAG, Saleforce introduced [ContextualBench](https://huggingface.co/datasets/Salesforce/ContextualBench?ref=blog.salesforceairesearch.com), featuring 7 benchmarks like [HotpotQA](https://arxiv.org/abs/1809.09600?ref=blog.salesforceairesearch.com) and [2WikiHopQA](https://www.aclweb.org/anthology/2020.coling-main.580/?ref=blog.salesforceairesearch.com) with consistent setups.
|
||||
|
||||
SFR-RAG outperforms GPT-4o, achieving state-of-the-art results in 3 out of 7 benchmarks, and significantly surpasses Command-R+ while using 10 times fewer parameters. It also excels at handling context, even when facts are altered or conflicting.
|
||||
|
||||
[Saleforce AI Research Blog](https://blog.salesforceairesearch.com/sfr-rag/)
|
||||
@@ -1,6 +1,9 @@
|
||||
# Linear Combination Reranker
|
||||
|
||||
This is the default re-ranker used by LanceDB hybrid search. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||
!!! note
|
||||
This is depricated. It is recommended to use the `RRFReranker` instead, if you want to use a score based reranker.
|
||||
|
||||
It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Reciprocal Rank Fusion Reranker
|
||||
|
||||
Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf).
|
||||
This is the default re-ranker used by LanceDB hybrid search. Reciprocal Rank Fusion (RRF) is an algorithm that evaluates the search scores by leveraging the positions/rank of the documents. The implementation follows this [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf).
|
||||
|
||||
|
||||
!!! note
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.10.0</version>
|
||||
<version>0.11.0-beta.1</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.10.0</version>
|
||||
<version>0.11.0-beta.1</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
1440
node/package-lock.json
generated
1440
node/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -58,7 +58,7 @@
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typedoc": "^0.24.7",
|
||||
"typedoc-plugin-markdown": "^3.15.3",
|
||||
"typescript": "*",
|
||||
"typescript": "^5.1.0",
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"dependencies": {
|
||||
@@ -88,10 +88,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.20",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.20",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.20",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.20",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.20"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.11.0-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.11.0-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.11.0-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.11.0-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.11.0-beta.1"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,7 +220,8 @@ export async function connect(
|
||||
region: partOpts.region ?? defaultRegion,
|
||||
timeout: partOpts.timeout ?? defaultRequestTimeout,
|
||||
readConsistencyInterval: partOpts.readConsistencyInterval ?? undefined,
|
||||
storageOptions: partOpts.storageOptions ?? undefined
|
||||
storageOptions: partOpts.storageOptions ?? undefined,
|
||||
hostOverride: partOpts.hostOverride ?? undefined
|
||||
}
|
||||
if (opts.uri.startsWith("db://")) {
|
||||
// Remote connection
|
||||
@@ -723,9 +724,9 @@ export interface VectorIndex {
|
||||
export interface IndexStats {
|
||||
numIndexedRows: number | null
|
||||
numUnindexedRows: number | null
|
||||
indexType: string | null
|
||||
distanceType: string | null
|
||||
completedAt: string | null
|
||||
indexType: string
|
||||
distanceType?: string
|
||||
numIndices?: number
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import * as chai from 'chai'
|
||||
import { assert } from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
@@ -22,7 +23,6 @@ import { tmpdir } from 'os'
|
||||
import * as fs from 'fs'
|
||||
import * as path from 'path'
|
||||
|
||||
const assert = chai.assert
|
||||
chai.use(chaiAsPromised)
|
||||
|
||||
describe('LanceDB AWS Integration test', function () {
|
||||
|
||||
@@ -33,6 +33,7 @@ export class Query<T = number[]> {
|
||||
private _filter?: string
|
||||
private _metricType?: MetricType
|
||||
private _prefilter: boolean
|
||||
private _fastSearch: boolean
|
||||
protected readonly _embeddings?: EmbeddingFunction<T>
|
||||
|
||||
constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||
@@ -46,6 +47,7 @@ export class Query<T = number[]> {
|
||||
this._metricType = undefined
|
||||
this._embeddings = embeddings
|
||||
this._prefilter = false
|
||||
this._fastSearch = false
|
||||
}
|
||||
|
||||
/***
|
||||
@@ -110,6 +112,15 @@ export class Query<T = number[]> {
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Skip searching un-indexed data. This can make search faster, but will miss
|
||||
* any data that is not yet indexed.
|
||||
*/
|
||||
fastSearch (value: boolean): Query<T> {
|
||||
this._fastSearch = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the query and return the results as an Array of Objects
|
||||
*/
|
||||
@@ -131,9 +142,9 @@ export class Query<T = number[]> {
|
||||
Object.keys(entry).forEach((key: string) => {
|
||||
if (entry[key] instanceof Vector) {
|
||||
// toJSON() returns f16 array correctly
|
||||
newObject[key] = (entry[key] as Vector).toJSON()
|
||||
newObject[key] = (entry[key] as any).toJSON()
|
||||
} else {
|
||||
newObject[key] = entry[key]
|
||||
newObject[key] = entry[key] as any
|
||||
}
|
||||
})
|
||||
return newObject as unknown as T
|
||||
|
||||
@@ -17,6 +17,7 @@ import axios, { type AxiosResponse, type ResponseType } from 'axios'
|
||||
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
|
||||
|
||||
import { type RemoteResponse, type RemoteRequest, Method } from '../middleware'
|
||||
import type { MetricType } from '..'
|
||||
|
||||
interface HttpLancedbClientMiddleware {
|
||||
onRemoteRequest(
|
||||
@@ -151,7 +152,9 @@ export class HttpLancedbClient {
|
||||
prefilter: boolean,
|
||||
refineFactor?: number,
|
||||
columns?: string[],
|
||||
filter?: string
|
||||
filter?: string,
|
||||
metricType?: MetricType,
|
||||
fastSearch?: boolean
|
||||
): Promise<ArrowTable<any>> {
|
||||
const result = await this.post(
|
||||
`/v1/table/${tableName}/query/`,
|
||||
@@ -159,10 +162,12 @@ export class HttpLancedbClient {
|
||||
vector,
|
||||
k,
|
||||
nprobes,
|
||||
refineFactor,
|
||||
refine_factor: refineFactor,
|
||||
columns,
|
||||
filter,
|
||||
prefilter
|
||||
prefilter,
|
||||
metric: metricType,
|
||||
fast_search: fastSearch
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
|
||||
@@ -238,16 +238,18 @@ export class RemoteQuery<T = number[]> extends Query<T> {
|
||||
(this as any)._prefilter,
|
||||
(this as any)._refineFactor,
|
||||
(this as any)._select,
|
||||
(this as any)._filter
|
||||
(this as any)._filter,
|
||||
(this as any)._metricType,
|
||||
(this as any)._fastSearch
|
||||
)
|
||||
|
||||
return data.toArray().map((entry: Record<string, unknown>) => {
|
||||
const newObject: Record<string, unknown> = {}
|
||||
Object.keys(entry).forEach((key: string) => {
|
||||
if (entry[key] instanceof Vector) {
|
||||
newObject[key] = (entry[key] as Vector).toArray()
|
||||
newObject[key] = (entry[key] as any).toArray()
|
||||
} else {
|
||||
newObject[key] = entry[key]
|
||||
newObject[key] = entry[key] as any
|
||||
}
|
||||
})
|
||||
return newObject as unknown as T
|
||||
@@ -524,8 +526,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
numIndexedRows: body?.num_indexed_rows,
|
||||
numUnindexedRows: body?.num_unindexed_rows,
|
||||
indexType: body?.index_type,
|
||||
distanceType: body?.distance_type,
|
||||
completedAt: body?.completed_at
|
||||
distanceType: body?.distance_type
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import { describe } from "mocha";
|
||||
import { track } from "temp";
|
||||
import { assert, expect } from 'chai'
|
||||
import * as chai from "chai";
|
||||
import * as chaiAsPromised from "chai-as-promised";
|
||||
|
||||
@@ -44,8 +45,6 @@ import {
|
||||
} from "apache-arrow";
|
||||
import type { RemoteRequest, RemoteResponse } from "../middleware";
|
||||
|
||||
const expect = chai.expect;
|
||||
const assert = chai.assert;
|
||||
chai.use(chaiAsPromised);
|
||||
|
||||
describe("LanceDB client", function () {
|
||||
@@ -169,7 +168,7 @@ describe("LanceDB client", function () {
|
||||
|
||||
// Should reject a bad filter
|
||||
await expect(table.filter("id % 2 = 0 AND").execute()).to.be.rejectedWith(
|
||||
/.*sql parser error: Expected an expression:, found: EOF.*/
|
||||
/.*sql parser error: .*/
|
||||
);
|
||||
});
|
||||
|
||||
@@ -888,9 +887,12 @@ describe("LanceDB client", function () {
|
||||
expect(indices[0].columns).to.have.lengthOf(1);
|
||||
expect(indices[0].columns[0]).to.equal("vector");
|
||||
|
||||
const stats = await table.indexStats(indices[0].uuid);
|
||||
const stats = await table.indexStats(indices[0].name);
|
||||
expect(stats.numIndexedRows).to.equal(300);
|
||||
expect(stats.numUnindexedRows).to.equal(0);
|
||||
expect(stats.indexType).to.equal("IVF_PQ");
|
||||
expect(stats.distanceType).to.equal("l2");
|
||||
expect(stats.numIndices).to.equal(1);
|
||||
}).timeout(50_000);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.0.0"
|
||||
version = "0.11.0-beta.1"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -14,7 +14,7 @@ crate-type = ["cdylib"]
|
||||
[dependencies]
|
||||
arrow-ipc.workspace = true
|
||||
futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb" }
|
||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||
napi = { version = "2.16.8", default-features = false, features = [
|
||||
"napi9",
|
||||
"async",
|
||||
|
||||
93
nodejs/__test__/remote.test.ts
Normal file
93
nodejs/__test__/remote.test.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import * as http from "http";
|
||||
import { RequestListener } from "http";
|
||||
import { Connection, ConnectionOptions, connect } from "../lancedb";
|
||||
|
||||
async function withMockDatabase(
|
||||
listener: RequestListener,
|
||||
callback: (db: Connection) => void,
|
||||
connectionOptions?: ConnectionOptions,
|
||||
) {
|
||||
const server = http.createServer(listener);
|
||||
server.listen(8000);
|
||||
|
||||
const db = await connect(
|
||||
"db://dev",
|
||||
Object.assign(
|
||||
{
|
||||
apiKey: "fake",
|
||||
hostOverride: "http://localhost:8000",
|
||||
},
|
||||
connectionOptions,
|
||||
),
|
||||
);
|
||||
|
||||
try {
|
||||
await callback(db);
|
||||
} finally {
|
||||
server.close();
|
||||
}
|
||||
}
|
||||
|
||||
describe("remote connection", () => {
|
||||
it("should accept partial connection options", async () => {
|
||||
await connect("db://test", {
|
||||
apiKey: "fake",
|
||||
clientConfig: {
|
||||
timeoutConfig: { readTimeout: 5 },
|
||||
retryConfig: { retries: 2 },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should pass down apiKey and userAgent", async () => {
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
expect(req.headers["x-api-key"]).toEqual("fake");
|
||||
expect(req.headers["user-agent"]).toEqual(
|
||||
`LanceDB-Node-Client/${process.env.npm_package_version}`,
|
||||
);
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async (db) => {
|
||||
const tableNames = await db.tableNames();
|
||||
expect(tableNames).toEqual([]);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("allows customizing user agent", async () => {
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
expect(req.headers["user-agent"]).toEqual("MyApp/1.0");
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async (db) => {
|
||||
const tableNames = await db.tableNames();
|
||||
expect(tableNames).toEqual([]);
|
||||
},
|
||||
{
|
||||
clientConfig: {
|
||||
userAgent: "MyApp/1.0",
|
||||
},
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -479,6 +479,9 @@ describe("When creating an index", () => {
|
||||
expect(stats).toBeDefined();
|
||||
expect(stats?.numIndexedRows).toEqual(300);
|
||||
expect(stats?.numUnindexedRows).toEqual(0);
|
||||
expect(stats?.distanceType).toBeUndefined();
|
||||
expect(stats?.indexType).toEqual("BTREE");
|
||||
expect(stats?.numIndices).toEqual(1);
|
||||
});
|
||||
|
||||
test("when getting stats on non-existent index", async () => {
|
||||
|
||||
@@ -23,8 +23,6 @@ import {
|
||||
Connection as LanceDbConnection,
|
||||
} from "./native.js";
|
||||
|
||||
import { RemoteConnection, RemoteConnectionOptions } from "./remote";
|
||||
|
||||
export {
|
||||
WriteOptions,
|
||||
WriteMode,
|
||||
@@ -32,8 +30,10 @@ export {
|
||||
ColumnAlteration,
|
||||
ConnectionOptions,
|
||||
IndexStatistics,
|
||||
IndexMetadata,
|
||||
IndexConfig,
|
||||
ClientConfig,
|
||||
TimeoutConfig,
|
||||
RetryConfig,
|
||||
} from "./native.js";
|
||||
|
||||
export {
|
||||
@@ -88,7 +88,7 @@ export * as embedding from "./embedding";
|
||||
*/
|
||||
export async function connect(
|
||||
uri: string,
|
||||
opts?: Partial<ConnectionOptions | RemoteConnectionOptions>,
|
||||
opts?: Partial<ConnectionOptions>,
|
||||
): Promise<Connection>;
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
@@ -109,13 +109,11 @@ export async function connect(
|
||||
* ```
|
||||
*/
|
||||
export async function connect(
|
||||
opts: Partial<RemoteConnectionOptions | ConnectionOptions> & { uri: string },
|
||||
opts: Partial<ConnectionOptions> & { uri: string },
|
||||
): Promise<Connection>;
|
||||
export async function connect(
|
||||
uriOrOptions:
|
||||
| string
|
||||
| (Partial<RemoteConnectionOptions | ConnectionOptions> & { uri: string }),
|
||||
opts: Partial<ConnectionOptions | RemoteConnectionOptions> = {},
|
||||
uriOrOptions: string | (Partial<ConnectionOptions> & { uri: string }),
|
||||
opts: Partial<ConnectionOptions> = {},
|
||||
): Promise<Connection> {
|
||||
let uri: string | undefined;
|
||||
if (typeof uriOrOptions !== "string") {
|
||||
@@ -130,9 +128,6 @@ export async function connect(
|
||||
throw new Error("uri is required");
|
||||
}
|
||||
|
||||
if (uri?.startsWith("db://")) {
|
||||
return new RemoteConnection(uri, opts as RemoteConnectionOptions);
|
||||
}
|
||||
opts = (opts as ConnectionOptions) ?? {};
|
||||
(<ConnectionOptions>opts).storageOptions = cleanseStorageOptions(
|
||||
(<ConnectionOptions>opts).storageOptions,
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import axios, {
|
||||
AxiosError,
|
||||
type AxiosResponse,
|
||||
type ResponseType,
|
||||
} from "axios";
|
||||
import { Table as ArrowTable } from "../arrow";
|
||||
import { tableFromIPC } from "../arrow";
|
||||
import { VectorQuery } from "../query";
|
||||
|
||||
export class RestfulLanceDBClient {
|
||||
#dbName: string;
|
||||
#region: string;
|
||||
#apiKey: string;
|
||||
#hostOverride?: string;
|
||||
#closed: boolean = false;
|
||||
#timeout: number = 12 * 1000; // 12 seconds;
|
||||
#session?: import("axios").AxiosInstance;
|
||||
|
||||
constructor(
|
||||
dbName: string,
|
||||
apiKey: string,
|
||||
region: string,
|
||||
hostOverride?: string,
|
||||
timeout?: number,
|
||||
) {
|
||||
this.#dbName = dbName;
|
||||
this.#apiKey = apiKey;
|
||||
this.#region = region;
|
||||
this.#hostOverride = hostOverride ?? this.#hostOverride;
|
||||
this.#timeout = timeout ?? this.#timeout;
|
||||
}
|
||||
|
||||
// todo: cache the session.
|
||||
get session(): import("axios").AxiosInstance {
|
||||
if (this.#session !== undefined) {
|
||||
return this.#session;
|
||||
} else {
|
||||
return axios.create({
|
||||
baseURL: this.url,
|
||||
headers: {
|
||||
// biome-ignore lint: external API
|
||||
Authorization: `Bearer ${this.#apiKey}`,
|
||||
},
|
||||
transformResponse: decodeErrorData,
|
||||
timeout: this.#timeout,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
get url(): string {
|
||||
return (
|
||||
this.#hostOverride ??
|
||||
`https://${this.#dbName}.${this.#region}.api.lancedb.com`
|
||||
);
|
||||
}
|
||||
|
||||
get headers(): { [key: string]: string } {
|
||||
const headers: { [key: string]: string } = {
|
||||
"x-api-key": this.#apiKey,
|
||||
"x-request-id": "na",
|
||||
};
|
||||
if (this.#region == "local") {
|
||||
headers["Host"] = `${this.#dbName}.${this.#region}.api.lancedb.com`;
|
||||
}
|
||||
if (this.#hostOverride) {
|
||||
headers["x-lancedb-database"] = this.#dbName;
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
isOpen(): boolean {
|
||||
return !this.#closed;
|
||||
}
|
||||
|
||||
private checkNotClosed(): void {
|
||||
if (this.#closed) {
|
||||
throw new Error("Connection is closed");
|
||||
}
|
||||
}
|
||||
|
||||
close(): void {
|
||||
this.#session = undefined;
|
||||
this.#closed = true;
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
async get(uri: string, params?: Record<string, any>): Promise<any> {
|
||||
this.checkNotClosed();
|
||||
uri = new URL(uri, this.url).toString();
|
||||
let response;
|
||||
try {
|
||||
response = await this.session.get(uri, {
|
||||
headers: this.headers,
|
||||
params,
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof AxiosError && e.response) {
|
||||
response = e.response;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
RestfulLanceDBClient.checkStatus(response!);
|
||||
return response!.data;
|
||||
}
|
||||
|
||||
// biome-ignore lint/suspicious/noExplicitAny: api response
|
||||
async post(uri: string, body?: any): Promise<any>;
|
||||
async post(
|
||||
uri: string,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: api request
|
||||
body: any,
|
||||
additional: {
|
||||
config?: { responseType: "arraybuffer" };
|
||||
headers?: Record<string, string>;
|
||||
params?: Record<string, string>;
|
||||
},
|
||||
): Promise<Buffer>;
|
||||
async post(
|
||||
uri: string,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: api request
|
||||
body?: any,
|
||||
additional?: {
|
||||
config?: { responseType: ResponseType };
|
||||
headers?: Record<string, string>;
|
||||
params?: Record<string, string>;
|
||||
},
|
||||
// biome-ignore lint/suspicious/noExplicitAny: api response
|
||||
): Promise<any> {
|
||||
this.checkNotClosed();
|
||||
uri = new URL(uri, this.url).toString();
|
||||
additional = Object.assign(
|
||||
{ config: { responseType: "json" } },
|
||||
additional,
|
||||
);
|
||||
|
||||
const headers = { ...this.headers, ...additional.headers };
|
||||
|
||||
if (!headers["Content-Type"]) {
|
||||
headers["Content-Type"] = "application/json";
|
||||
}
|
||||
let response;
|
||||
try {
|
||||
response = await this.session.post(uri, body, {
|
||||
headers,
|
||||
responseType: additional!.config!.responseType,
|
||||
params: new Map(Object.entries(additional.params ?? {})),
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof AxiosError && e.response) {
|
||||
response = e.response;
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
RestfulLanceDBClient.checkStatus(response!);
|
||||
if (additional!.config!.responseType === "arraybuffer") {
|
||||
return response!.data;
|
||||
} else {
|
||||
return JSON.parse(response!.data);
|
||||
}
|
||||
}
|
||||
|
||||
async listTables(limit = 10, pageToken = ""): Promise<string[]> {
|
||||
const json = await this.get("/v1/table", { limit, pageToken });
|
||||
return json.tables;
|
||||
}
|
||||
|
||||
async query(tableName: string, query: VectorQuery): Promise<ArrowTable> {
|
||||
const tbl = await this.post(`/v1/table/${tableName}/query`, query, {
|
||||
config: {
|
||||
responseType: "arraybuffer",
|
||||
},
|
||||
});
|
||||
return tableFromIPC(tbl);
|
||||
}
|
||||
|
||||
static checkStatus(response: AxiosResponse): void {
|
||||
if (response.status === 404) {
|
||||
throw new Error(`Not found: ${response.data}`);
|
||||
} else if (response.status >= 400 && response.status < 500) {
|
||||
throw new Error(
|
||||
`Bad Request: ${response.status}, error: ${response.data}`,
|
||||
);
|
||||
} else if (response.status >= 500 && response.status < 600) {
|
||||
throw new Error(
|
||||
`Internal Server Error: ${response.status}, error: ${response.data}`,
|
||||
);
|
||||
} else if (response.status !== 200) {
|
||||
throw new Error(
|
||||
`Unknown Error: ${response.status}, error: ${response.data}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function decodeErrorData(data: unknown) {
|
||||
if (Buffer.isBuffer(data)) {
|
||||
const decoded = data.toString("utf-8");
|
||||
return decoded;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
import { Schema } from "apache-arrow";
|
||||
import {
|
||||
Data,
|
||||
SchemaLike,
|
||||
fromTableToStreamBuffer,
|
||||
makeEmptyTable,
|
||||
} from "../arrow";
|
||||
import {
|
||||
Connection,
|
||||
CreateTableOptions,
|
||||
OpenTableOptions,
|
||||
TableNamesOptions,
|
||||
} from "../connection";
|
||||
import { Table } from "../table";
|
||||
import { TTLCache } from "../util";
|
||||
import { RestfulLanceDBClient } from "./client";
|
||||
import { RemoteTable } from "./table";
|
||||
|
||||
export interface RemoteConnectionOptions {
|
||||
apiKey?: string;
|
||||
region?: string;
|
||||
hostOverride?: string;
|
||||
timeout?: number;
|
||||
}
|
||||
|
||||
export class RemoteConnection extends Connection {
|
||||
#dbName: string;
|
||||
#apiKey: string;
|
||||
#region: string;
|
||||
#client: RestfulLanceDBClient;
|
||||
#tableCache = new TTLCache(300_000);
|
||||
|
||||
constructor(
|
||||
url: string,
|
||||
{ apiKey, region, hostOverride, timeout }: RemoteConnectionOptions,
|
||||
) {
|
||||
super();
|
||||
apiKey = apiKey ?? process.env.LANCEDB_API_KEY;
|
||||
region = region ?? process.env.LANCEDB_REGION;
|
||||
|
||||
if (!apiKey) {
|
||||
throw new Error("apiKey is required when connecting to LanceDB Cloud");
|
||||
}
|
||||
|
||||
if (!region) {
|
||||
throw new Error("region is required when connecting to LanceDB Cloud");
|
||||
}
|
||||
|
||||
const parsed = new URL(url);
|
||||
if (parsed.protocol !== "db:") {
|
||||
throw new Error(
|
||||
`invalid protocol: ${parsed.protocol}, only accepts db://`,
|
||||
);
|
||||
}
|
||||
|
||||
this.#dbName = parsed.hostname;
|
||||
this.#apiKey = apiKey;
|
||||
this.#region = region;
|
||||
this.#client = new RestfulLanceDBClient(
|
||||
this.#dbName,
|
||||
this.#apiKey,
|
||||
this.#region,
|
||||
hostOverride,
|
||||
timeout,
|
||||
);
|
||||
}
|
||||
|
||||
isOpen(): boolean {
|
||||
return this.#client.isOpen();
|
||||
}
|
||||
close(): void {
|
||||
return this.#client.close();
|
||||
}
|
||||
|
||||
display(): string {
|
||||
return `RemoteConnection(${this.#dbName})`;
|
||||
}
|
||||
|
||||
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
|
||||
const response = await this.#client.get("/v1/table/", {
|
||||
limit: options?.limit ?? 10,
|
||||
// biome-ignore lint/style/useNamingConvention: <explanation>
|
||||
page_token: options?.startAfter ?? "",
|
||||
});
|
||||
const body = await response.body();
|
||||
for (const table of body.tables) {
|
||||
this.#tableCache.set(table, true);
|
||||
}
|
||||
return body.tables;
|
||||
}
|
||||
|
||||
async openTable(
|
||||
name: string,
|
||||
_options?: Partial<OpenTableOptions> | undefined,
|
||||
): Promise<Table> {
|
||||
if (this.#tableCache.get(name) === undefined) {
|
||||
await this.#client.post(
|
||||
`/v1/table/${encodeURIComponent(name)}/describe/`,
|
||||
);
|
||||
this.#tableCache.set(name, true);
|
||||
}
|
||||
return new RemoteTable(this.#client, name, this.#dbName);
|
||||
}
|
||||
|
||||
async createTable(
|
||||
nameOrOptions:
|
||||
| string
|
||||
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
|
||||
data?: Data,
|
||||
options?: Partial<CreateTableOptions> | undefined,
|
||||
): Promise<Table> {
|
||||
if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
|
||||
const { name, data, ...options } = nameOrOptions;
|
||||
return this.createTable(name, data, options);
|
||||
}
|
||||
if (data === undefined) {
|
||||
throw new Error("data is required");
|
||||
}
|
||||
if (options?.mode) {
|
||||
console.warn(
|
||||
"option 'mode' is not supported in LanceDB Cloud",
|
||||
"LanceDB Cloud only supports the default 'create' mode.",
|
||||
"If the table already exists, an error will be thrown.",
|
||||
);
|
||||
}
|
||||
if (options?.embeddingFunction) {
|
||||
console.warn(
|
||||
"embedding_functions is not yet supported on LanceDB Cloud.",
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 ",
|
||||
"for this feature.",
|
||||
);
|
||||
}
|
||||
|
||||
const { buf } = await Table.parseTableData(
|
||||
data,
|
||||
options,
|
||||
true /** streaming */,
|
||||
);
|
||||
|
||||
await this.#client.post(
|
||||
`/v1/table/${encodeURIComponent(nameOrOptions)}/create/`,
|
||||
buf,
|
||||
{
|
||||
config: {
|
||||
responseType: "arraybuffer",
|
||||
},
|
||||
headers: { "Content-Type": "application/vnd.apache.arrow.stream" },
|
||||
},
|
||||
);
|
||||
this.#tableCache.set(nameOrOptions, true);
|
||||
return new RemoteTable(this.#client, nameOrOptions, this.#dbName);
|
||||
}
|
||||
|
||||
async createEmptyTable(
|
||||
name: string,
|
||||
schema: SchemaLike,
|
||||
options?: Partial<CreateTableOptions> | undefined,
|
||||
): Promise<Table> {
|
||||
if (options?.mode) {
|
||||
console.warn(`mode is not supported on LanceDB Cloud`);
|
||||
}
|
||||
|
||||
if (options?.embeddingFunction) {
|
||||
console.warn(
|
||||
"embeddingFunction is not yet supported on LanceDB Cloud.",
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 ",
|
||||
"for this feature.",
|
||||
);
|
||||
}
|
||||
const emptyTable = makeEmptyTable(schema);
|
||||
const buf = await fromTableToStreamBuffer(emptyTable);
|
||||
|
||||
await this.#client.post(
|
||||
`/v1/table/${encodeURIComponent(name)}/create/`,
|
||||
buf,
|
||||
{
|
||||
config: {
|
||||
responseType: "arraybuffer",
|
||||
},
|
||||
headers: { "Content-Type": "application/vnd.apache.arrow.stream" },
|
||||
},
|
||||
);
|
||||
|
||||
this.#tableCache.set(name, true);
|
||||
return new RemoteTable(this.#client, name, this.#dbName);
|
||||
}
|
||||
|
||||
async dropTable(name: string): Promise<void> {
|
||||
await this.#client.post(`/v1/table/${encodeURIComponent(name)}/drop/`);
|
||||
|
||||
this.#tableCache.delete(name);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
export { RestfulLanceDBClient } from "./client";
|
||||
export { type RemoteConnectionOptions, RemoteConnection } from "./connection";
|
||||
export { RemoteTable } from "./table";
|
||||
@@ -1,226 +0,0 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { Table as ArrowTable } from "apache-arrow";
|
||||
|
||||
import { Data, IntoVector } from "../arrow";
|
||||
|
||||
import { IndexStatistics } from "..";
|
||||
import { CreateTableOptions } from "../connection";
|
||||
import { IndexOptions } from "../indices";
|
||||
import { MergeInsertBuilder } from "../merge";
|
||||
import { VectorQuery } from "../query";
|
||||
import { AddDataOptions, Table, UpdateOptions } from "../table";
|
||||
import { IntoSql, toSQL } from "../util";
|
||||
import { RestfulLanceDBClient } from "./client";
|
||||
|
||||
export class RemoteTable extends Table {
|
||||
#client: RestfulLanceDBClient;
|
||||
#name: string;
|
||||
|
||||
// Used in the display() method
|
||||
#dbName: string;
|
||||
|
||||
get #tablePrefix() {
|
||||
return `/v1/table/${encodeURIComponent(this.#name)}/`;
|
||||
}
|
||||
|
||||
get name(): string {
|
||||
return this.#name;
|
||||
}
|
||||
|
||||
public constructor(
|
||||
client: RestfulLanceDBClient,
|
||||
tableName: string,
|
||||
dbName: string,
|
||||
) {
|
||||
super();
|
||||
this.#client = client;
|
||||
this.#name = tableName;
|
||||
this.#dbName = dbName;
|
||||
}
|
||||
|
||||
isOpen(): boolean {
|
||||
return !this.#client.isOpen();
|
||||
}
|
||||
|
||||
close(): void {
|
||||
this.#client.close();
|
||||
}
|
||||
|
||||
display(): string {
|
||||
return `RemoteTable(${this.#dbName}; ${this.#name})`;
|
||||
}
|
||||
|
||||
async schema(): Promise<import("apache-arrow").Schema> {
|
||||
const resp = await this.#client.post(`${this.#tablePrefix}/describe/`);
|
||||
// TODO: parse this into a valid arrow schema
|
||||
return resp.schema;
|
||||
}
|
||||
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
|
||||
const { buf, mode } = await Table.parseTableData(
|
||||
data,
|
||||
options as CreateTableOptions,
|
||||
true,
|
||||
);
|
||||
await this.#client.post(`${this.#tablePrefix}/insert/`, buf, {
|
||||
params: {
|
||||
mode,
|
||||
},
|
||||
headers: {
|
||||
"Content-Type": "application/vnd.apache.arrow.stream",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async update(
|
||||
optsOrUpdates:
|
||||
| (Map<string, string> | Record<string, string>)
|
||||
| ({
|
||||
values: Map<string, IntoSql> | Record<string, IntoSql>;
|
||||
} & Partial<UpdateOptions>)
|
||||
| ({
|
||||
valuesSql: Map<string, string> | Record<string, string>;
|
||||
} & Partial<UpdateOptions>),
|
||||
options?: Partial<UpdateOptions>,
|
||||
): Promise<void> {
|
||||
const isValues =
|
||||
"values" in optsOrUpdates && typeof optsOrUpdates.values !== "string";
|
||||
const isValuesSql =
|
||||
"valuesSql" in optsOrUpdates &&
|
||||
typeof optsOrUpdates.valuesSql !== "string";
|
||||
const isMap = (obj: unknown): obj is Map<string, string> => {
|
||||
return obj instanceof Map;
|
||||
};
|
||||
|
||||
let predicate;
|
||||
let columns: [string, string][];
|
||||
switch (true) {
|
||||
case isMap(optsOrUpdates):
|
||||
columns = Array.from(optsOrUpdates.entries());
|
||||
predicate = options?.where;
|
||||
break;
|
||||
case isValues && isMap(optsOrUpdates.values):
|
||||
columns = Array.from(optsOrUpdates.values.entries()).map(([k, v]) => [
|
||||
k,
|
||||
toSQL(v),
|
||||
]);
|
||||
predicate = optsOrUpdates.where;
|
||||
break;
|
||||
case isValues && !isMap(optsOrUpdates.values):
|
||||
columns = Object.entries(optsOrUpdates.values).map(([k, v]) => [
|
||||
k,
|
||||
toSQL(v),
|
||||
]);
|
||||
predicate = optsOrUpdates.where;
|
||||
break;
|
||||
|
||||
case isValuesSql && isMap(optsOrUpdates.valuesSql):
|
||||
columns = Array.from(optsOrUpdates.valuesSql.entries());
|
||||
predicate = optsOrUpdates.where;
|
||||
break;
|
||||
case isValuesSql && !isMap(optsOrUpdates.valuesSql):
|
||||
columns = Object.entries(optsOrUpdates.valuesSql).map(([k, v]) => [
|
||||
k,
|
||||
v,
|
||||
]);
|
||||
predicate = optsOrUpdates.where;
|
||||
break;
|
||||
default:
|
||||
columns = Object.entries(optsOrUpdates as Record<string, string>);
|
||||
predicate = options?.where;
|
||||
}
|
||||
|
||||
await this.#client.post(`${this.#tablePrefix}/update/`, {
|
||||
predicate: predicate ?? null,
|
||||
updates: columns,
|
||||
});
|
||||
}
|
||||
async countRows(filter?: unknown): Promise<number> {
|
||||
const payload = { predicate: filter };
|
||||
return await this.#client.post(`${this.#tablePrefix}/count_rows/`, payload);
|
||||
}
|
||||
|
||||
async delete(predicate: unknown): Promise<void> {
|
||||
const payload = { predicate };
|
||||
await this.#client.post(`${this.#tablePrefix}/delete/`, payload);
|
||||
}
|
||||
async createIndex(
|
||||
column: string,
|
||||
options?: Partial<IndexOptions>,
|
||||
): Promise<void> {
|
||||
if (options !== undefined) {
|
||||
console.warn("options are not yet supported on the LanceDB cloud");
|
||||
}
|
||||
const indexType = "vector";
|
||||
const metric = "L2";
|
||||
const data = {
|
||||
column,
|
||||
// biome-ignore lint/style/useNamingConvention: external API
|
||||
index_type: indexType,
|
||||
// biome-ignore lint/style/useNamingConvention: external API
|
||||
metric_type: metric,
|
||||
};
|
||||
await this.#client.post(`${this.#tablePrefix}/create_index`, data);
|
||||
}
|
||||
query(): import("..").Query {
|
||||
throw new Error("query() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
|
||||
search(_query: string | IntoVector): VectorQuery {
|
||||
throw new Error("search() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
vectorSearch(_vector: unknown): import("..").VectorQuery {
|
||||
throw new Error("vectorSearch() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
addColumns(_newColumnTransforms: unknown): Promise<void> {
|
||||
throw new Error("addColumns() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
alterColumns(_columnAlterations: unknown): Promise<void> {
|
||||
throw new Error("alterColumns() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
dropColumns(_columnNames: unknown): Promise<void> {
|
||||
throw new Error("dropColumns() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
async version(): Promise<number> {
|
||||
const resp = await this.#client.post(`${this.#tablePrefix}/describe/`);
|
||||
return resp.version;
|
||||
}
|
||||
checkout(_version: unknown): Promise<void> {
|
||||
throw new Error("checkout() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
checkoutLatest(): Promise<void> {
|
||||
throw new Error(
|
||||
"checkoutLatest() is not yet supported on the LanceDB cloud",
|
||||
);
|
||||
}
|
||||
restore(): Promise<void> {
|
||||
throw new Error("restore() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
optimize(_options?: unknown): Promise<import("../native").OptimizeStats> {
|
||||
throw new Error("optimize() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
async listIndices(): Promise<import("../native").IndexConfig[]> {
|
||||
return await this.#client.post(`${this.#tablePrefix}/index/list/`);
|
||||
}
|
||||
toArrow(): Promise<ArrowTable> {
|
||||
throw new Error("toArrow() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
mergeInsert(_on: string | string[]): MergeInsertBuilder {
|
||||
throw new Error("mergeInsert() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
async indexStats(_name: string): Promise<IndexStatistics | undefined> {
|
||||
throw new Error("indexStats() is not yet supported on the LanceDB cloud");
|
||||
}
|
||||
}
|
||||
208
nodejs/native.d.ts
vendored
208
nodejs/native.d.ts
vendored
@@ -1,208 +0,0 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/* auto-generated by NAPI-RS */
|
||||
|
||||
/** A description of an index currently configured on a column */
|
||||
export interface IndexConfig {
|
||||
/** The name of the index */
|
||||
name: string
|
||||
/** The type of the index */
|
||||
indexType: string
|
||||
/**
|
||||
* The columns in the index
|
||||
*
|
||||
* Currently this is always an array of size 1. In the future there may
|
||||
* be more columns to represent composite indices.
|
||||
*/
|
||||
columns: Array<string>
|
||||
}
|
||||
/** Statistics about a compaction operation. */
|
||||
export interface CompactionStats {
|
||||
/** The number of fragments removed */
|
||||
fragmentsRemoved: number
|
||||
/** The number of new, compacted fragments added */
|
||||
fragmentsAdded: number
|
||||
/** The number of data files removed */
|
||||
filesRemoved: number
|
||||
/** The number of new, compacted data files added */
|
||||
filesAdded: number
|
||||
}
|
||||
/** Statistics about a cleanup operation */
|
||||
export interface RemovalStats {
|
||||
/** The number of bytes removed */
|
||||
bytesRemoved: number
|
||||
/** The number of old versions removed */
|
||||
oldVersionsRemoved: number
|
||||
}
|
||||
/** Statistics about an optimize operation */
|
||||
export interface OptimizeStats {
|
||||
/** Statistics about the compaction operation */
|
||||
compaction: CompactionStats
|
||||
/** Statistics about the removal operation */
|
||||
prune: RemovalStats
|
||||
}
|
||||
/**
|
||||
* A definition of a column alteration. The alteration changes the column at
|
||||
* `path` to have the new name `name`, to be nullable if `nullable` is true,
|
||||
* and to have the data type `data_type`. At least one of `rename` or `nullable`
|
||||
* must be provided.
|
||||
*/
|
||||
export interface ColumnAlteration {
|
||||
/**
|
||||
* The path to the column to alter. This is a dot-separated path to the column.
|
||||
* If it is a top-level column then it is just the name of the column. If it is
|
||||
* a nested column then it is the path to the column, e.g. "a.b.c" for a column
|
||||
* `c` nested inside a column `b` nested inside a column `a`.
|
||||
*/
|
||||
path: string
|
||||
/**
|
||||
* The new name of the column. If not provided then the name will not be changed.
|
||||
* This must be distinct from the names of all other columns in the table.
|
||||
*/
|
||||
rename?: string
|
||||
/** Set the new nullability. Note that a nullable column cannot be made non-nullable. */
|
||||
nullable?: boolean
|
||||
}
|
||||
/** A definition of a new column to add to a table. */
|
||||
export interface AddColumnsSql {
|
||||
/** The name of the new column. */
|
||||
name: string
|
||||
/**
|
||||
* The values to populate the new column with, as a SQL expression.
|
||||
* The expression can reference other columns in the table.
|
||||
*/
|
||||
valueSql: string
|
||||
}
|
||||
export interface IndexStatistics {
|
||||
/** The number of rows indexed by the index */
|
||||
numIndexedRows: number
|
||||
/** The number of rows not indexed */
|
||||
numUnindexedRows: number
|
||||
/** The type of the index */
|
||||
indexType?: string
|
||||
/** The metadata for each index */
|
||||
indices: Array<IndexMetadata>
|
||||
}
|
||||
export interface IndexMetadata {
|
||||
metricType?: string
|
||||
indexType?: string
|
||||
}
|
||||
export interface ConnectionOptions {
|
||||
/**
|
||||
* (For LanceDB OSS only): The interval, in seconds, at which to check for
|
||||
* updates to the table from other processes. If None, then consistency is not
|
||||
* checked. For performance reasons, this is the default. For strong
|
||||
* consistency, set this to zero seconds. Then every read will check for
|
||||
* updates from other processes. As a compromise, you can set this to a
|
||||
* non-zero value for eventual consistency. If more than that interval
|
||||
* has passed since the last check, then the table will be checked for updates.
|
||||
* Note: this consistency only applies to read operations. Write operations are
|
||||
* always consistent.
|
||||
*/
|
||||
readConsistencyInterval?: number
|
||||
/**
|
||||
* (For LanceDB OSS only): configuration for object storage.
|
||||
*
|
||||
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||
*/
|
||||
storageOptions?: Record<string, string>
|
||||
}
|
||||
/** Write mode for writing a table. */
|
||||
export const enum WriteMode {
|
||||
Create = 'Create',
|
||||
Append = 'Append',
|
||||
Overwrite = 'Overwrite'
|
||||
}
|
||||
/** Write options when creating a Table. */
|
||||
export interface WriteOptions {
|
||||
/** Write mode for writing to a table. */
|
||||
mode?: WriteMode
|
||||
}
|
||||
export interface OpenTableOptions {
|
||||
storageOptions?: Record<string, string>
|
||||
}
|
||||
export class Connection {
|
||||
/** Create a new Connection instance from the given URI. */
|
||||
static new(uri: string, options: ConnectionOptions): Promise<Connection>
|
||||
display(): string
|
||||
isOpen(): boolean
|
||||
close(): void
|
||||
/** List all tables in the dataset. */
|
||||
tableNames(startAfter?: string | undefined | null, limit?: number | undefined | null): Promise<Array<string>>
|
||||
/**
|
||||
* Create table from a Apache Arrow IPC (file) buffer.
|
||||
*
|
||||
* Parameters:
|
||||
* - name: The name of the table.
|
||||
* - buf: The buffer containing the IPC file.
|
||||
*
|
||||
*/
|
||||
createTable(name: string, buf: Buffer, mode: string, storageOptions?: Record<string, string> | undefined | null, useLegacyFormat?: boolean | undefined | null): Promise<Table>
|
||||
createEmptyTable(name: string, schemaBuf: Buffer, mode: string, storageOptions?: Record<string, string> | undefined | null, useLegacyFormat?: boolean | undefined | null): Promise<Table>
|
||||
openTable(name: string, storageOptions?: Record<string, string> | undefined | null, indexCacheSize?: number | undefined | null): Promise<Table>
|
||||
/** Drop table with the name. Or raise an error if the table does not exist. */
|
||||
dropTable(name: string): Promise<void>
|
||||
}
|
||||
export class Index {
|
||||
static ivfPq(distanceType?: string | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): Index
|
||||
static btree(): Index
|
||||
}
|
||||
/** Typescript-style Async Iterator over RecordBatches */
|
||||
export class RecordBatchIterator {
|
||||
next(): Promise<Buffer | null>
|
||||
}
|
||||
/** A builder used to create and run a merge insert operation */
|
||||
export class NativeMergeInsertBuilder {
|
||||
whenMatchedUpdateAll(condition?: string | undefined | null): NativeMergeInsertBuilder
|
||||
whenNotMatchedInsertAll(): NativeMergeInsertBuilder
|
||||
whenNotMatchedBySourceDelete(filter?: string | undefined | null): NativeMergeInsertBuilder
|
||||
execute(buf: Buffer): Promise<void>
|
||||
}
|
||||
export class Query {
|
||||
onlyIf(predicate: string): void
|
||||
select(columns: Array<[string, string]>): void
|
||||
limit(limit: number): void
|
||||
nearestTo(vector: Float32Array): VectorQuery
|
||||
execute(maxBatchLength?: number | undefined | null): Promise<RecordBatchIterator>
|
||||
explainPlan(verbose: boolean): Promise<string>
|
||||
}
|
||||
export class VectorQuery {
|
||||
column(column: string): void
|
||||
distanceType(distanceType: string): void
|
||||
postfilter(): void
|
||||
refineFactor(refineFactor: number): void
|
||||
nprobes(nprobe: number): void
|
||||
bypassVectorIndex(): void
|
||||
onlyIf(predicate: string): void
|
||||
select(columns: Array<[string, string]>): void
|
||||
limit(limit: number): void
|
||||
execute(maxBatchLength?: number | undefined | null): Promise<RecordBatchIterator>
|
||||
explainPlan(verbose: boolean): Promise<string>
|
||||
}
|
||||
export class Table {
|
||||
name: string
|
||||
display(): string
|
||||
isOpen(): boolean
|
||||
close(): void
|
||||
/** Return Schema as empty Arrow IPC file. */
|
||||
schema(): Promise<Buffer>
|
||||
add(buf: Buffer, mode: string): Promise<void>
|
||||
countRows(filter?: string | undefined | null): Promise<number>
|
||||
delete(predicate: string): Promise<void>
|
||||
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
|
||||
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
|
||||
query(): Query
|
||||
vectorSearch(vector: Float32Array): VectorQuery
|
||||
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
|
||||
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
|
||||
dropColumns(columns: Array<string>): Promise<void>
|
||||
version(): Promise<number>
|
||||
checkout(version: number): Promise<void>
|
||||
checkoutLatest(): Promise<void>
|
||||
restore(): Promise<void>
|
||||
optimize(olderThanMs?: number | undefined | null): Promise<OptimizeStats>
|
||||
listIndices(): Promise<Array<IndexConfig>>
|
||||
indexStats(indexName: string): Promise<IndexStatistics | null>
|
||||
mergeInsert(on: Array<string>): NativeMergeInsertBuilder
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
97
nodejs/package-lock.json
generated
97
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.10.0-beta.1",
|
||||
"version": "0.11.0-beta.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.10.0-beta.1",
|
||||
"version": "0.11.0-beta.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -18,7 +18,6 @@
|
||||
"win32"
|
||||
],
|
||||
"dependencies": {
|
||||
"axios": "^1.7.2",
|
||||
"reflect-metadata": "^0.2.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -30,6 +29,7 @@
|
||||
"@napi-rs/cli": "^2.18.3",
|
||||
"@types/axios": "^0.14.0",
|
||||
"@types/jest": "^29.1.2",
|
||||
"@types/node": "^22.7.4",
|
||||
"@types/tmp": "^0.2.6",
|
||||
"apache-arrow-13": "npm:apache-arrow@13.0.0",
|
||||
"apache-arrow-14": "npm:apache-arrow@14.0.0",
|
||||
@@ -4648,11 +4648,12 @@
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "20.14.11",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.14.11.tgz",
|
||||
"integrity": "sha512-kprQpL8MMeszbz6ojB5/tU8PLN4kesnN8Gjzw349rDlNgsSzg90lAVj3llK99Dh7JON+t9AuscPPFW6mPbTnSA==",
|
||||
"version": "22.7.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.7.4.tgz",
|
||||
"integrity": "sha512-y+NPi1rFzDs1NdQHHToqeiX2TIS79SWEAw9GYhkkx8bD0ChpfqC+n2j5OXOCpzfojBEBt6DnEnnG9MY0zk1XLg==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~5.26.4"
|
||||
"undici-types": "~6.19.2"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/node-fetch": {
|
||||
@@ -4665,6 +4666,12 @@
|
||||
"form-data": "^4.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/node/node_modules/undici-types": {
|
||||
"version": "6.19.8",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.19.8.tgz",
|
||||
"integrity": "sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/@types/pad-left": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/pad-left/-/pad-left-2.1.1.tgz",
|
||||
@@ -4963,6 +4970,21 @@
|
||||
"arrow2csv": "bin/arrow2csv.cjs"
|
||||
}
|
||||
},
|
||||
"node_modules/apache-arrow-15/node_modules/@types/node": {
|
||||
"version": "20.16.10",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.16.10.tgz",
|
||||
"integrity": "sha512-vQUKgWTjEIRFCvK6CyriPH3MZYiYlNy0fKiEYHWbcoWLEgs4opurGGKlebrTLqdSMIbXImH6XExNiIyNUv3WpA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.19.2"
|
||||
}
|
||||
},
|
||||
"node_modules/apache-arrow-15/node_modules/undici-types": {
|
||||
"version": "6.19.8",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.19.8.tgz",
|
||||
"integrity": "sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/apache-arrow-16": {
|
||||
"name": "apache-arrow",
|
||||
"version": "16.0.0",
|
||||
@@ -4984,6 +5006,21 @@
|
||||
"arrow2csv": "bin/arrow2csv.cjs"
|
||||
}
|
||||
},
|
||||
"node_modules/apache-arrow-16/node_modules/@types/node": {
|
||||
"version": "20.16.10",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.16.10.tgz",
|
||||
"integrity": "sha512-vQUKgWTjEIRFCvK6CyriPH3MZYiYlNy0fKiEYHWbcoWLEgs4opurGGKlebrTLqdSMIbXImH6XExNiIyNUv3WpA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.19.2"
|
||||
}
|
||||
},
|
||||
"node_modules/apache-arrow-16/node_modules/undici-types": {
|
||||
"version": "6.19.8",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.19.8.tgz",
|
||||
"integrity": "sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/apache-arrow-17": {
|
||||
"name": "apache-arrow",
|
||||
"version": "17.0.0",
|
||||
@@ -5011,12 +5048,42 @@
|
||||
"integrity": "sha512-BwR5KP3Es/CSht0xqBcUXS3qCAUVXwpRKsV2+arxeb65atasuXG9LykC9Ab10Cw3s2raH92ZqOeILaQbsB2ACg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/apache-arrow-17/node_modules/@types/node": {
|
||||
"version": "20.16.10",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.16.10.tgz",
|
||||
"integrity": "sha512-vQUKgWTjEIRFCvK6CyriPH3MZYiYlNy0fKiEYHWbcoWLEgs4opurGGKlebrTLqdSMIbXImH6XExNiIyNUv3WpA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.19.2"
|
||||
}
|
||||
},
|
||||
"node_modules/apache-arrow-17/node_modules/flatbuffers": {
|
||||
"version": "24.3.25",
|
||||
"resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-24.3.25.tgz",
|
||||
"integrity": "sha512-3HDgPbgiwWMI9zVB7VYBHaMrbOO7Gm0v+yD2FV/sCKj+9NDeVL7BOBYUuhWAQGKWOzBo8S9WdMvV0eixO233XQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/apache-arrow-17/node_modules/undici-types": {
|
||||
"version": "6.19.8",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.19.8.tgz",
|
||||
"integrity": "sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/apache-arrow/node_modules/@types/node": {
|
||||
"version": "20.16.10",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.16.10.tgz",
|
||||
"integrity": "sha512-vQUKgWTjEIRFCvK6CyriPH3MZYiYlNy0fKiEYHWbcoWLEgs4opurGGKlebrTLqdSMIbXImH6XExNiIyNUv3WpA==",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.19.2"
|
||||
}
|
||||
},
|
||||
"node_modules/apache-arrow/node_modules/undici-types": {
|
||||
"version": "6.19.8",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.19.8.tgz",
|
||||
"integrity": "sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/argparse": {
|
||||
"version": "1.0.10",
|
||||
"resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz",
|
||||
@@ -5046,12 +5113,14 @@
|
||||
"node_modules/asynckit": {
|
||||
"version": "0.4.0",
|
||||
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
|
||||
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q=="
|
||||
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.7.2",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.2.tgz",
|
||||
"integrity": "sha512-2A8QhOMrbomlDuiLeK9XibIBzuHeRcqqNOHp0Cyp5EoJ1IFDh+XZH3A6BkXtv0K4gFGCI0Y4BM7B1wOEi0Rmgw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.6",
|
||||
"form-data": "^4.0.0",
|
||||
@@ -5536,6 +5605,7 @@
|
||||
"version": "1.0.8",
|
||||
"resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
|
||||
"integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"delayed-stream": "~1.0.0"
|
||||
},
|
||||
@@ -5723,6 +5793,7 @@
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
|
||||
"integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
|
||||
"devOptional": true,
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
@@ -6248,6 +6319,7 @@
|
||||
"version": "1.15.6",
|
||||
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
|
||||
"integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
|
||||
"dev": true,
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
@@ -6267,6 +6339,7 @@
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz",
|
||||
"integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
"combined-stream": "^1.0.8",
|
||||
@@ -7773,6 +7846,7 @@
|
||||
"version": "1.52.0",
|
||||
"resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz",
|
||||
"integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==",
|
||||
"devOptional": true,
|
||||
"engines": {
|
||||
"node": ">= 0.6"
|
||||
}
|
||||
@@ -7781,6 +7855,7 @@
|
||||
"version": "2.1.35",
|
||||
"resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz",
|
||||
"integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"mime-db": "1.52.0"
|
||||
},
|
||||
@@ -8393,7 +8468,8 @@
|
||||
"node_modules/proxy-from-env": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz",
|
||||
"integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg=="
|
||||
"integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/pump": {
|
||||
"version": "3.0.0",
|
||||
@@ -9561,7 +9637,8 @@
|
||||
"node_modules/undici-types": {
|
||||
"version": "5.26.5",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA=="
|
||||
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/update-browserslist-db": {
|
||||
"version": "1.0.13",
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"vector database",
|
||||
"ann"
|
||||
],
|
||||
"version": "0.10.0",
|
||||
"version": "0.11.0-beta.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
@@ -40,6 +40,7 @@
|
||||
"@napi-rs/cli": "^2.18.3",
|
||||
"@types/axios": "^0.14.0",
|
||||
"@types/jest": "^29.1.2",
|
||||
"@types/node": "^22.7.4",
|
||||
"@types/tmp": "^0.2.6",
|
||||
"apache-arrow-13": "npm:apache-arrow@13.0.0",
|
||||
"apache-arrow-14": "npm:apache-arrow@14.0.0",
|
||||
@@ -66,8 +67,8 @@
|
||||
"os": ["darwin", "linux", "win32"],
|
||||
"scripts": {
|
||||
"artifacts": "napi artifacts",
|
||||
"build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
|
||||
"build:release": "napi build --platform --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
|
||||
"build:debug": "napi build --platform --no-const-enum --dts ../lancedb/native.d.ts --js ../lancedb/native.js lancedb",
|
||||
"build:release": "napi build --platform --no-const-enum --release --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
|
||||
"build": "npm run build:debug && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts && shx cp lancedb/*.node dist/",
|
||||
"build-release": "npm run build:release && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts",
|
||||
"lint-ci": "biome ci .",
|
||||
@@ -81,7 +82,6 @@
|
||||
"version": "napi version"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.7.2",
|
||||
"reflect-metadata": "^0.2.2"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
|
||||
@@ -68,6 +68,24 @@ impl Connection {
|
||||
builder = builder.storage_option(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
let client_config = options.client_config.unwrap_or_default();
|
||||
builder = builder.client_config(client_config.into());
|
||||
|
||||
if let Some(api_key) = options.api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
}
|
||||
|
||||
if let Some(region) = options.region {
|
||||
builder = builder.region(®ion);
|
||||
} else {
|
||||
builder = builder.region("us-east-1");
|
||||
}
|
||||
|
||||
if let Some(host_override) = options.host_override {
|
||||
builder = builder.host_override(&host_override);
|
||||
}
|
||||
|
||||
Ok(Self::inner_new(
|
||||
builder
|
||||
.execute()
|
||||
|
||||
@@ -22,6 +22,7 @@ mod index;
|
||||
mod iterator;
|
||||
pub mod merge;
|
||||
mod query;
|
||||
pub mod remote;
|
||||
mod table;
|
||||
mod util;
|
||||
|
||||
@@ -42,6 +43,19 @@ pub struct ConnectionOptions {
|
||||
///
|
||||
/// The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||
pub storage_options: Option<HashMap<String, String>>,
|
||||
|
||||
/// (For LanceDB cloud only): configuration for the remote HTTP client.
|
||||
pub client_config: Option<remote::ClientConfig>,
|
||||
/// (For LanceDB cloud only): the API key to use with LanceDB Cloud.
|
||||
///
|
||||
/// Can also be set via the environment variable `LANCEDB_API_KEY`.
|
||||
pub api_key: Option<String>,
|
||||
/// (For LanceDB cloud only): the region to use for LanceDB cloud.
|
||||
/// Defaults to 'us-east-1'.
|
||||
pub region: Option<String>,
|
||||
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
|
||||
/// for testing purposes.
|
||||
pub host_override: Option<String>,
|
||||
}
|
||||
|
||||
/// Write mode for writing a table.
|
||||
|
||||
120
nodejs/src/remote.rs
Normal file
120
nodejs/src/remote.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use napi_derive::*;
|
||||
|
||||
/// Timeout configuration for remote HTTP client.
|
||||
#[napi(object)]
|
||||
#[derive(Debug)]
|
||||
pub struct TimeoutConfig {
|
||||
/// The timeout for establishing a connection in seconds. Default is 120
|
||||
/// seconds (2 minutes). This can also be set via the environment variable
|
||||
/// `LANCE_CLIENT_CONNECT_TIMEOUT`, as an integer number of seconds.
|
||||
pub connect_timeout: Option<f64>,
|
||||
/// The timeout for reading data from the server in seconds. Default is 300
|
||||
/// seconds (5 minutes). This can also be set via the environment variable
|
||||
/// `LANCE_CLIENT_READ_TIMEOUT`, as an integer number of seconds.
|
||||
pub read_timeout: Option<f64>,
|
||||
/// The timeout for keeping idle connections in the connection pool in seconds.
|
||||
/// Default is 300 seconds (5 minutes). This can also be set via the
|
||||
/// environment variable `LANCE_CLIENT_CONNECTION_TIMEOUT`, as an integer
|
||||
/// number of seconds.
|
||||
pub pool_idle_timeout: Option<f64>,
|
||||
}
|
||||
|
||||
/// Retry configuration for the remote HTTP client.
|
||||
#[napi(object)]
|
||||
#[derive(Debug)]
|
||||
pub struct RetryConfig {
|
||||
/// The maximum number of retries for a request. Default is 3. You can also
|
||||
/// set this via the environment variable `LANCE_CLIENT_MAX_RETRIES`.
|
||||
pub retries: Option<u8>,
|
||||
/// The maximum number of retries for connection errors. Default is 3. You
|
||||
/// can also set this via the environment variable `LANCE_CLIENT_CONNECT_RETRIES`.
|
||||
pub connect_retries: Option<u8>,
|
||||
/// The maximum number of retries for read errors. Default is 3. You can also
|
||||
/// set this via the environment variable `LANCE_CLIENT_READ_RETRIES`.
|
||||
pub read_retries: Option<u8>,
|
||||
/// The backoff factor to apply between retries. Default is 0.25. Between each retry
|
||||
/// the client will wait for the amount of seconds:
|
||||
/// `{backoff factor} * (2 ** ({number of previous retries}))`. So for the default
|
||||
/// of 0.25, the first retry will wait 0.25 seconds, the second retry will wait 0.5
|
||||
/// seconds, the third retry will wait 1 second, etc.
|
||||
///
|
||||
/// You can also set this via the environment variable
|
||||
/// `LANCE_CLIENT_RETRY_BACKOFF_FACTOR`.
|
||||
pub backoff_factor: Option<f64>,
|
||||
/// The jitter to apply to the backoff factor, in seconds. Default is 0.25.
|
||||
///
|
||||
/// A random value between 0 and `backoff_jitter` will be added to the backoff
|
||||
/// factor in seconds. So for the default of 0.25 seconds, between 0 and 250
|
||||
/// milliseconds will be added to the sleep between each retry.
|
||||
///
|
||||
/// You can also set this via the environment variable
|
||||
/// `LANCE_CLIENT_RETRY_BACKOFF_JITTER`.
|
||||
pub backoff_jitter: Option<f64>,
|
||||
/// The HTTP status codes for which to retry the request. Default is
|
||||
/// [429, 500, 502, 503].
|
||||
///
|
||||
/// You can also set this via the environment variable
|
||||
/// `LANCE_CLIENT_RETRY_STATUSES`. Use a comma-separated list of integers.
|
||||
pub statuses: Option<Vec<u16>>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ClientConfig {
|
||||
pub user_agent: Option<String>,
|
||||
pub retry_config: Option<RetryConfig>,
|
||||
pub timeout_config: Option<TimeoutConfig>,
|
||||
}
|
||||
|
||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
fn from(config: TimeoutConfig) -> Self {
|
||||
Self {
|
||||
connect_timeout: config
|
||||
.connect_timeout
|
||||
.map(std::time::Duration::from_secs_f64),
|
||||
read_timeout: config.read_timeout.map(std::time::Duration::from_secs_f64),
|
||||
pool_idle_timeout: config
|
||||
.pool_idle_timeout
|
||||
.map(std::time::Duration::from_secs_f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RetryConfig> for lancedb::remote::RetryConfig {
|
||||
fn from(config: RetryConfig) -> Self {
|
||||
Self {
|
||||
retries: config.retries,
|
||||
connect_retries: config.connect_retries,
|
||||
read_retries: config.read_retries,
|
||||
backoff_factor: config.backoff_factor.map(|v| v as f32),
|
||||
backoff_jitter: config.backoff_jitter.map(|v| v as f32),
|
||||
statuses: config.statuses,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
user_agent: config
|
||||
.user_agent
|
||||
.unwrap_or(concat!("LanceDB-Node-Client/", env!("CARGO_PKG_VERSION")).to_string()),
|
||||
retry_config: config.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -337,7 +337,7 @@ impl Table {
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn index_stats(&self, index_name: String) -> napi::Result<Option<IndexStatistics>> {
|
||||
let tbl = self.inner_ref()?.as_native().unwrap();
|
||||
let tbl = self.inner_ref()?;
|
||||
let stats = tbl.index_stats(&index_name).await.default_error()?;
|
||||
Ok(stats.map(IndexStatistics::from))
|
||||
}
|
||||
@@ -480,32 +480,22 @@ pub struct IndexStatistics {
|
||||
/// The number of rows not indexed
|
||||
pub num_unindexed_rows: f64,
|
||||
/// The type of the index
|
||||
pub index_type: Option<String>,
|
||||
/// The metadata for each index
|
||||
pub indices: Vec<IndexMetadata>,
|
||||
pub index_type: String,
|
||||
/// The type of the distance function used by the index. This is only
|
||||
/// present for vector indices. Scalar and full text search indices do
|
||||
/// not have a distance function.
|
||||
pub distance_type: Option<String>,
|
||||
/// The number of parts this index is split into.
|
||||
pub num_indices: Option<u32>,
|
||||
}
|
||||
impl From<lancedb::index::IndexStatistics> for IndexStatistics {
|
||||
fn from(value: lancedb::index::IndexStatistics) -> Self {
|
||||
Self {
|
||||
num_indexed_rows: value.num_indexed_rows as f64,
|
||||
num_unindexed_rows: value.num_unindexed_rows as f64,
|
||||
index_type: value.index_type.map(|t| format!("{:?}", t)),
|
||||
indices: value.indices.into_iter().map(Into::into).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct IndexMetadata {
|
||||
pub metric_type: Option<String>,
|
||||
pub index_type: Option<String>,
|
||||
}
|
||||
|
||||
impl From<lancedb::index::IndexMetadata> for IndexMetadata {
|
||||
fn from(value: lancedb::index::IndexMetadata) -> Self {
|
||||
Self {
|
||||
metric_type: value.metric_type,
|
||||
index_type: value.index_type,
|
||||
index_type: value.index_type.to_string(),
|
||||
distance_type: value.distance_type.map(|d| d.to_string()),
|
||||
num_indices: value.num_indices,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.14.0-beta.0"
|
||||
current_version = "0.14.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.14.0-beta.0"
|
||||
version = "0.14.0"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
@@ -22,8 +22,6 @@ pyo3 = { version = "0.21", features = ["extension-module", "abi3-py38", "gil-ref
|
||||
# pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
||||
pyo3-asyncio-0-21 = { version = "0.21.0", features = ["attributes", "tokio-runtime"] }
|
||||
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
pin-project = "1.1.5"
|
||||
futures.workspace = true
|
||||
tokio = { version = "1.36.0", features = ["sync"] }
|
||||
@@ -35,4 +33,6 @@ pyo3-build-config = { version = "0.20.3", features = [
|
||||
] }
|
||||
|
||||
[features]
|
||||
default = ["remote"]
|
||||
fp16kernels = ["lancedb/fp16kernels"]
|
||||
remote = ["lancedb/remote"]
|
||||
|
||||
@@ -3,7 +3,7 @@ name = "lancedb"
|
||||
# version in Cargo.toml
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.18.0",
|
||||
"pylance==0.18.2",
|
||||
"requests>=2.31.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Dict, Optional, Union, Any
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from lancedb.remote import ClientConfig
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
@@ -120,7 +122,7 @@ async def connect_async(
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
client_config: Optional[Union[ClientConfig, Dict[str, Any]]] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
@@ -148,6 +150,10 @@ async def connect_async(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
default configuration is used.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
@@ -160,7 +166,13 @@ async def connect_async(
|
||||
... # For a local directory, provide a path to the database
|
||||
... db = await lancedb.connect_async("~/.lancedb")
|
||||
... # For object storage, use a URI prefix
|
||||
... db = await lancedb.connect_async("s3://my-bucket/lancedb")
|
||||
... db = await lancedb.connect_async("s3://my-bucket/lancedb",
|
||||
... storage_options={
|
||||
... "aws_access_key_id": "***"})
|
||||
... # Connect to LanceDB cloud
|
||||
... db = await lancedb.connect_async("db://my_database", api_key="ldb_...",
|
||||
... client_config={
|
||||
... "retry_config": {"retries": 5}})
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -172,6 +184,9 @@ async def connect_async(
|
||||
else:
|
||||
read_consistency_interval_secs = None
|
||||
|
||||
if isinstance(client_config, dict):
|
||||
client_config = ClientConfig(**client_config)
|
||||
|
||||
return AsyncConnection(
|
||||
await lancedb_connect(
|
||||
sanitize_uri(uri),
|
||||
@@ -179,6 +194,7 @@ async def connect_async(
|
||||
region,
|
||||
host_override,
|
||||
read_consistency_interval_secs,
|
||||
client_config,
|
||||
storage_options,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
URI = Union[str, Path]
|
||||
VECTOR_COLUMN_NAME = "vector"
|
||||
|
||||
@@ -96,7 +96,7 @@ class DBConnection(EnforceOverrides):
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
- list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
@@ -579,7 +579,7 @@ class AsyncConnection(object):
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
- list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
@@ -34,7 +25,7 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
|
||||
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
|
||||
max_retries: int = (
|
||||
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
|
||||
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@@ -46,22 +37,37 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for the source column in the database
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query with retries
|
||||
def compute_query_embeddings_with_retry(
|
||||
self, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for a given user query with retries
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_query_embeddings, max_retries=self.max_retries
|
||||
@@ -70,9 +76,15 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database with retries
|
||||
def compute_source_embeddings_with_retry(
|
||||
self, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for the source column in the database with retries.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_source_embeddings, max_retries=self.max_retries
|
||||
@@ -94,8 +106,14 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
||||
return self.model_dump(
|
||||
exclude={
|
||||
field_name
|
||||
for field_name in self.model_fields
|
||||
if field_name.startswith("_")
|
||||
}
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
@@ -144,18 +162,20 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
def compute_query_embeddings(
|
||||
self, query: str, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
def compute_source_embeddings(
|
||||
self, texts: TEXT, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Generate the embeddings for the given texts"""
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
@@ -19,6 +10,7 @@ from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import ollama
|
||||
|
||||
|
||||
@register("ollama")
|
||||
@@ -39,17 +31,20 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
|
||||
def _compute_embedding(self, text):
|
||||
return self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
def _compute_embedding(self, text) -> Union["np.array", None]:
|
||||
return (
|
||||
self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
or None
|
||||
)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], "np.ndarray"]
|
||||
) -> List["np.array"]:
|
||||
) -> list[Union["np.array", None]]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
@@ -63,7 +58,7 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
return embeddings
|
||||
|
||||
@cached_property
|
||||
def _ollama_client(self):
|
||||
def _ollama_client(self) -> "ollama.Client":
|
||||
ollama = attempt_import_or_raise("ollama")
|
||||
# ToDo explore ollama.AsyncClient
|
||||
return ollama.Client(host=self.host, **self.ollama_client_kwargs)
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
import logging
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
@@ -89,17 +81,26 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
# TODO retry, rate limit, token limit
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
kwargs = {
|
||||
"input": texts,
|
||||
"model": self.name,
|
||||
}
|
||||
if self.dim:
|
||||
kwargs["dimensions"] = self.dim
|
||||
rs = self._openai_client.embeddings.create(**kwargs)
|
||||
try:
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
kwargs = {
|
||||
"input": texts,
|
||||
"model": self.name,
|
||||
}
|
||||
if self.dim:
|
||||
kwargs["dimensions"] = self.dim
|
||||
rs = self._openai_client.embeddings.create(**kwargs)
|
||||
except openai.BadRequestError:
|
||||
logging.exception("Bad request: %s", texts)
|
||||
return [None] * len(texts)
|
||||
except Exception:
|
||||
logging.exception("OpenAI embeddings error")
|
||||
raise
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -40,6 +40,11 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
The device to use for the model. Default is "cpu".
|
||||
show_progress_bar : bool
|
||||
Whether to show a progress bar when loading the model. Default is True.
|
||||
trust_remote_code : bool
|
||||
Whether or not to allow for custom models defined on the HuggingFace
|
||||
Hub in their own modeling files. This option should only be set to True
|
||||
for repositories you trust and in which you have read the code, as it
|
||||
will execute code present on the Hub on your local machine.
|
||||
|
||||
to download package, run :
|
||||
`pip install transformers`
|
||||
@@ -49,6 +54,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
|
||||
name: str = "colbert-ir/colbertv2.0"
|
||||
device: str = "cpu"
|
||||
trust_remote_code: bool = False
|
||||
_tokenizer: Any = PrivateAttr()
|
||||
_model: Any = PrivateAttr()
|
||||
|
||||
@@ -57,7 +63,9 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
self._ndims = None
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
|
||||
self._model = transformers.AutoModel.from_pretrained(self.name)
|
||||
self._model = transformers.AutoModel.from_pretrained(
|
||||
self.name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
self._model.to(self.device)
|
||||
|
||||
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
||||
|
||||
@@ -104,4 +104,4 @@ class LanceMergeInsertBuilder(object):
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
return self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
|
||||
@@ -36,6 +36,7 @@ from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -575,12 +576,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._reranker = None
|
||||
self._str_query = str_query
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
metric: "L2" or "cosine" or "dot"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
@@ -588,7 +589,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric
|
||||
self._metric = metric.lower()
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
@@ -679,6 +680,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
if self._reranker is not None:
|
||||
rs_table = result_set.read_all()
|
||||
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||
check_reranker_result(result_set)
|
||||
# convert result_set back to RecordBatchReader
|
||||
result_set = pa.RecordBatchReader.from_batches(
|
||||
result_set.schema, result_set.to_batches()
|
||||
@@ -811,6 +813,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
results = results.read_all()
|
||||
if self._reranker is not None:
|
||||
results = self._reranker.rerank_fts(self._query, results)
|
||||
check_reranker_result(results)
|
||||
return results
|
||||
|
||||
def tantivy_to_arrow(self) -> pa.Table:
|
||||
@@ -953,8 +956,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str = None,
|
||||
vector_column: str = None,
|
||||
query: Optional[str] = None,
|
||||
vector_column: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
@@ -1060,10 +1063,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._fts_query._query, vector_results, fts_results
|
||||
)
|
||||
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
check_reranker_result(results)
|
||||
|
||||
# apply limit after reranking
|
||||
results = results.slice(length=self._limit)
|
||||
@@ -1112,8 +1112,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
normalize="score",
|
||||
reranker: Reranker = RRFReranker(),
|
||||
normalize: str = "score",
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Rerank the hybrid search results using the specified reranker. The reranker
|
||||
@@ -1121,12 +1121,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker, default RRFReranker()
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
normalize: str, default "score"
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank",
|
||||
the scores are converted to ranks and then normalized. If "score", the
|
||||
scores are normalized directly.
|
||||
reranker: Reranker, default RRFReranker()
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
|
||||
@@ -12,9 +12,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import attrs
|
||||
from lancedb import __version__
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -62,3 +65,109 @@ class LanceDBClient(abc.ABC):
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query the LanceDB server for the given table and query."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeoutConfig:
|
||||
"""Timeout configuration for remote HTTP client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
connect_timeout: Optional[timedelta]
|
||||
The timeout for establishing a connection. Default is 120 seconds (2 minutes).
|
||||
This can also be set via the environment variable
|
||||
`LANCE_CLIENT_CONNECT_TIMEOUT`, as an integer number of seconds.
|
||||
read_timeout: Optional[timedelta]
|
||||
The timeout for reading data from the server. Default is 300 seconds
|
||||
(5 minutes). This can also be set via the environment variable
|
||||
`LANCE_CLIENT_READ_TIMEOUT`, as an integer number of seconds.
|
||||
pool_idle_timeout: Optional[timedelta]
|
||||
The timeout for keeping idle connections in the connection pool. Default
|
||||
is 300 seconds (5 minutes). This can also be set via the environment variable
|
||||
`LANCE_CLIENT_CONNECTION_TIMEOUT`, as an integer number of seconds.
|
||||
"""
|
||||
|
||||
connect_timeout: Optional[timedelta] = None
|
||||
read_timeout: Optional[timedelta] = None
|
||||
pool_idle_timeout: Optional[timedelta] = None
|
||||
|
||||
@staticmethod
|
||||
def __to_timedelta(value) -> Optional[timedelta]:
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, timedelta):
|
||||
return value
|
||||
elif isinstance(value, (int, float)):
|
||||
return timedelta(seconds=value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid value for timeout: {value}, must be a timedelta "
|
||||
"or number of seconds"
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.connect_timeout = self.__to_timedelta(self.connect_timeout)
|
||||
self.read_timeout = self.__to_timedelta(self.read_timeout)
|
||||
self.pool_idle_timeout = self.__to_timedelta(self.pool_idle_timeout)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Retry configuration for the remote HTTP client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
retries: Optional[int]
|
||||
The maximum number of retries for a request. Default is 3. You can also set this
|
||||
via the environment variable `LANCE_CLIENT_MAX_RETRIES`.
|
||||
connect_retries: Optional[int]
|
||||
The maximum number of retries for connection errors. Default is 3. You can also
|
||||
set this via the environment variable `LANCE_CLIENT_CONNECT_RETRIES`.
|
||||
read_retries: Optional[int]
|
||||
The maximum number of retries for read errors. Default is 3. You can also set
|
||||
this via the environment variable `LANCE_CLIENT_READ_RETRIES`.
|
||||
backoff_factor: Optional[float]
|
||||
The backoff factor to apply between retries. Default is 0.25. Between each retry
|
||||
the client will wait for the amount of seconds:
|
||||
`{backoff factor} * (2 ** ({number of previous retries}))`. So for the default
|
||||
of 0.25, the first retry will wait 0.25 seconds, the second retry will wait 0.5
|
||||
seconds, the third retry will wait 1 second, etc.
|
||||
|
||||
You can also set this via the environment variable
|
||||
`LANCE_CLIENT_RETRY_BACKOFF_FACTOR`.
|
||||
backoff_jitter: Optional[float]
|
||||
The jitter to apply to the backoff factor, in seconds. Default is 0.25.
|
||||
|
||||
A random value between 0 and `backoff_jitter` will be added to the backoff
|
||||
factor in seconds. So for the default of 0.25 seconds, between 0 and 250
|
||||
milliseconds will be added to the sleep between each retry.
|
||||
|
||||
You can also set this via the environment variable
|
||||
`LANCE_CLIENT_RETRY_BACKOFF_JITTER`.
|
||||
statuses: Optional[List[int]
|
||||
The HTTP status codes for which to retry the request. Default is
|
||||
[429, 500, 502, 503].
|
||||
|
||||
You can also set this via the environment variable
|
||||
`LANCE_CLIENT_RETRY_STATUSES`. Use a comma-separated list of integers.
|
||||
"""
|
||||
|
||||
retries: Optional[int] = None
|
||||
connect_retries: Optional[int] = None
|
||||
read_retries: Optional[int] = None
|
||||
backoff_factor: Optional[float] = None
|
||||
backoff_jitter: Optional[float] = None
|
||||
statuses: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: Optional[RetryConfig] = None
|
||||
timeout_config: Optional[TimeoutConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
self.retry_config = RetryConfig(**self.retry_config)
|
||||
if isinstance(self.timeout_config, dict):
|
||||
self.timeout_config = TimeoutConfig(**self.timeout_config)
|
||||
|
||||
@@ -103,19 +103,29 @@ class RestfulLanceDBClient:
|
||||
|
||||
@staticmethod
|
||||
def _check_status(resp: requests.Response):
|
||||
# Leaving request id empty for now, as we'll be replacing this impl
|
||||
# with the Rust one shortly.
|
||||
if resp.status_code == 404:
|
||||
raise LanceDBClientError(f"Not found: {resp.text}")
|
||||
raise LanceDBClientError(
|
||||
f"Not found: {resp.text}", request_id="", status_code=404
|
||||
)
|
||||
elif 400 <= resp.status_code < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status_code}, error: {resp.text}"
|
||||
f"Bad Request: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
elif 500 <= resp.status_code < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status_code}, error: {resp.text}"
|
||||
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status_code}, error: {resp.text}"
|
||||
f"Unknown Error: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
|
||||
@@ -12,5 +12,102 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LanceDBClientError(RuntimeError):
|
||||
"""An error that occurred in the LanceDB client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
message: str
|
||||
The error message.
|
||||
request_id: str
|
||||
The id of the request that failed. This can be provided in error reports
|
||||
to help diagnose the issue.
|
||||
status_code: int
|
||||
The HTTP status code of the response. May be None if the request
|
||||
failed before the response was received.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, request_id: str, status_code: Optional[int] = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.request_id = request_id
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class HttpError(LanceDBClientError):
|
||||
"""An error that occurred during an HTTP request.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
message: str
|
||||
The error message.
|
||||
request_id: str
|
||||
The id of the request that failed. This can be provided in error reports
|
||||
to help diagnose the issue.
|
||||
status_code: int
|
||||
The HTTP status code of the response. May be None if the request
|
||||
failed before the response was received.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RetryError(LanceDBClientError):
|
||||
"""An error that occurs when the client has exceeded the maximum number of retries.
|
||||
|
||||
The retry strategy can be adjusted by setting the
|
||||
[retry_config](lancedb.remote.ClientConfig.retry_config) in the client
|
||||
configuration. This is passed in the `client_config` argument of
|
||||
[connect](lancedb.connect) and [connect_async](lancedb.connect_async).
|
||||
|
||||
The __cause__ attribute of this exception will be the last exception that
|
||||
caused the retry to fail. It will be an
|
||||
[HttpError][lancedb.remote.errors.HttpError] instance.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
message: str
|
||||
The retry error message, which will describe which retry limit was hit.
|
||||
request_id: str
|
||||
The id of the request that failed. This can be provided in error reports
|
||||
to help diagnose the issue.
|
||||
request_failures: int
|
||||
The number of request failures.
|
||||
connect_failures: int
|
||||
The number of connect failures.
|
||||
read_failures: int
|
||||
The number of read failures.
|
||||
max_request_failures: int
|
||||
The maximum number of request failures.
|
||||
max_connect_failures: int
|
||||
The maximum number of connect failures.
|
||||
max_read_failures: int
|
||||
The maximum number of read failures.
|
||||
status_code: int
|
||||
The HTTP status code of the last response. May be None if the request
|
||||
failed before the response was received.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
request_id: str,
|
||||
request_failures: int,
|
||||
connect_failures: int,
|
||||
read_failures: int,
|
||||
max_request_failures: int,
|
||||
max_connect_failures: int,
|
||||
max_read_failures: int,
|
||||
status_code: Optional[int],
|
||||
):
|
||||
super().__init__(message, request_id, status_code)
|
||||
self.request_failures = request_failures
|
||||
self.connect_failures = connect_failures
|
||||
self.read_failures = read_failures
|
||||
self.max_request_failures = max_request_failures
|
||||
self.max_connect_failures = max_connect_failures
|
||||
self.max_read_failures = max_read_failures
|
||||
|
||||
@@ -26,7 +26,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from ..util import inf_vector_column_query, value_to_sql
|
||||
from ..util import value_to_sql, infer_vector_column_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
from .db import RemoteDBConnection
|
||||
@@ -266,7 +266,7 @@ class RemoteTable(Table):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Union[VEC, str],
|
||||
query: Union[VEC, str] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type="auto",
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
@@ -305,8 +305,6 @@ class RemoteTable(Table):
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
|
||||
@@ -329,11 +327,15 @@ class RemoteTable(Table):
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
raise e
|
||||
# empty query builder is not supported in saas, raise error
|
||||
if query is None and query_type != "hybrid":
|
||||
raise ValueError("Empty query is not supported")
|
||||
vector_column_name = infer_vector_column_name(
|
||||
schema=self.schema,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
vector_column_name=vector_column_name,
|
||||
)
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
|
||||
@@ -105,7 +105,7 @@ class Reranker(ABC):
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Rerank function receives the individual results from the vector and FTS search
|
||||
results. You can choose to use any of the results to generate the final results,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from numpy import NaN
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
@@ -58,14 +59,42 @@ class LinearCombinationReranker(Reranker):
|
||||
def merge_results(
|
||||
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
|
||||
):
|
||||
# If both are empty then just return an empty table
|
||||
if len(vector_results) == 0 and len(fts_results) == 0:
|
||||
return vector_results
|
||||
# If one is empty then return the other
|
||||
# If one is empty then return the other and add _relevance_score
|
||||
# column equal the existing vector or fts score
|
||||
if len(vector_results) == 0:
|
||||
return fts_results
|
||||
results = fts_results.append_column(
|
||||
"_relevance_score",
|
||||
pa.array(fts_results["_score"], type=pa.float32()),
|
||||
)
|
||||
if self.score == "relevance":
|
||||
results = self._keep_relevance_score(results)
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_distance",
|
||||
pa.array([NaN] * len(fts_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
if len(fts_results) == 0:
|
||||
return vector_results
|
||||
# invert the distance to relevance score
|
||||
results = vector_results.append_column(
|
||||
"_relevance_score",
|
||||
pa.array(
|
||||
[
|
||||
self._invert_score(distance)
|
||||
for distance in vector_results["_distance"].to_pylist()
|
||||
],
|
||||
type=pa.float32(),
|
||||
),
|
||||
)
|
||||
if self.score == "relevance":
|
||||
results = self._keep_relevance_score(results)
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_score",
|
||||
pa.array([NaN] * len(vector_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
# sort both input tables on _rowid
|
||||
combined_list = []
|
||||
|
||||
19
python/python/lancedb/rerankers/util.py
Normal file
19
python/python/lancedb/rerankers/util.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The Lance Authors
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def check_reranker_result(result):
|
||||
if not isinstance(result, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(result)}"
|
||||
)
|
||||
|
||||
# Enforce that `_relevance_score` column is present in the result of every
|
||||
# rerank_hybrid method
|
||||
if "_relevance_score" not in result.column_names:
|
||||
raise ValueError(
|
||||
"rerank_hybrid must return a pyarrow.Table with a column"
|
||||
"named `_relevance_score`"
|
||||
)
|
||||
@@ -31,7 +31,6 @@ import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
from lance import LanceDataset
|
||||
from lance.dependencies import _check_for_hugging_face
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -50,7 +49,7 @@ from .query import (
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
get_uri_scheme,
|
||||
inf_vector_column_query,
|
||||
infer_vector_column_name,
|
||||
join_uri,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
@@ -87,6 +86,9 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
if isinstance(data, LanceModel):
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
if isinstance(data, dict):
|
||||
raise ValueError("Cannot add a single dictionary to a table. Use a list.")
|
||||
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
@@ -98,8 +100,6 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
return pa.Table.from_batches(data, schema=schema)
|
||||
else:
|
||||
return pa.Table.from_pylist(data)
|
||||
elif isinstance(data, dict):
|
||||
return vec_to_table(data)
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
# Do not add schema here, since schema may contains the vector column
|
||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||
@@ -554,7 +554,7 @@ class Table(ABC):
|
||||
data: DATA
|
||||
The data to insert into the table. Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
- list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
@@ -1409,7 +1409,7 @@ class LanceTable(Table):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: list-of-dict, dict, pd.DataFrame
|
||||
data: list-of-dict, pd.DataFrame
|
||||
The data to insert into the table.
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
@@ -1630,11 +1630,12 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
raise e
|
||||
vector_column_name = infer_vector_column_name(
|
||||
schema=self.schema,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
vector_column_name=vector_column_name,
|
||||
)
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
@@ -1998,22 +1999,26 @@ def _sanitize_vector_column(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
||||
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
|
||||
# kernel over f16 types.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
|
||||
if pa.types.is_float16(vec_arr.values.type):
|
||||
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
|
||||
# kernel over f16 types yet.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
else:
|
||||
if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -2057,8 +2062,15 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
def _sanitize_nans(
|
||||
data,
|
||||
fill_value,
|
||||
on_bad_vectors,
|
||||
vec_arr: pa.FixedSizeListArray,
|
||||
vector_column_name: str,
|
||||
):
|
||||
"""Sanitize NaNs in vectors"""
|
||||
assert pa.types.is_fixed_size_list(vec_arr.type)
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
@@ -2078,9 +2090,11 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False)
|
||||
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
|
||||
data = data.filter(is_full)
|
||||
# Drop is very slow to be able to filter out NaNs in a fixed size list array
|
||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
||||
not_nulls = np.any(np_arr, axis=1)
|
||||
data = data.filter(~not_nulls)
|
||||
return data
|
||||
|
||||
|
||||
@@ -2334,7 +2348,7 @@ class AsyncTable:
|
||||
data: DATA
|
||||
The data to insert into the table. Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
- list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
@@ -2450,7 +2464,31 @@ class AsyncTable:
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
schema = await self.schema()
|
||||
if on_bad_vectors is None:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data, _ = _sanitize_data(
|
||||
new_data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||
await self._inner.execute_merge_insert(
|
||||
data,
|
||||
dict(
|
||||
on=merge._on,
|
||||
when_matched_update_all=merge._when_matched_update_all,
|
||||
when_matched_update_all_condition=merge._when_matched_update_all_condition,
|
||||
when_not_matched_insert_all=merge._when_not_matched_insert_all,
|
||||
when_not_matched_by_source_delete=merge._when_not_matched_by_source_delete,
|
||||
when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition,
|
||||
),
|
||||
)
|
||||
|
||||
async def delete(self, where: str):
|
||||
"""Delete rows from the table.
|
||||
@@ -2669,6 +2707,26 @@ class AsyncTable:
|
||||
"""
|
||||
return await self._inner.list_indices()
|
||||
|
||||
async def index_stats(self, index_name: str) -> Optional[IndexStatistics]:
|
||||
"""
|
||||
Retrieve statistics about an index
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_name: str
|
||||
The name of the index to retrieve statistics for
|
||||
|
||||
Returns
|
||||
-------
|
||||
IndexStatistics or None
|
||||
The statistics about the index. Returns None if the index does not exist.
|
||||
"""
|
||||
stats = await self._inner.index_stats(index_name)
|
||||
if stats is None:
|
||||
return None
|
||||
else:
|
||||
return IndexStatistics(**stats)
|
||||
|
||||
async def uses_v2_manifest_paths(self) -> bool:
|
||||
"""
|
||||
Check if the table is using the new v2 manifest paths.
|
||||
@@ -2699,3 +2757,31 @@ class AsyncTable:
|
||||
to check if the table is already using the new path style.
|
||||
"""
|
||||
await self._inner.migrate_manifest_paths_v2()
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexStatistics:
|
||||
"""
|
||||
Statistics about an index.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
num_indexed_rows: int
|
||||
The number of rows that are covered by this index.
|
||||
num_unindexed_rows: int
|
||||
The number of rows that are not covered by this index.
|
||||
index_type: str
|
||||
The type of index that was created.
|
||||
distance_type: Optional[str]
|
||||
The distance type used by the index.
|
||||
num_indices: Optional[int]
|
||||
The number of parts the index is split into.
|
||||
"""
|
||||
|
||||
num_indexed_rows: int
|
||||
num_unindexed_rows: int
|
||||
index_type: Literal[
|
||||
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
|
||||
]
|
||||
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
||||
num_indices: Optional[int] = None
|
||||
|
||||
@@ -9,7 +9,7 @@ import pathlib
|
||||
import warnings
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, Optional, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
@@ -212,6 +212,23 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
return vector_col_name
|
||||
|
||||
|
||||
def infer_vector_column_name(
|
||||
schema: pa.Schema,
|
||||
query_type: str,
|
||||
query: Optional[Any], # inferred later in query builder
|
||||
vector_column_name: Optional[str],
|
||||
):
|
||||
if (vector_column_name is None and query is not None and query_type != "fts") or (
|
||||
vector_column_name is None and query_type == "hybrid"
|
||||
):
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(schema)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return vector_column_name
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
|
||||
@@ -354,7 +354,7 @@ async def test_create_mode_async(tmp_path):
|
||||
)
|
||||
await db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await db.create_table("test", data=data)
|
||||
|
||||
new_data = pd.DataFrame(
|
||||
@@ -382,7 +382,7 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
)
|
||||
tbl = await db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await db.create_table("test", data=data)
|
||||
|
||||
# open the table but don't add more rows
|
||||
|
||||
@@ -86,6 +86,47 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
def test_embedding_with_bad_results(tmp_path):
|
||||
@register("mock-embedding")
|
||||
class MockEmbeddingFunction(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 128
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> list[Union[np.array, None]]:
|
||||
return [
|
||||
None if i % 2 == 0 else np.random.randn(self.ndims())
|
||||
for i in range(len(texts))
|
||||
]
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("mock-embedding").create()
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
table.add(
|
||||
[{"text": "hello world"}, {"text": "bar"}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
df = table.to_pandas()
|
||||
assert len(table) == 1
|
||||
assert df.iloc[0]["text"] == "bar"
|
||||
|
||||
# table = db.create_table("test2", schema=Schema, mode="overwrite")
|
||||
# table.add(
|
||||
# [{"text": "hello world"}, {"text": "bar"}],
|
||||
# )
|
||||
# assert len(table) == 2
|
||||
# tbl = table.to_arrow()
|
||||
# assert tbl["vector"].null_count == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
@@ -142,3 +183,45 @@ def test_add_optional_vector(tmp_path):
|
||||
expected = LanceSchema(id="id", text="text")
|
||||
tbl.add([expected])
|
||||
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_type",
|
||||
[
|
||||
"openai",
|
||||
"sentence-transformers",
|
||||
"huggingface",
|
||||
"ollama",
|
||||
"cohere",
|
||||
"instructor",
|
||||
],
|
||||
)
|
||||
def test_embedding_function_safe_model_dump(embedding_type):
|
||||
registry = get_registry()
|
||||
|
||||
# Note: Some embedding types might require specific parameters
|
||||
try:
|
||||
model = registry.get(embedding_type).create()
|
||||
except Exception as e:
|
||||
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
|
||||
|
||||
dumped_model = model.safe_model_dump()
|
||||
|
||||
assert all(
|
||||
not k.startswith("_") for k in dumped_model.keys()
|
||||
), f"{embedding_type}: Dumped model contains keys starting with underscore"
|
||||
|
||||
assert (
|
||||
"max_retries" in dumped_model
|
||||
), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
|
||||
|
||||
assert isinstance(
|
||||
dumped_model, dict
|
||||
), f"{embedding_type}: Dumped model is not a dictionary"
|
||||
|
||||
for key in model.__dict__:
|
||||
if key.startswith("_"):
|
||||
assert key not in dumped_model, (
|
||||
f"{embedding_type}: Private attribute '{key}' "
|
||||
f"is present in dumped model"
|
||||
)
|
||||
|
||||
@@ -442,3 +442,42 @@ def test_watsonx_embedding(tmp_path):
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("ollama") is None, reason="Ollama not installed"
|
||||
)
|
||||
def test_ollama_embedding(tmp_path):
|
||||
model = get_registry().get("ollama").create(max_retries=0)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
result = tbl.search("hello").limit(1).to_pandas()
|
||||
assert result["text"][0] == "hello world"
|
||||
|
||||
# Test safe_model_dump
|
||||
dumped_model = model.safe_model_dump()
|
||||
assert isinstance(dumped_model, dict)
|
||||
assert "name" in dumped_model
|
||||
assert "max_retries" in dumped_model
|
||||
assert dumped_model["max_retries"] == 0
|
||||
assert all(not k.startswith("_") for k in dumped_model.keys())
|
||||
|
||||
# Test serialization of the dumped model
|
||||
import json
|
||||
|
||||
try:
|
||||
json.dumps(dumped_model)
|
||||
except TypeError:
|
||||
pytest.fail("Failed to JSON serialize the dumped model")
|
||||
|
||||
@@ -63,17 +63,24 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
await some_table.create_index("id", config=Bitmap())
|
||||
# TODO: Fix via https://github.com/lancedb/lance/issues/2039
|
||||
# indices = await some_table.list_indices()
|
||||
# assert str(indices) == '[Index(Bitmap, columns=["id"])]'
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(Bitmap, columns=["id"])]'
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
index_name = indices[0].name
|
||||
stats = await some_table.index_stats(index_name)
|
||||
assert stats.index_type == "BITMAP"
|
||||
assert stats.distance_type is None
|
||||
assert stats.num_indexed_rows == await some_table.count_rows()
|
||||
assert stats.num_unindexed_rows == 0
|
||||
assert stats.num_indices == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_label_list_index(some_table: AsyncTable):
|
||||
await some_table.create_index("tags", config=LabelList())
|
||||
# TODO: Fix via https://github.com/lancedb/lance/issues/2039
|
||||
# indices = await some_table.list_indices()
|
||||
# assert str(indices) == '[Index(LabelList, columns=["id"])]'
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(LabelList, columns=["tags"])]'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -91,6 +98,14 @@ async def test_create_vector_index(some_table: AsyncTable):
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "IvfPq"
|
||||
assert indices[0].columns == ["vector"]
|
||||
assert indices[0].name == "vector_idx"
|
||||
|
||||
stats = await some_table.index_stats("vector_idx")
|
||||
assert stats.index_type == "IVF_PQ"
|
||||
assert stats.distance_type == "l2"
|
||||
assert stats.num_indexed_rows == await some_table.count_rows()
|
||||
assert stats.num_unindexed_rows == 0
|
||||
assert stats.num_indices == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import contextlib
|
||||
import http.server
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
|
||||
import lancedb
|
||||
from lancedb.remote.errors import HttpError, RetryError
|
||||
import pyarrow as pa
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
import pytest
|
||||
|
||||
|
||||
class FakeLanceDBClient:
|
||||
@@ -81,3 +87,106 @@ def test_create_table_with_recordbatches():
|
||||
table = conn.create_table("test", [batch], schema=batch.schema)
|
||||
assert table.name == "test"
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
||||
|
||||
|
||||
def make_mock_http_handler(handler):
|
||||
class MockLanceDBHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
handler(self)
|
||||
|
||||
def do_POST(self):
|
||||
handler(self)
|
||||
|
||||
return MockLanceDBHandler
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mock_lancedb_connection(handler):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
|
||||
db = await lancedb.connect_async(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override="http://localhost:8080",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {
|
||||
"connect_timeout": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_remote_db():
|
||||
def handler(request):
|
||||
# We created a UUID request id
|
||||
request_id = request.headers["x-request-id"]
|
||||
assert uuid.UUID(request_id).version == 4
|
||||
|
||||
# We set a user agent with the current library version
|
||||
user_agent = request.headers["User-Agent"]
|
||||
assert user_agent == f"LanceDB-Python-Client/{lancedb.__version__}"
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
table_names = await db.table_names()
|
||||
assert table_names == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error():
|
||||
request_id_holder = {"request_id": None}
|
||||
|
||||
def handler(request):
|
||||
request_id_holder["request_id"] = request.headers["x-request-id"]
|
||||
|
||||
request.send_response(507)
|
||||
request.end_headers()
|
||||
request.wfile.write(b"Internal Server Error")
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
|
||||
await db.table_names()
|
||||
|
||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||
assert exc_info.value.status_code == 507
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_error():
|
||||
request_id_holder = {"request_id": None}
|
||||
|
||||
def handler(request):
|
||||
request_id_holder["request_id"] = request.headers["x-request-id"]
|
||||
|
||||
request.send_response(429)
|
||||
request.end_headers()
|
||||
request.wfile.write(b"Try again later")
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
|
||||
await db.table_names()
|
||||
|
||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
cause = exc_info.value.__cause__
|
||||
assert isinstance(cause, HttpError)
|
||||
assert "Try again later" in str(cause)
|
||||
assert cause.request_id == request_id_holder["request_id"]
|
||||
assert cause.status_code == 429
|
||||
|
||||
@@ -120,12 +120,14 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
err = (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Vector search setting
|
||||
result = (
|
||||
@@ -135,7 +137,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
result_explicit = (
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
@@ -158,7 +162,26 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Multi-vector search setting
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
@@ -172,7 +195,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
result_deduped = reranker.rerank_multivector(
|
||||
[rs1, rs2, rs1], query, deduplicate=True
|
||||
)
|
||||
assert len(result_deduped) < 20
|
||||
assert len(result_deduped) <= 20
|
||||
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
|
||||
assert len(result) == 20 and result == result_arrow
|
||||
|
||||
@@ -213,7 +236,7 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.rerank(reranker, normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
@@ -228,12 +251,30 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").text(
|
||||
query
|
||||
).to_arrow()
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
# Test with empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
|
||||
@@ -193,6 +193,24 @@ def test_empty_table(db):
|
||||
tbl.add(data=data)
|
||||
|
||||
|
||||
def test_add_dictionary(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
tbl = LanceTable.create(db, "test", schema=schema)
|
||||
data = {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}
|
||||
with pytest.raises(ValueError) as excep_info:
|
||||
tbl.add(data=data)
|
||||
assert (
|
||||
str(excep_info.value)
|
||||
== "Cannot add a single dictionary to a table. Use a list."
|
||||
)
|
||||
|
||||
|
||||
def test_add(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
@@ -636,11 +654,13 @@ def test_merge_insert(db):
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
|
||||
# replace-range
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
|
||||
"a > 2"
|
||||
).execute(new_data)
|
||||
(
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete("a > 2")
|
||||
.execute(new_data)
|
||||
)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
@@ -658,6 +678,75 @@ def test_merge_insert(db):
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_insert_async(db_async: AsyncConnection):
|
||||
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
|
||||
table = await db_async.create_table("some_table", data=data)
|
||||
assert await table.count_rows() == 3
|
||||
version = await table.version()
|
||||
|
||||
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
|
||||
# upsert
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# conditional update
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all(where="target.b = 'b'")
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# insert-if-not-exists
|
||||
await table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# replace-range
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete("a > 2")
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# replace-range no condition
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete()
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
@@ -973,7 +1062,36 @@ def test_hybrid_search(db, tmp_path):
|
||||
.where("text='Arrrrggghhhhhhh'")
|
||||
.to_list()
|
||||
)
|
||||
len(result) == 1
|
||||
assert len(result) == 1
|
||||
|
||||
# with explicit query type
|
||||
vector_query = list(range(emb.ndims()))
|
||||
result = (
|
||||
table.search(query_type="hybrid")
|
||||
.vector(vector_query)
|
||||
.text("Arrrrggghhhhhhh")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
|
||||
# with vector_column_name
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(vector_query)
|
||||
.text("Arrrrggghhhhhhh")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
|
||||
# fail if only text or vector is provided
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").to_list()
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").vector(vector_query).to_list()
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -7,7 +7,7 @@ use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::From
|
||||
use lancedb::connection::{Connection as LanceConnection, CreateTableMode, LanceFileVersion};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods, Bound, PyAny, PyRef, PyResult, Python,
|
||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_asyncio_0_21::tokio::future_into_py;
|
||||
|
||||
@@ -187,6 +187,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn connect(
|
||||
py: Python,
|
||||
uri: String,
|
||||
@@ -194,6 +195,7 @@ pub fn connect(
|
||||
region: Option<String>,
|
||||
host_override: Option<String>,
|
||||
read_consistency_interval: Option<f64>,
|
||||
client_config: Option<PyClientConfig>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
future_into_py(py, async move {
|
||||
@@ -214,6 +216,70 @@ pub fn connect(
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
#[cfg(feature = "remote")]
|
||||
if let Some(client_config) = client_config {
|
||||
builder = builder.client_config(client_config.into());
|
||||
}
|
||||
Ok(Connection::new(builder.execute().await.infer_error()?))
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyClientConfig {
|
||||
user_agent: String,
|
||||
retry_config: Option<PyClientRetryConfig>,
|
||||
timeout_config: Option<PyClientTimeoutConfig>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyClientRetryConfig {
|
||||
retries: Option<u8>,
|
||||
connect_retries: Option<u8>,
|
||||
read_retries: Option<u8>,
|
||||
backoff_factor: Option<f32>,
|
||||
backoff_jitter: Option<f32>,
|
||||
statuses: Option<Vec<u16>>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyClientTimeoutConfig {
|
||||
connect_timeout: Option<Duration>,
|
||||
read_timeout: Option<Duration>,
|
||||
pool_idle_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<PyClientRetryConfig> for lancedb::remote::RetryConfig {
|
||||
fn from(value: PyClientRetryConfig) -> Self {
|
||||
Self {
|
||||
retries: value.retries,
|
||||
connect_retries: value.connect_retries,
|
||||
read_retries: value.read_retries,
|
||||
backoff_factor: value.backoff_factor,
|
||||
backoff_jitter: value.backoff_jitter,
|
||||
statuses: value.statuses,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<PyClientTimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
fn from(value: PyClientTimeoutConfig) -> Self {
|
||||
Self {
|
||||
connect_timeout: value.connect_timeout,
|
||||
read_timeout: value.read_timeout,
|
||||
pool_idle_timeout: value.pool_idle_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(value: PyClientConfig) -> Self {
|
||||
Self {
|
||||
user_agent: value.user_agent,
|
||||
retry_config: value.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
use pyo3::{
|
||||
exceptions::{PyIOError, PyNotImplementedError, PyOSError, PyRuntimeError, PyValueError},
|
||||
PyResult,
|
||||
intern,
|
||||
types::{PyAnyMethods, PyNone},
|
||||
PyErr, PyResult, Python,
|
||||
};
|
||||
|
||||
use lancedb::error::Error as LanceError;
|
||||
@@ -38,12 +40,79 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
LanceError::InvalidInput { .. }
|
||||
| LanceError::InvalidTableName { .. }
|
||||
| LanceError::TableNotFound { .. }
|
||||
| LanceError::Schema { .. } => self.value_error(),
|
||||
| LanceError::Schema { .. }
|
||||
| LanceError::TableAlreadyExists { .. } => self.value_error(),
|
||||
LanceError::CreateDir { .. } => self.os_error(),
|
||||
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
|
||||
LanceError::NotSupported { .. } => {
|
||||
Err(PyNotImplementedError::new_err(err.to_string()))
|
||||
}
|
||||
LanceError::Http {
|
||||
request_id,
|
||||
source,
|
||||
status_code,
|
||||
} => Python::with_gil(|py| {
|
||||
let message = err.to_string();
|
||||
let http_err_cls = py
|
||||
.import_bound(intern!(py, "lancedb.remote.errors"))?
|
||||
.getattr(intern!(py, "HttpError"))?;
|
||||
let err = http_err_cls.call1((
|
||||
message,
|
||||
request_id,
|
||||
status_code.map(|s| s.as_u16()),
|
||||
))?;
|
||||
|
||||
if let Some(cause) = source.source() {
|
||||
// The HTTP error already includes the first cause. But
|
||||
// we can add the rest of the chain if there is any more.
|
||||
let cause_err = http_from_rust_error(
|
||||
py,
|
||||
cause,
|
||||
request_id,
|
||||
status_code.map(|s| s.as_u16()),
|
||||
)?;
|
||||
err.setattr(intern!(py, "__cause__"), cause_err)?;
|
||||
}
|
||||
|
||||
Err(PyErr::from_value_bound(err))
|
||||
}),
|
||||
LanceError::Retry {
|
||||
request_id,
|
||||
request_failures,
|
||||
max_request_failures,
|
||||
connect_failures,
|
||||
max_connect_failures,
|
||||
read_failures,
|
||||
max_read_failures,
|
||||
source,
|
||||
status_code,
|
||||
} => Python::with_gil(|py| {
|
||||
let cause_err = http_from_rust_error(
|
||||
py,
|
||||
source.as_ref(),
|
||||
request_id,
|
||||
status_code.map(|s| s.as_u16()),
|
||||
)?;
|
||||
|
||||
let message = err.to_string();
|
||||
let retry_error_cls = py
|
||||
.import_bound(intern!(py, "lancedb.remote.errors"))?
|
||||
.getattr("RetryError")?;
|
||||
let err = retry_error_cls.call1((
|
||||
message,
|
||||
request_id,
|
||||
*request_failures,
|
||||
*connect_failures,
|
||||
*read_failures,
|
||||
*max_request_failures,
|
||||
*max_connect_failures,
|
||||
*max_read_failures,
|
||||
status_code.map(|s| s.as_u16()),
|
||||
))?;
|
||||
|
||||
err.setattr(intern!(py, "__cause__"), cause_err)?;
|
||||
Err(PyErr::from_value_bound(err))
|
||||
}),
|
||||
_ => self.runtime_error(),
|
||||
},
|
||||
}
|
||||
@@ -61,3 +130,24 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
self.map_err(|err| PyValueError::new_err(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn http_from_rust_error(
|
||||
py: Python<'_>,
|
||||
err: &dyn std::error::Error,
|
||||
request_id: &str,
|
||||
status_code: Option<u16>,
|
||||
) -> PyResult<PyErr> {
|
||||
let message = err.to_string();
|
||||
let http_err_cls = py.import("lancedb.remote.errors")?.getattr("HttpError")?;
|
||||
let py_err = http_err_cls.call1((message, request_id, status_code))?;
|
||||
|
||||
// Reset the traceback since it doesn't provide additional information.
|
||||
let py_err = py_err.call_method1(intern!(py, "with_traceback"), (PyNone::get_bound(py),))?;
|
||||
|
||||
if let Some(cause) = err.source() {
|
||||
let cause_err = http_from_rust_error(py, cause, request_id, status_code)?;
|
||||
py_err.setattr(intern!(py, "__cause__"), cause_err)?;
|
||||
}
|
||||
|
||||
Ok(PyErr::from_value(py_err))
|
||||
}
|
||||
|
||||
@@ -200,6 +200,8 @@ pub struct IndexConfig {
|
||||
/// Currently this is always a list of size 1. In the future there may
|
||||
/// be more columns to represent composite indices.
|
||||
pub columns: Vec<String>,
|
||||
/// Name of the index.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -215,6 +217,7 @@ impl From<lancedb::index::IndexConfig> for IndexConfig {
|
||||
Self {
|
||||
index_type,
|
||||
columns: value.columns,
|
||||
name: value.name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ use lancedb::table::{
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods,
|
||||
types::{PyDict, PyString},
|
||||
Bound, PyAny, PyRef, PyResult, Python,
|
||||
types::{PyDict, PyDictMethods, PyString},
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python, ToPyObject,
|
||||
};
|
||||
use pyo3_asyncio_0_21::tokio::future_into_py;
|
||||
|
||||
@@ -204,6 +204,33 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn index_stats(self_: PyRef<'_, Self>, index_name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let stats = inner.index_stats(&index_name).await.infer_error()?;
|
||||
if let Some(stats) = stats {
|
||||
Python::with_gil(|py| {
|
||||
let dict = PyDict::new_bound(py);
|
||||
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
|
||||
dict.set_item("num_unindexed_rows", stats.num_unindexed_rows)?;
|
||||
dict.set_item("index_type", stats.index_type.to_string())?;
|
||||
|
||||
if let Some(distance_type) = stats.distance_type {
|
||||
dict.set_item("distance_type", distance_type.to_string())?;
|
||||
}
|
||||
|
||||
if let Some(num_indices) = stats.num_indices {
|
||||
dict.set_item("num_indices", num_indices)?;
|
||||
}
|
||||
|
||||
Ok(Some(dict.to_object(py)))
|
||||
})
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn __repr__(&self) -> String {
|
||||
match &self.inner {
|
||||
None => format!("ClosedTable({})", self.name),
|
||||
@@ -304,6 +331,31 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn execute_merge_insert<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
data: Bound<'a, PyAny>,
|
||||
parameters: MergeInsertParams,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let batches: ArrowArrayStreamReader = ArrowArrayStreamReader::from_pyarrow_bound(&data)?;
|
||||
let on = parameters.on.iter().map(|s| s.as_str()).collect::<Vec<_>>();
|
||||
let mut builder = self_.inner_ref()?.merge_insert(&on);
|
||||
if parameters.when_matched_update_all {
|
||||
builder.when_matched_update_all(parameters.when_matched_update_all_condition);
|
||||
}
|
||||
if parameters.when_not_matched_insert_all {
|
||||
builder.when_not_matched_insert_all();
|
||||
}
|
||||
if parameters.when_not_matched_by_source_delete {
|
||||
builder
|
||||
.when_not_matched_by_source_delete(parameters.when_not_matched_by_source_condition);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
builder.execute(Box::new(batches)).await.infer_error()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
@@ -328,3 +380,14 @@ impl Table {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
#[pyo3(from_item_all)]
|
||||
pub struct MergeInsertParams {
|
||||
on: Vec<String>,
|
||||
when_matched_update_all: bool,
|
||||
when_matched_update_all_condition: Option<String>,
|
||||
when_not_matched_insert_all: bool,
|
||||
when_not_matched_by_source_delete: bool,
|
||||
when_not_matched_by_source_condition: Option<String>,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.10.0"
|
||||
version = "0.11.0-beta.1"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -46,6 +46,10 @@ impl JsQuery {
|
||||
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
|
||||
.value(&mut cx);
|
||||
|
||||
let fast_search = query_obj
|
||||
.get_opt::<JsBoolean, _, _>(&mut cx, "_fastSearch")?
|
||||
.map(|val| val.value(&mut cx));
|
||||
|
||||
let is_electron = cx
|
||||
.argument::<JsBoolean>(1)
|
||||
.or_throw(&mut cx)?
|
||||
@@ -70,6 +74,9 @@ impl JsQuery {
|
||||
if let Some(limit) = limit {
|
||||
builder = builder.limit(limit as usize);
|
||||
};
|
||||
if let Some(true) = fast_search {
|
||||
builder = builder.fast_search();
|
||||
}
|
||||
|
||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
|
||||
|
||||
@@ -470,49 +470,42 @@ impl JsTable {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let index_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let load_stats = futures::try_join!(
|
||||
table.as_native().unwrap().count_indexed_rows(&index_uuid),
|
||||
table.as_native().unwrap().count_unindexed_rows(&index_uuid)
|
||||
);
|
||||
let load_stats = table.index_stats(index_name).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let (indexed_rows, unindexed_rows) = load_stats.or_throw(&mut cx)?;
|
||||
let stats = load_stats.or_throw(&mut cx)?;
|
||||
|
||||
let output = JsObject::new(&mut cx);
|
||||
if let Some(stats) = stats {
|
||||
let output = JsObject::new(&mut cx);
|
||||
let num_indexed_rows = cx.number(stats.num_indexed_rows as f64);
|
||||
output.set(&mut cx, "numIndexedRows", num_indexed_rows)?;
|
||||
let num_unindexed_rows = cx.number(stats.num_unindexed_rows as f64);
|
||||
output.set(&mut cx, "numUnindexedRows", num_unindexed_rows)?;
|
||||
if let Some(distance_type) = stats.distance_type {
|
||||
let distance_type = cx.string(distance_type.to_string());
|
||||
output.set(&mut cx, "distanceType", distance_type)?;
|
||||
}
|
||||
let index_type = cx.string(stats.index_type.to_string());
|
||||
output.set(&mut cx, "indexType", index_type)?;
|
||||
|
||||
match indexed_rows {
|
||||
Some(x) => {
|
||||
let i = cx.number(x as f64);
|
||||
output.set(&mut cx, "numIndexedRows", i)?;
|
||||
if let Some(num_indices) = stats.num_indices {
|
||||
let num_indices = cx.number(num_indices as f64);
|
||||
output.set(&mut cx, "numIndices", num_indices)?;
|
||||
}
|
||||
None => {
|
||||
let null = cx.null();
|
||||
output.set(&mut cx, "numIndexedRows", null)?;
|
||||
}
|
||||
};
|
||||
|
||||
match unindexed_rows {
|
||||
Some(x) => {
|
||||
let i = cx.number(x as f64);
|
||||
output.set(&mut cx, "numUnindexedRows", i)?;
|
||||
}
|
||||
None => {
|
||||
let null = cx.null();
|
||||
output.set(&mut cx, "numUnindexedRows", null)?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
Ok(output.as_value(&mut cx))
|
||||
} else {
|
||||
Ok(JsNull::new(&mut cx).as_value(&mut cx))
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.10.0"
|
||||
version = "0.11.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -19,6 +19,7 @@ arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
chrono = { workspace = true }
|
||||
datafusion-common.workspace = true
|
||||
datafusion-physical-plan.workspace = true
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
@@ -31,6 +32,7 @@ lance-table = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
lance-encoding = { workspace = true }
|
||||
moka = { workspace = true}
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log.workspace = true
|
||||
@@ -46,7 +48,9 @@ async-openai = { version = "0.20.0", optional = true }
|
||||
serde_with = { version = "3.8.1" }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
rand = { version = "0.8.3", features = ["small_rng"], optional = true}
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
uuid = { version = "1.7.0", features = ["v4"], optional = true }
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
hf-hub = { version = "0.3.2", optional = true }
|
||||
@@ -70,7 +74,7 @@ http-body = "1" # Matching reqwest
|
||||
|
||||
[features]
|
||||
default = []
|
||||
remote = ["dep:reqwest", "dep:http"]
|
||||
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
openai = ["dep:async-openai", "dep:reqwest"]
|
||||
|
||||
@@ -32,6 +32,8 @@ use crate::embeddings::{
|
||||
};
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::remote::client::ClientConfig;
|
||||
use crate::table::{NativeTable, TableDefinition, WriteOptions};
|
||||
use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
@@ -431,6 +433,7 @@ pub(crate) trait ConnectionInternal:
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<Table>;
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table>;
|
||||
async fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>;
|
||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||
async fn drop_db(&self) -> Result<()>;
|
||||
|
||||
@@ -513,6 +516,19 @@ impl Connection {
|
||||
OpenTableBuilder::new(self.internal.clone(), name.into())
|
||||
}
|
||||
|
||||
/// Rename a table in the database.
|
||||
///
|
||||
/// This is only supported in LanceDB Cloud.
|
||||
pub async fn rename_table(
|
||||
&self,
|
||||
old_name: impl AsRef<str>,
|
||||
new_name: impl AsRef<str>,
|
||||
) -> Result<()> {
|
||||
self.internal
|
||||
.rename_table(old_name.as_ref(), new_name.as_ref())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Drop a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -553,6 +569,8 @@ pub struct ConnectBuilder {
|
||||
region: Option<String>,
|
||||
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
|
||||
host_override: Option<String>,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: ClientConfig,
|
||||
|
||||
storage_options: HashMap<String, String>,
|
||||
|
||||
@@ -578,6 +596,8 @@ impl ConnectBuilder {
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
#[cfg(feature = "remote")]
|
||||
client_config: Default::default(),
|
||||
read_consistency_interval: None,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry: None,
|
||||
@@ -599,6 +619,30 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the LanceDB Cloud client configuration.
|
||||
///
|
||||
/// ```
|
||||
/// # use lancedb::connect;
|
||||
/// # use lancedb::remote::*;
|
||||
/// connect("db://my_database")
|
||||
/// .client_config(ClientConfig {
|
||||
/// timeout_config: TimeoutConfig {
|
||||
/// connect_timeout: Some(std::time::Duration::from_secs(5)),
|
||||
/// ..Default::default()
|
||||
/// },
|
||||
/// retry_config: RetryConfig {
|
||||
/// retries: Some(5),
|
||||
/// ..Default::default()
|
||||
/// },
|
||||
/// ..Default::default()
|
||||
/// });
|
||||
/// ```
|
||||
#[cfg(feature = "remote")]
|
||||
pub fn client_config(mut self, config: ClientConfig) -> Self {
|
||||
self.client_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
@@ -671,12 +715,14 @@ impl ConnectBuilder {
|
||||
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
// TODO: remove this warning when the remote client is ready
|
||||
warn!("The rust implementation of the remote client is not yet ready for use.");
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
&self.uri,
|
||||
&api_key,
|
||||
®ion,
|
||||
self.host_override,
|
||||
self.client_config,
|
||||
)?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
@@ -1066,6 +1112,12 @@ impl ConnectionInternal for Database {
|
||||
Ok(Table::new(native_table))
|
||||
}
|
||||
|
||||
async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "rename_table is not supported in LanceDB OSS".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
let dir_name = format!("{}.{}", name, LANCE_EXTENSION);
|
||||
let full_path = self.base_path.child(dir_name.clone());
|
||||
|
||||
@@ -46,8 +46,37 @@ pub enum Error {
|
||||
ObjectStore { source: object_store::Error },
|
||||
#[snafu(display("lance error: {source}"))]
|
||||
Lance { source: lance::Error },
|
||||
#[snafu(display("Http error: {message}"))]
|
||||
Http { message: String },
|
||||
#[cfg(feature = "remote")]
|
||||
#[snafu(display("Http error: (request_id={request_id}) {source}"))]
|
||||
Http {
|
||||
#[snafu(source(from(reqwest::Error, Box::new)))]
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
request_id: String,
|
||||
/// Status code associated with the error, if available.
|
||||
/// This is not always available, for example when the error is due to a
|
||||
/// connection failure. It may also be missing if the request was
|
||||
/// successful but there was an error decoding the response.
|
||||
status_code: Option<reqwest::StatusCode>,
|
||||
},
|
||||
#[cfg(feature = "remote")]
|
||||
#[snafu(display(
|
||||
"Hit retry limit for request_id={request_id} (\
|
||||
request_failures={request_failures}/{max_request_failures}, \
|
||||
connect_failures={connect_failures}/{max_connect_failures}, \
|
||||
read_failures={read_failures}/{max_read_failures})"
|
||||
))]
|
||||
Retry {
|
||||
request_id: String,
|
||||
request_failures: u8,
|
||||
max_request_failures: u8,
|
||||
connect_failures: u8,
|
||||
max_connect_failures: u8,
|
||||
read_failures: u8,
|
||||
max_read_failures: u8,
|
||||
#[snafu(source(from(reqwest::Error, Box::new)))]
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
status_code: Option<reqwest::StatusCode>,
|
||||
},
|
||||
#[snafu(display("Arrow error: {source}"))]
|
||||
Arrow { source: ArrowError },
|
||||
#[snafu(display("LanceDBError: not supported: {message}"))]
|
||||
@@ -98,24 +127,6 @@ impl<T> From<PoisonError<T>> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
Self::Http {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<url::ParseError> for Error {
|
||||
fn from(e: url::ParseError) -> Self {
|
||||
Self::Http {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl From<polars::prelude::PolarsError> for Error {
|
||||
fn from(source: polars::prelude::PolarsError) -> Self {
|
||||
|
||||
@@ -18,7 +18,7 @@ use scalar::FtsIndexBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
|
||||
use crate::{table::TableInternal, Result};
|
||||
use crate::{table::TableInternal, DistanceType, Error, Result};
|
||||
|
||||
use self::{
|
||||
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
|
||||
@@ -102,19 +102,61 @@ impl IndexBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub enum IndexType {
|
||||
// Vector
|
||||
#[serde(alias = "IVF_PQ")]
|
||||
IvfPq,
|
||||
#[serde(alias = "IVF_HNSW_PQ")]
|
||||
IvfHnswPq,
|
||||
#[serde(alias = "IVF_HNSW_SQ")]
|
||||
IvfHnswSq,
|
||||
// Scalar
|
||||
#[serde(alias = "BTREE")]
|
||||
BTree,
|
||||
#[serde(alias = "BITMAP")]
|
||||
Bitmap,
|
||||
#[serde(alias = "LABEL_LIST")]
|
||||
LabelList,
|
||||
// FTS
|
||||
FTS,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for IndexType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::IvfPq => write!(f, "IVF_PQ"),
|
||||
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
||||
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
|
||||
Self::BTree => write!(f, "BTREE"),
|
||||
Self::Bitmap => write!(f, "BITMAP"),
|
||||
Self::LabelList => write!(f, "LABEL_LIST"),
|
||||
Self::FTS => write!(f, "FTS"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for IndexType {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self> {
|
||||
match value.to_uppercase().as_str() {
|
||||
"BTREE" => Ok(Self::BTree),
|
||||
"BITMAP" => Ok(Self::Bitmap),
|
||||
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
|
||||
"FTS" => Ok(Self::FTS),
|
||||
"IVF_PQ" => Ok(Self::IvfPq),
|
||||
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
||||
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
|
||||
_ => Err(Error::InvalidInput {
|
||||
message: format!("the input value {} is not a valid IndexType", value),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A description of an index currently configured on a column
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct IndexConfig {
|
||||
/// The name of the index
|
||||
pub name: String,
|
||||
@@ -129,16 +171,39 @@ pub struct IndexConfig {
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct IndexMetadata {
|
||||
pub metric_type: Option<String>,
|
||||
pub index_type: Option<String>,
|
||||
pub(crate) struct IndexMetadata {
|
||||
pub metric_type: Option<DistanceType>,
|
||||
// Sometimes the index type is provided at this level.
|
||||
pub index_type: Option<IndexType>,
|
||||
}
|
||||
|
||||
// This struct is used to deserialize the JSON data returned from the Lance API
|
||||
// Dataset::index_statistics().
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct IndexStatisticsImpl {
|
||||
pub num_indexed_rows: usize,
|
||||
pub num_unindexed_rows: usize,
|
||||
pub indices: Vec<IndexMetadata>,
|
||||
// Sometimes, the index type is provided at this level.
|
||||
pub index_type: Option<IndexType>,
|
||||
pub num_indices: Option<u32>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, PartialEq)]
|
||||
pub struct IndexStatistics {
|
||||
/// The number of rows in the table that are covered by this index.
|
||||
pub num_indexed_rows: usize,
|
||||
/// The number of rows in the table that are not covered by this index.
|
||||
/// These are rows that haven't yet been added to the index.
|
||||
pub num_unindexed_rows: usize,
|
||||
pub index_type: Option<String>,
|
||||
pub indices: Vec<IndexMetadata>,
|
||||
/// The type of the index.
|
||||
pub index_type: IndexType,
|
||||
/// The distance type used by the index.
|
||||
///
|
||||
/// This is only present for vector indices.
|
||||
pub distance_type: Option<DistanceType>,
|
||||
/// The number of parts this index is split into.
|
||||
pub num_indices: Option<u32>,
|
||||
}
|
||||
|
||||
@@ -214,6 +214,11 @@ pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
|
||||
max(1, num_partitions)
|
||||
}
|
||||
|
||||
pub(crate) fn suggested_num_partitions_for_hnsw(rows: usize, dim: u32) -> u32 {
|
||||
let num_partitions = (((rows as u64) * (dim as u64)) / (256 * 5_000_000)) as u32;
|
||||
max(1, num_partitions)
|
||||
}
|
||||
|
||||
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||
if dim % 16 == 0 {
|
||||
// Should be more aggressive than this default.
|
||||
|
||||
@@ -213,7 +213,7 @@ pub mod ipc;
|
||||
mod polars_arrow_convertors;
|
||||
pub mod query;
|
||||
#[cfg(feature = "remote")]
|
||||
pub(crate) mod remote;
|
||||
pub mod remote;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
@@ -228,6 +228,7 @@ pub use table::Table;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum DistanceType {
|
||||
/// Euclidean distance. This is a very common distance metric that
|
||||
/// accounts for both magnitude and direction when determining the distance
|
||||
@@ -253,6 +254,12 @@ pub enum DistanceType {
|
||||
Hamming,
|
||||
}
|
||||
|
||||
impl Default for DistanceType {
|
||||
fn default() -> Self {
|
||||
Self::L2
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DistanceType> for LanceDistanceType {
|
||||
fn from(value: DistanceType) -> Self {
|
||||
match value {
|
||||
|
||||
@@ -402,6 +402,9 @@ pub trait QueryBase {
|
||||
///
|
||||
/// By default, it is false.
|
||||
fn fast_search(self) -> Self;
|
||||
|
||||
/// Return the `_rowid` meta column from the Table.
|
||||
fn with_row_id(self) -> Self;
|
||||
}
|
||||
|
||||
pub trait HasQuery {
|
||||
@@ -438,6 +441,11 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self.mut_query().fast_search = true;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_row_id(mut self) -> Self {
|
||||
self.mut_query().with_row_id = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for controlling the execution of a query
|
||||
@@ -548,6 +556,11 @@ pub struct Query {
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) fast_search: bool,
|
||||
|
||||
/// If set to true, the query will return the `_rowid` meta column.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) with_row_id: bool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -560,6 +573,7 @@ impl Query {
|
||||
full_text_search: None,
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1160,4 +1174,24 @@ mod tests {
|
||||
.unwrap();
|
||||
assert!(!plan.contains("Take"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_row_id() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let results = table
|
||||
.vector_search(&[0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.with_row_id()
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
for batch in results {
|
||||
assert!(batch.column_by_name("_rowid").is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,12 @@
|
||||
//! building client/server applications with LanceDB or as a client for some
|
||||
//! other custom LanceDB service.
|
||||
|
||||
pub mod client;
|
||||
pub mod db;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
pub(crate) mod client;
|
||||
pub(crate) mod db;
|
||||
pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
|
||||
@@ -14,13 +14,152 @@
|
||||
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use log::debug;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
RequestBuilder, Response,
|
||||
Request, RequestBuilder, Response,
|
||||
};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
|
||||
/// Configuration for the LanceDB Cloud HTTP client.
|
||||
#[derive(Debug)]
|
||||
pub struct ClientConfig {
|
||||
pub timeout_config: TimeoutConfig,
|
||||
pub retry_config: RetryConfig,
|
||||
/// User agent to use for requests. The default provides the library
|
||||
/// name and version.
|
||||
pub user_agent: String,
|
||||
// TODO: how to configure request ids?
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
timeout_config: TimeoutConfig::default(),
|
||||
retry_config: RetryConfig::default(),
|
||||
user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// How to handle timeouts for HTTP requests.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct TimeoutConfig {
|
||||
/// The timeout for creating a connection to the server.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_CONNECT_TIMEOUT` environment variable
|
||||
/// to set this value. Use an integer value in seconds.
|
||||
///
|
||||
/// The default is 120 seconds (2 minutes).
|
||||
pub connect_timeout: Option<Duration>,
|
||||
/// The timeout for reading a response from the server.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_READ_TIMEOUT` environment variable
|
||||
/// to set this value. Use an integer value in seconds.
|
||||
///
|
||||
/// The default is 300 seconds (5 minutes).
|
||||
pub read_timeout: Option<Duration>,
|
||||
/// The timeout for keeping idle connections alive.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_CONNECTION_TIMEOUT` environment variable
|
||||
/// to set this value. Use an integer value in seconds.
|
||||
///
|
||||
/// The default is 300 seconds (5 minutes).
|
||||
pub pool_idle_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
/// How to handle retries for HTTP requests.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct RetryConfig {
|
||||
/// The number of times to retry a request if it fails.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_MAX_RETRIES` environment variable
|
||||
/// to set this value. Use an integer value.
|
||||
///
|
||||
/// The default is 3 retries.
|
||||
pub retries: Option<u8>,
|
||||
/// The number of times to retry a request if it fails to connect.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_CONNECT_RETRIES` environment variable
|
||||
/// to set this value. Use an integer value.
|
||||
///
|
||||
/// The default is 3 retries.
|
||||
pub connect_retries: Option<u8>,
|
||||
/// The number of times to retry a request if it fails to read.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_READ_RETRIES` environment variable
|
||||
/// to set this value. Use an integer value.
|
||||
///
|
||||
/// The default is 3 retries.
|
||||
pub read_retries: Option<u8>,
|
||||
/// The exponential backoff factor to use when retrying requests.
|
||||
///
|
||||
/// Between each retry, the client will wait for the amount of seconds:
|
||||
///
|
||||
/// ```text
|
||||
/// {backoff factor} * (2 ** ({number of previous retries}))
|
||||
/// ```
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_RETRY_BACKOFF_FACTOR` environment variable
|
||||
/// to set this value. Use a float value.
|
||||
///
|
||||
/// The default is 0.25. So the first retry will wait 0.25 seconds, the second
|
||||
/// retry will wait 0.5 seconds, the third retry will wait 1 second, etc.
|
||||
pub backoff_factor: Option<f32>,
|
||||
/// The backoff jitter factor to use when retrying requests.
|
||||
///
|
||||
/// The backoff jitter is a random value between 0 and the jitter factor in
|
||||
/// seconds.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_RETRY_BACKOFF_JITTER` environment variable
|
||||
/// to set this value. Use a float value.
|
||||
///
|
||||
/// The default is 0.25. So between 0 and 0.25 seconds will be added to the
|
||||
/// sleep time between retries.
|
||||
pub backoff_jitter: Option<f32>,
|
||||
/// The set of status codes to retry on.
|
||||
///
|
||||
/// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable
|
||||
/// to set this value. Use a comma-separated list of integer values.
|
||||
///
|
||||
/// The default is 429, 500, 502, 503.
|
||||
pub statuses: Option<Vec<u16>>,
|
||||
// TODO: should we allow customizing methods?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ResolvedRetryConfig {
|
||||
retries: u8,
|
||||
connect_retries: u8,
|
||||
read_retries: u8,
|
||||
backoff_factor: f32,
|
||||
backoff_jitter: f32,
|
||||
statuses: Vec<reqwest::StatusCode>,
|
||||
}
|
||||
|
||||
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(retry_config: RetryConfig) -> Result<Self> {
|
||||
Ok(Self {
|
||||
retries: retry_config.retries.unwrap_or(3),
|
||||
connect_retries: retry_config.connect_retries.unwrap_or(3),
|
||||
read_retries: retry_config.read_retries.unwrap_or(3),
|
||||
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
|
||||
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
|
||||
statuses: retry_config
|
||||
.statuses
|
||||
.unwrap_or_else(|| vec![429, 500, 502, 503])
|
||||
.into_iter()
|
||||
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
|
||||
// we can mock responses in tests. Based on the patterns from this blog post:
|
||||
// https://write.as/balrogboogie/testing-reqwest-based-clients
|
||||
@@ -28,53 +167,110 @@ use crate::error::{Error, Result};
|
||||
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
|
||||
client: reqwest::Client,
|
||||
host: String,
|
||||
retry_config: ResolvedRetryConfig,
|
||||
sender: S,
|
||||
}
|
||||
|
||||
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
|
||||
fn send(&self, req: RequestBuilder) -> impl Future<Output = Result<Response>> + Send;
|
||||
fn send(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
request: reqwest::Request,
|
||||
) -> impl Future<Output = reqwest::Result<Response>> + Send;
|
||||
}
|
||||
|
||||
// Default implementation of HttpSend which sends the request normally with reqwest
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Sender;
|
||||
impl HttpSend for Sender {
|
||||
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
|
||||
Ok(request.send().await?)
|
||||
async fn send(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
request: reqwest::Request,
|
||||
) -> reqwest::Result<reqwest::Response> {
|
||||
client.execute(request).await
|
||||
}
|
||||
}
|
||||
|
||||
impl RestfulLanceDbClient<Sender> {
|
||||
fn get_timeout(passed: Option<Duration>, env_var: &str, default: Duration) -> Result<Duration> {
|
||||
if let Some(passed) = passed {
|
||||
Ok(passed)
|
||||
} else if let Ok(timeout) = std::env::var(env_var) {
|
||||
let timeout = timeout.parse::<u64>().map_err(|_| Error::InvalidInput {
|
||||
message: format!(
|
||||
"Invalid value for {} environment variable: '{}'",
|
||||
env_var, timeout
|
||||
),
|
||||
})?;
|
||||
Ok(Duration::from_secs(timeout))
|
||||
} else {
|
||||
Ok(default)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_new(
|
||||
db_url: &str,
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
client_config: ClientConfig,
|
||||
) -> Result<Self> {
|
||||
let parsed_url = url::Url::parse(db_url)?;
|
||||
let parsed_url = url::Url::parse(db_url).map_err(|err| Error::InvalidInput {
|
||||
message: format!("db_url is not a valid URL. '{db_url}'. Error: {err}"),
|
||||
})?;
|
||||
debug_assert_eq!(parsed_url.scheme(), "db");
|
||||
if !parsed_url.has_host() {
|
||||
return Err(Error::Http {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!("Invalid database URL (missing host) '{}'", db_url),
|
||||
});
|
||||
}
|
||||
let db_name = parsed_url.host_str().unwrap();
|
||||
|
||||
// Get the timeouts
|
||||
let connect_timeout = Self::get_timeout(
|
||||
client_config.timeout_config.connect_timeout,
|
||||
"LANCE_CLIENT_CONNECT_TIMEOUT",
|
||||
Duration::from_secs(120),
|
||||
)?;
|
||||
let read_timeout = Self::get_timeout(
|
||||
client_config.timeout_config.read_timeout,
|
||||
"LANCE_CLIENT_READ_TIMEOUT",
|
||||
Duration::from_secs(300),
|
||||
)?;
|
||||
let pool_idle_timeout = Self::get_timeout(
|
||||
client_config.timeout_config.pool_idle_timeout,
|
||||
// Though it's confusing with the connect_timeout name, this is the
|
||||
// legacy name for this in the Python sync client. So we keep as-is.
|
||||
"LANCE_CLIENT_CONNECTION_TIMEOUT",
|
||||
Duration::from_secs(300),
|
||||
)?;
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.connect_timeout(connect_timeout)
|
||||
.read_timeout(read_timeout)
|
||||
.pool_idle_timeout(pool_idle_timeout)
|
||||
.default_headers(Self::default_headers(
|
||||
api_key,
|
||||
region,
|
||||
db_name,
|
||||
host_override.is_some(),
|
||||
)?)
|
||||
.build()?;
|
||||
.user_agent(client_config.user_agent)
|
||||
.build()
|
||||
.map_err(|err| Error::Other {
|
||||
message: "Failed to build HTTP client".into(),
|
||||
source: Some(Box::new(err)),
|
||||
})?;
|
||||
let host = match host_override {
|
||||
Some(host_override) => host_override,
|
||||
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
|
||||
};
|
||||
let retry_config = client_config.retry_config.try_into()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
retry_config,
|
||||
sender: Sender,
|
||||
})
|
||||
}
|
||||
@@ -94,7 +290,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"x-api-key",
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::Http {
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
@@ -102,7 +298,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||
headers.insert(
|
||||
"Host",
|
||||
HeaderValue::from_str(&host).map_err(|_| Error::Http {
|
||||
HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii database name '{}' provided", db_name),
|
||||
})?,
|
||||
);
|
||||
@@ -110,7 +306,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
if has_host_override {
|
||||
headers.insert(
|
||||
"x-lancedb-database",
|
||||
HeaderValue::from_str(db_name).map_err(|_| Error::Http {
|
||||
HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii database name '{}' provided", db_name),
|
||||
})?,
|
||||
);
|
||||
@@ -129,29 +325,209 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
self.client.post(full_uri)
|
||||
}
|
||||
|
||||
pub async fn send(&self, req: RequestBuilder) -> Result<Response> {
|
||||
self.sender.send(req).await
|
||||
pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
|
||||
let (client, request) = req.build_split();
|
||||
let mut request = request.unwrap();
|
||||
|
||||
// Set a request id.
|
||||
// TODO: allow the user to supply this, through middleware?
|
||||
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
|
||||
request_id.to_str().unwrap().to_string()
|
||||
} else {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let header = HeaderValue::from_str(&request_id).unwrap();
|
||||
request.headers_mut().insert(REQUEST_ID_HEADER, header);
|
||||
request_id
|
||||
};
|
||||
|
||||
if with_retry {
|
||||
self.send_with_retry_impl(client, request, request_id).await
|
||||
} else {
|
||||
let response = self
|
||||
.sender
|
||||
.send(&client, request)
|
||||
.await
|
||||
.err_to_http(request_id.clone())?;
|
||||
Ok((request_id, response))
|
||||
}
|
||||
}
|
||||
|
||||
async fn rsp_to_str(response: Response) -> String {
|
||||
async fn send_with_retry_impl(
|
||||
&self,
|
||||
client: reqwest::Client,
|
||||
req: Request,
|
||||
request_id: String,
|
||||
) -> Result<(String, Response)> {
|
||||
let mut retry_counter = RetryCounter::new(&self.retry_config, request_id);
|
||||
|
||||
loop {
|
||||
// This only works if the request body is not a stream. If it is
|
||||
// a stream, we can't use the retry path. We would need to implement
|
||||
// an outer retry.
|
||||
let request = req.try_clone().ok_or_else(|| Error::Runtime {
|
||||
message: "Attempted to retry a request that cannot be cloned".to_string(),
|
||||
})?;
|
||||
let response = self
|
||||
.sender
|
||||
.send(&client, request)
|
||||
.await
|
||||
.map(|r| (r.status(), r));
|
||||
match response {
|
||||
Ok((status, response)) if status.is_success() => {
|
||||
return Ok((retry_counter.request_id, response))
|
||||
}
|
||||
Ok((status, response)) if self.retry_config.statuses.contains(&status) => {
|
||||
let source = self
|
||||
.check_response(&retry_counter.request_id, response)
|
||||
.await
|
||||
.unwrap_err();
|
||||
retry_counter.increment_request_failures(source)?;
|
||||
}
|
||||
Err(err) if err.is_connect() => {
|
||||
retry_counter.increment_connect_failures(err)?;
|
||||
}
|
||||
Err(err) if err.is_timeout() || err.is_body() || err.is_decode() => {
|
||||
retry_counter.increment_read_failures(err)?;
|
||||
}
|
||||
Err(err) => {
|
||||
let status_code = err.status();
|
||||
return Err(Error::Http {
|
||||
source: Box::new(err),
|
||||
request_id: retry_counter.request_id,
|
||||
status_code,
|
||||
});
|
||||
}
|
||||
Ok((_, response)) => return Ok((retry_counter.request_id, response)),
|
||||
}
|
||||
|
||||
let sleep_time = retry_counter.next_sleep_time();
|
||||
tokio::time::sleep(sleep_time).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_response(&self, request_id: &str, response: Response) -> Result<Response> {
|
||||
// Try to get the response text, but if that fails, just return the status code
|
||||
let status = response.status();
|
||||
response.text().await.unwrap_or_else(|_| status.to_string())
|
||||
if status.is_success() {
|
||||
Ok(response)
|
||||
} else {
|
||||
let response_text = response.text().await.ok();
|
||||
let message = if let Some(response_text) = response_text {
|
||||
format!("{}: {}", status, response_text)
|
||||
} else {
|
||||
status.to_string()
|
||||
};
|
||||
Err(Error::Http {
|
||||
source: message.into(),
|
||||
request_id: request_id.into(),
|
||||
status_code: Some(status),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RetryCounter<'a> {
|
||||
request_failures: u8,
|
||||
connect_failures: u8,
|
||||
read_failures: u8,
|
||||
config: &'a ResolvedRetryConfig,
|
||||
request_id: String,
|
||||
}
|
||||
|
||||
impl<'a> RetryCounter<'a> {
|
||||
fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
|
||||
Self {
|
||||
request_failures: 0,
|
||||
connect_failures: 0,
|
||||
read_failures: 0,
|
||||
config,
|
||||
request_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_response(&self, response: Response) -> Result<Response> {
|
||||
let status_int: u16 = u16::from(response.status());
|
||||
if (400..500).contains(&status_int) {
|
||||
Err(Error::InvalidInput {
|
||||
message: Self::rsp_to_str(response).await,
|
||||
})
|
||||
} else if status_int != 200 {
|
||||
Err(Error::Runtime {
|
||||
message: Self::rsp_to_str(response).await,
|
||||
fn check_out_of_retries(
|
||||
&self,
|
||||
source: Box<dyn std::error::Error + Send + Sync>,
|
||||
status_code: Option<reqwest::StatusCode>,
|
||||
) -> Result<()> {
|
||||
if self.request_failures >= self.config.retries
|
||||
|| self.connect_failures >= self.config.connect_retries
|
||||
|| self.read_failures >= self.config.read_retries
|
||||
{
|
||||
Err(Error::Retry {
|
||||
request_id: self.request_id.clone(),
|
||||
request_failures: self.request_failures,
|
||||
max_request_failures: self.config.retries,
|
||||
connect_failures: self.connect_failures,
|
||||
max_connect_failures: self.config.connect_retries,
|
||||
read_failures: self.read_failures,
|
||||
max_read_failures: self.config.read_retries,
|
||||
source,
|
||||
status_code,
|
||||
})
|
||||
} else {
|
||||
Ok(response)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> {
|
||||
self.request_failures += 1;
|
||||
let status_code = if let crate::Error::Http { status_code, .. } = &source {
|
||||
*status_code
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> {
|
||||
self.connect_failures += 1;
|
||||
let status_code = source.status();
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> {
|
||||
self.read_failures += 1;
|
||||
let status_code = source.status();
|
||||
self.check_out_of_retries(Box::new(source), status_code)
|
||||
}
|
||||
|
||||
fn next_sleep_time(&self) -> Duration {
|
||||
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
|
||||
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
|
||||
let sleep_time = Duration::from_secs_f32(backoff + jitter);
|
||||
debug!(
|
||||
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
|
||||
self.request_id,
|
||||
self.connect_failures,
|
||||
self.config.connect_retries,
|
||||
self.request_failures,
|
||||
self.config.retries,
|
||||
self.read_failures,
|
||||
self.config.read_retries,
|
||||
sleep_time
|
||||
);
|
||||
sleep_time
|
||||
}
|
||||
}
|
||||
|
||||
pub trait RequestResultExt {
|
||||
type Output;
|
||||
fn err_to_http(self, request_id: String) -> Result<Self::Output>;
|
||||
}
|
||||
|
||||
impl<T> RequestResultExt for reqwest::Result<T> {
|
||||
type Output = T;
|
||||
fn err_to_http(self, request_id: String) -> Result<T> {
|
||||
self.map_err(|err| {
|
||||
let status_code = err.status();
|
||||
Error::Http {
|
||||
source: Box::new(err),
|
||||
request_id,
|
||||
status_code,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -172,8 +548,11 @@ pub mod test_utils {
|
||||
}
|
||||
|
||||
impl HttpSend for MockSender {
|
||||
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
|
||||
let request = request.build().unwrap();
|
||||
async fn send(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
request: reqwest::Request,
|
||||
) -> reqwest::Result<reqwest::Response> {
|
||||
let response = (self.f)(request);
|
||||
Ok(response)
|
||||
}
|
||||
@@ -193,6 +572,7 @@ pub mod test_utils {
|
||||
RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "http://localhost".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: MockSender {
|
||||
f: Arc::new(wrapper),
|
||||
},
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use moka::future::Cache;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use serde::Deserialize;
|
||||
use tokio::task::spawn_blocking;
|
||||
@@ -28,7 +29,7 @@ use crate::embeddings::EmbeddingRegistry;
|
||||
use crate::error::Result;
|
||||
use crate::Table;
|
||||
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::batches_to_ipc_bytes;
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
@@ -41,6 +42,7 @@ struct ListTablesResponse {
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
client: RestfulLanceDbClient<S>,
|
||||
table_cache: Cache<String, ()>,
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
@@ -49,9 +51,20 @@ impl RemoteDatabase {
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
client_config: ClientConfig,
|
||||
) -> Result<Self> {
|
||||
let client = RestfulLanceDbClient::try_new(uri, api_key, region, host_override)?;
|
||||
Ok(Self { client })
|
||||
let client =
|
||||
RestfulLanceDbClient::try_new(uri, api_key, region, host_override, client_config)?;
|
||||
|
||||
let table_cache = Cache::builder()
|
||||
.time_to_live(std::time::Duration::from_secs(300))
|
||||
.max_capacity(10_000)
|
||||
.build();
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
table_cache,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +81,10 @@ mod test_utils {
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler(handler);
|
||||
Self { client }
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -89,9 +105,17 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
if let Some(start_after) = options.start_after {
|
||||
req = req.query(&[("page_token", start_after)]);
|
||||
}
|
||||
let rsp = self.client.send(req).await?;
|
||||
let rsp = self.client.check_response(rsp).await?;
|
||||
Ok(rsp.json::<ListTablesResponse>().await?.tables)
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let tables = rsp
|
||||
.json::<ListTablesResponse>()
|
||||
.await
|
||||
.err_to_http(request_id)?
|
||||
.tables;
|
||||
for table in &tables {
|
||||
self.table_cache.insert(table.clone(), ()).await;
|
||||
}
|
||||
Ok(tables)
|
||||
}
|
||||
|
||||
async fn do_create_table(
|
||||
@@ -110,13 +134,11 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/create/", options.name))
|
||||
.body(data_buffer)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
// This is currently expected by LanceDb cloud but will be removed soon.
|
||||
.header("x-request-id", "na");
|
||||
let rsp = self.client.send(req).await?;
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
let (request_id, rsp) = self.client.send(req, false).await?;
|
||||
|
||||
if rsp.status() == StatusCode::BAD_REQUEST {
|
||||
let body = rsp.text().await?;
|
||||
let body = rsp.text().await.err_to_http(request_id.clone())?;
|
||||
if body.contains("already exists") {
|
||||
return Err(crate::Error::TableAlreadyExists { name: options.name });
|
||||
} else {
|
||||
@@ -124,7 +146,9 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
}
|
||||
}
|
||||
|
||||
self.client.check_response(rsp).await?;
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
|
||||
self.table_cache.insert(options.name.clone(), ()).await;
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
@@ -134,25 +158,40 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table> {
|
||||
// We describe the table to confirm it exists before moving on.
|
||||
// TODO: a TTL cache of table existence
|
||||
let req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/describe/", options.name));
|
||||
let resp = self.client.send(req).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: options.name });
|
||||
if self.table_cache.get(&options.name).is_none() {
|
||||
let req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/describe/", options.name));
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: options.name });
|
||||
}
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
}
|
||||
self.client.check_response(resp).await?;
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
options.name,
|
||||
))))
|
||||
}
|
||||
|
||||
async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/rename/", current_name));
|
||||
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
|
||||
let (request_id, resp) = self.client.send(req, false).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
self.table_cache.remove(current_name).await;
|
||||
self.table_cache.insert(new_name.into(), ()).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
|
||||
let resp = self.client.send(req).await?;
|
||||
self.client.check_response(resp).await?;
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
self.table_cache.remove(name).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -169,12 +208,56 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
use crate::{remote::db::ARROW_STREAM_CONTENT_TYPE, Connection};
|
||||
use crate::{
|
||||
remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
|
||||
Connection, Error,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retries() {
|
||||
// We'll record the request_id here, to check it matches the one in the error.
|
||||
let seen_request_id = Arc::new(OnceLock::new());
|
||||
let seen_request_id_ref = seen_request_id.clone();
|
||||
let conn = Connection::new_with_handler(move |request| {
|
||||
// Request id should be the same on each retry.
|
||||
let request_id = request.headers()["x-request-id"]
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let seen_id = seen_request_id_ref.get_or_init(|| request_id.clone());
|
||||
assert_eq!(&request_id, seen_id);
|
||||
|
||||
http::Response::builder()
|
||||
.status(500)
|
||||
.body("internal server error")
|
||||
.unwrap()
|
||||
});
|
||||
let result = conn.table_names().execute().await;
|
||||
if let Err(Error::Retry {
|
||||
request_id,
|
||||
request_failures,
|
||||
max_request_failures,
|
||||
source,
|
||||
..
|
||||
}) = result
|
||||
{
|
||||
let expected_id = seen_request_id.get().unwrap();
|
||||
assert_eq!(&request_id, expected_id);
|
||||
assert_eq!(request_failures, max_request_failures);
|
||||
assert!(
|
||||
source.to_string().contains("internal server error"),
|
||||
"source: {:?}",
|
||||
source
|
||||
);
|
||||
} else {
|
||||
panic!("unexpected result: {:?}", result);
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names() {
|
||||
@@ -334,4 +417,23 @@ mod tests {
|
||||
conn.drop_table("table1").await.unwrap();
|
||||
// NOTE: the API will return 200 even if the table does not exist. So we shouldn't expect 404.
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rename_table() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/table1/rename/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(body["new_table_name"], "table2");
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
conn.rename_table("table1", "table2").await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::table::dataset::DatasetReadGuard;
|
||||
use crate::index::Index;
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::query::Select;
|
||||
use crate::table::AddDataMode;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::Error;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::SchemaRef;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use bytes::Buf;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
|
||||
use futures::TryStreamExt;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::StatusCode;
|
||||
use lance::arrow::json::JsonSchema;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
||||
use lance_datafusion::exec::OneShotExec;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -25,8 +34,9 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
@@ -41,16 +51,28 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
|
||||
async fn describe(&self) -> Result<TableDescription> {
|
||||
let request = self.client.post(&format!("/table/{}/describe/", self.name));
|
||||
let response = self.client.send(request).await?;
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", self.name));
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let response = self.check_table_response(response).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
message: format!("Failed to parse table description: {}", e),
|
||||
})
|
||||
match response.text().await {
|
||||
Ok(body) => serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
source: format!("Failed to parse table description: {}", e).into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
}),
|
||||
Err(err) => {
|
||||
let status_code = err.status();
|
||||
Err(Error::Http {
|
||||
source: Box::new(err),
|
||||
request_id,
|
||||
status_code,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reader_as_body(data: Box<dyn RecordBatchReader + Send>) -> Result<reqwest::Body> {
|
||||
@@ -76,14 +98,113 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
Ok(reqwest::Body::wrap_stream(body_stream))
|
||||
}
|
||||
|
||||
async fn check_table_response(&self, response: reqwest::Response) -> Result<reqwest::Response> {
|
||||
async fn check_table_response(
|
||||
&self,
|
||||
request_id: &str,
|
||||
response: reqwest::Response,
|
||||
) -> Result<reqwest::Response> {
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
self.client.check_response(response).await
|
||||
self.client.check_response(request_id, response).await
|
||||
}
|
||||
|
||||
async fn read_arrow_stream(
|
||||
&self,
|
||||
request_id: &str,
|
||||
body: reqwest::Response,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
// Assert that the content type is correct
|
||||
let content_type = body
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.ok_or_else(|| Error::Http {
|
||||
source: "Missing content type".into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?
|
||||
.to_str()
|
||||
.map_err(|e| Error::Http {
|
||||
source: format!("Failed to parse content type: {}", e).into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?;
|
||||
if content_type != ARROW_STREAM_CONTENT_TYPE {
|
||||
return Err(Error::Http {
|
||||
source: format!(
|
||||
"Expected content type {}, got {}",
|
||||
ARROW_STREAM_CONTENT_TYPE, content_type
|
||||
)
|
||||
.into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
});
|
||||
}
|
||||
|
||||
// There isn't a way to actually stream this data yet. I have an upstream issue:
|
||||
// https://github.com/apache/arrow-rs/issues/6420
|
||||
let body = body.bytes().await.err_to_http(request_id.into())?;
|
||||
let reader = StreamReader::try_new(body.reader(), None)?;
|
||||
let schema = reader.schema();
|
||||
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
|
||||
}
|
||||
|
||||
fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> {
|
||||
if params.offset.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Offset is not yet supported in LanceDB Cloud".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(limit) = params.limit {
|
||||
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
|
||||
}
|
||||
|
||||
if let Some(filter) = ¶ms.filter {
|
||||
body["filter"] = serde_json::Value::String(filter.clone());
|
||||
}
|
||||
|
||||
match ¶ms.select {
|
||||
Select::All => {}
|
||||
Select::Columns(columns) => {
|
||||
body["columns"] = serde_json::Value::Array(
|
||||
columns
|
||||
.iter()
|
||||
.map(|s| serde_json::Value::String(s.clone()))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
Select::Dynamic(pairs) => {
|
||||
body["columns"] = serde_json::Value::Array(
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(name, expr)| serde_json::json!([name, expr]))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if params.fast_search {
|
||||
body["fast_search"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
if let Some(full_text_search) = ¶ms.full_text_search {
|
||||
if full_text_search.wand_factor.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Wand factor is not yet supported in LanceDB Cloud".into(),
|
||||
});
|
||||
}
|
||||
body["full_text_query"] = serde_json::json!({
|
||||
"columns": full_text_search.columns,
|
||||
"query": full_text_search.query,
|
||||
})
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,7 +274,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/count_rows/", self.name));
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.name));
|
||||
|
||||
if let Some(filter) = filter {
|
||||
request = request.json(&serde_json::json!({ "filter": filter }));
|
||||
@@ -161,14 +282,16 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
request = request.json(&serde_json::json!({}));
|
||||
}
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let response = self.check_table_response(response).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
message: format!("Failed to parse row count: {}", e),
|
||||
source: format!("Failed to parse row count: {}", e).into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
})
|
||||
}
|
||||
async fn add(
|
||||
@@ -179,7 +302,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let body = Self::reader_as_body(data)?;
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/insert/", self.name))
|
||||
.post(&format!("/v1/table/{}/insert/", self.name))
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(body);
|
||||
|
||||
@@ -190,47 +313,89 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn build_plan(
|
||||
&self,
|
||||
_ds_ref: &DatasetReadGuard,
|
||||
_query: &VectorQuery,
|
||||
_options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner> {
|
||||
Err(Error::NotSupported {
|
||||
message: "build_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
_query: &VectorQuery,
|
||||
query: &VectorQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
Err(Error::NotSupported {
|
||||
message: "create_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result<String> {
|
||||
Err(Error::NotSupported {
|
||||
message: "explain_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
body["prefilter"] = query.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
|
||||
if let Some(vector) = query.query_vector.as_ref() {
|
||||
let vector: Vec<f32> = match vector.data_type() {
|
||||
DataType::Float32 => vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
.unwrap()
|
||||
.values()
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
_ => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
}
|
||||
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
}
|
||||
|
||||
if !query.use_index {
|
||||
body["bypass_vector_index"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
&self,
|
||||
_query: &Query,
|
||||
query: &Query,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
Err(Error::NotSupported {
|
||||
message: "plain_query is not yet supported on LanceDB cloud.".into(),
|
||||
})
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/query/", self.name))
|
||||
.header(CONTENT_TYPE, JSON_CONTENT_TYPE);
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
|
||||
Ok(DatasetRecordBatchStream::new(stream))
|
||||
}
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
|
||||
let request = self.client.post(&format!("/table/{}/update/", self.name));
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/update/", self.name));
|
||||
|
||||
let mut updates = Vec::new();
|
||||
for (column, expression) in update.columns {
|
||||
@@ -243,34 +408,107 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
"only_if": update.filter,
|
||||
}));
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
let response = self.check_table_response(response).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
message: format!(
|
||||
source: format!(
|
||||
"Failed to parse updated rows result from response {}: {}",
|
||||
body, e
|
||||
),
|
||||
)
|
||||
.into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
})
|
||||
}
|
||||
async fn delete(&self, predicate: &str) -> Result<()> {
|
||||
let body = serde_json::json!({ "predicate": predicate });
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/delete/", self.name))
|
||||
.post(&format!("/v1/table/{}/delete/", self.name))
|
||||
.json(&body);
|
||||
let response = self.client.send(request).await?;
|
||||
self.check_table_response(response).await?;
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
async fn create_index(&self, _index: IndexBuilder) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "create_index is not yet supported on LanceDB cloud.".into(),
|
||||
})
|
||||
|
||||
async fn create_index(&self, mut index: IndexBuilder) -> Result<()> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/create_index/", self.name));
|
||||
|
||||
let column = match index.columns.len() {
|
||||
0 => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "No columns specified".into(),
|
||||
})
|
||||
}
|
||||
1 => index.columns.pop().unwrap(),
|
||||
_ => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Indices over multiple columns not yet supported".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
let mut body = serde_json::json!({
|
||||
"column": column
|
||||
});
|
||||
|
||||
let (index_type, distance_type) = match index.index {
|
||||
// TODO: Should we pass the actual index parameters? SaaS does not
|
||||
// yet support them.
|
||||
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
||||
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
||||
Index::BTree(_) => ("BTREE", None),
|
||||
Index::Bitmap(_) => ("BITMAP", None),
|
||||
Index::LabelList(_) => ("LABEL_LIST", None),
|
||||
Index::FTS(_) => ("FTS", None),
|
||||
Index::Auto => {
|
||||
let schema = self.schema().await?;
|
||||
let field = schema
|
||||
.field_with_name(&column)
|
||||
.map_err(|_| Error::InvalidInput {
|
||||
message: format!("Column {} not found in schema", column),
|
||||
})?;
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
("IVF_PQ", None)
|
||||
} else if supported_btree_data_type(field.data_type()) {
|
||||
("BTREE", None)
|
||||
} else {
|
||||
return Err(Error::NotSupported {
|
||||
message: format!(
|
||||
"there are no indices supported for the field `{}` with the data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Index type not supported".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
body["index_type"] = serde_json::Value::String(index_type.into());
|
||||
if let Some(distance_type) = distance_type {
|
||||
// Phalanx expects this to be lowercase right now.
|
||||
body["metric_type"] =
|
||||
serde_json::Value::String(distance_type.to_string().to_lowercase());
|
||||
}
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
@@ -280,14 +518,14 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let body = Self::reader_as_body(new_data)?;
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/merge_insert/", self.name))
|
||||
.post(&format!("/v1/table/{}/merge_insert/", self.name))
|
||||
.query(&query)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -315,16 +553,91 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
message: "drop_columns is not yet supported.".into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
Err(Error::NotSupported {
|
||||
message: "list_indices is not yet supported.".into(),
|
||||
})
|
||||
// Make request to list the indices
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/index/list/", self.name));
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListIndicesResponse {
|
||||
indexes: Vec<IndexConfigResponse>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct IndexConfigResponse {
|
||||
index_name: String,
|
||||
columns: Vec<String>,
|
||||
}
|
||||
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
let body: ListIndicesResponse = serde_json::from_str(&body).map_err(|err| Error::Http {
|
||||
source: format!(
|
||||
"Failed to parse list_indices response: {}, body: {}",
|
||||
err, body
|
||||
)
|
||||
.into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
})?;
|
||||
|
||||
// Make request to get stats for each index, so we get the index type.
|
||||
// This is a bit inefficient, but it's the only way to get the index type.
|
||||
let mut futures = Vec::with_capacity(body.indexes.len());
|
||||
for index in body.indexes {
|
||||
let future = async move {
|
||||
match self.index_stats(&index.index_name).await {
|
||||
Ok(Some(stats)) => Ok(Some(IndexConfig {
|
||||
name: index.index_name,
|
||||
index_type: stats.index_type,
|
||||
columns: index.columns,
|
||||
})),
|
||||
Ok(None) => Ok(None), // The index must have been deleted since we listed it.
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
};
|
||||
futures.push(future);
|
||||
}
|
||||
let results = futures::future::try_join_all(futures).await?;
|
||||
let index_configs = results.into_iter().flatten().collect();
|
||||
|
||||
Ok(index_configs)
|
||||
}
|
||||
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
|
||||
let request = self.client.post(&format!(
|
||||
"/v1/table/{}/index/{}/stats/",
|
||||
self.name, index_name
|
||||
));
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
let stats = serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
source: format!("Failed to parse index statistics: {}", e).into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
})?;
|
||||
|
||||
Ok(Some(stats))
|
||||
}
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
Err(Error::NotSupported {
|
||||
message: "table_definition is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
fn dataset_uri(&self) -> &str {
|
||||
"NOT_SUPPORTED"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -372,9 +685,14 @@ mod tests {
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use reqwest::Body;
|
||||
|
||||
use crate::{Error, Table};
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
DistanceType, Error, Table,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_not_found() {
|
||||
@@ -419,7 +737,7 @@ mod tests {
|
||||
async fn test_version() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/describe/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
@@ -435,7 +753,7 @@ mod tests {
|
||||
async fn test_schema() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/describe/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/describe/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
@@ -464,7 +782,11 @@ mod tests {
|
||||
async fn test_count_rows() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/count_rows/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
|
||||
|
||||
http::Response::builder().status(200).body("42").unwrap()
|
||||
@@ -475,7 +797,11 @@ mod tests {
|
||||
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/count_rows/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
assert_eq!(
|
||||
request.body().unwrap().as_bytes().unwrap(),
|
||||
br#"{"filter":"a > 10"}"#
|
||||
@@ -524,7 +850,7 @@ mod tests {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let table = Table::new_with_handler("my_table", move |mut request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/insert/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/insert/");
|
||||
// If mode is specified, it should be "append". Append is default
|
||||
// so it's not required.
|
||||
assert!(request
|
||||
@@ -568,7 +894,7 @@ mod tests {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let table = Table::new_with_handler("my_table", move |mut request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/insert/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/insert/");
|
||||
assert_eq!(
|
||||
request
|
||||
.url()
|
||||
@@ -609,7 +935,11 @@ mod tests {
|
||||
async fn test_update() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/update/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/update/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
if let Some(body) = request.body().unwrap().as_bytes() {
|
||||
let body = std::str::from_utf8(body).unwrap();
|
||||
@@ -653,7 +983,7 @@ mod tests {
|
||||
// Default parameters
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/merge_insert/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/");
|
||||
|
||||
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(params["on"], "some_col");
|
||||
@@ -676,7 +1006,7 @@ mod tests {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let table = Table::new_with_handler("my_table", move |mut request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/merge_insert/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE
|
||||
@@ -716,7 +1046,11 @@ mod tests {
|
||||
async fn test_delete() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/delete/");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/delete/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
@@ -728,4 +1062,309 @@ mod tests {
|
||||
|
||||
table.delete("id in (1, 2, 3)").await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_vector_default_values() {
|
||||
let expected_data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let expected_data_ref = expected_data.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let mut expected_body = serde_json::json!({
|
||||
"prefilter": true,
|
||||
"distance_type": "l2",
|
||||
"nprobes": 20,
|
||||
"refine_factor": null,
|
||||
});
|
||||
// Pass vector separately to make sure it matches f32 precision.
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_stream(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let data = table
|
||||
.query()
|
||||
.nearest_to(vec![0.1, 0.2, 0.3])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await;
|
||||
let data = data.unwrap().collect::<Vec<_>>().await;
|
||||
assert_eq!(data.len(), 1);
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_vector_all_params() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let mut expected_body = serde_json::json!({
|
||||
"vector_column": "my_vector",
|
||||
"prefilter": false,
|
||||
"k": 42,
|
||||
"distance_type": "cosine",
|
||||
"bypass_vector_index": true,
|
||||
"columns": ["a", "b"],
|
||||
"nprobes": 12,
|
||||
"refine_factor": 2,
|
||||
});
|
||||
// Pass vector separately to make sure it matches f32 precision.
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let _ = table
|
||||
.query()
|
||||
.limit(42)
|
||||
.select(Select::columns(&["a", "b"]))
|
||||
.nearest_to(vec![0.1, 0.2, 0.3])
|
||||
.unwrap()
|
||||
.column("my_vector")
|
||||
.postfilter()
|
||||
.distance_type(crate::DistanceType::Cosine)
|
||||
.nprobes(12)
|
||||
.refine_factor(2)
|
||||
.bypass_vector_index()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_fts() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let expected_body = serde_json::json!({
|
||||
"full_text_query": {
|
||||
"columns": ["a", "b"],
|
||||
"query": "hello world",
|
||||
},
|
||||
"k": 10,
|
||||
});
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let _ = table
|
||||
.query()
|
||||
.full_text_search(
|
||||
FullTextSearchQuery::new("hello world".into())
|
||||
.columns(Some(vec!["a".into(), "b".into()])),
|
||||
)
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_index() {
|
||||
let cases = [
|
||||
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
||||
(
|
||||
"IVF_PQ",
|
||||
Some("cosine"),
|
||||
Index::IvfPq(IvfPqIndexBuilder::default().distance_type(DistanceType::Cosine)),
|
||||
),
|
||||
(
|
||||
"IVF_HNSW_SQ",
|
||||
Some("l2"),
|
||||
Index::IvfHnswSq(Default::default()),
|
||||
),
|
||||
// HNSW_PQ isn't yet supported on SaaS
|
||||
("BTREE", None, Index::BTree(Default::default())),
|
||||
("BITMAP", None, Index::Bitmap(Default::default())),
|
||||
("LABEL_LIST", None, Index::LabelList(Default::default())),
|
||||
("FTS", None, Index::FTS(Default::default())),
|
||||
];
|
||||
|
||||
for (index_type, distance_type, index) in cases {
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/create_index/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let mut expected_body = serde_json::json!({
|
||||
"column": "a",
|
||||
"index_type": index_type,
|
||||
});
|
||||
if let Some(distance_type) = distance_type {
|
||||
expected_body["metric_type"] = distance_type.to_lowercase().into();
|
||||
}
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
|
||||
table.create_index(&["a"], index).execute().await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_indices() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
|
||||
let response_body = match request.url().path() {
|
||||
"/v1/table/my_table/index/list/" => {
|
||||
serde_json::json!({
|
||||
"indexes": [
|
||||
{
|
||||
"index_name": "vector_idx",
|
||||
"index_uuid": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
|
||||
"columns": ["vector"],
|
||||
"index_status": "done",
|
||||
},
|
||||
{
|
||||
"index_name": "my_idx",
|
||||
"index_uuid": "34255f64-5717-4562-b3fc-2c963f66afa6",
|
||||
"columns": ["my_column"],
|
||||
"index_status": "done",
|
||||
},
|
||||
]
|
||||
})
|
||||
}
|
||||
"/v1/table/my_table/index/vector_idx/stats/" => {
|
||||
serde_json::json!({
|
||||
"num_indexed_rows": 100000,
|
||||
"num_unindexed_rows": 0,
|
||||
"index_type": "IVF_PQ",
|
||||
"distance_type": "l2"
|
||||
})
|
||||
}
|
||||
"/v1/table/my_table/index/my_idx/stats/" => {
|
||||
serde_json::json!({
|
||||
"num_indexed_rows": 100000,
|
||||
"num_unindexed_rows": 0,
|
||||
"index_type": "LABEL_LIST"
|
||||
})
|
||||
}
|
||||
path => panic!("Unexpected path: {}", path),
|
||||
};
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(serde_json::to_string(&response_body).unwrap())
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let indices = table.list_indices().await.unwrap();
|
||||
let expected = vec![
|
||||
IndexConfig {
|
||||
name: "vector_idx".into(),
|
||||
index_type: IndexType::IvfPq,
|
||||
columns: vec!["vector".into()],
|
||||
},
|
||||
IndexConfig {
|
||||
name: "my_idx".into(),
|
||||
index_type: IndexType::LabelList,
|
||||
columns: vec!["my_column".into()],
|
||||
},
|
||||
];
|
||||
assert_eq!(indices, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_index_stats() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/index/my_index/stats/"
|
||||
);
|
||||
|
||||
let response_body = serde_json::json!({
|
||||
"num_indexed_rows": 100000,
|
||||
"num_unindexed_rows": 0,
|
||||
"index_type": "IVF_PQ",
|
||||
"distance_type": "l2"
|
||||
});
|
||||
let response_body = serde_json::to_string(&response_body).unwrap();
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
let indices = table.index_stats("my_index").await.unwrap().unwrap();
|
||||
let expected = IndexStatistics {
|
||||
num_indexed_rows: 100000,
|
||||
num_unindexed_rows: 0,
|
||||
index_type: IndexType::IvfPq,
|
||||
distance_type: Some(DistanceType::L2),
|
||||
num_indices: None,
|
||||
};
|
||||
assert_eq!(indices, expected);
|
||||
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/index/my_index/stats/"
|
||||
);
|
||||
|
||||
http::Response::builder().status(404).body("").unwrap()
|
||||
});
|
||||
let indices = table.index_stats("my_index").await.unwrap();
|
||||
assert!(indices.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>>
|
||||
let buf = Vec::with_capacity(WRITE_BUF_SIZE);
|
||||
let mut buf = Cursor::new(buf);
|
||||
{
|
||||
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batches.schema())?;
|
||||
let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut buf, &batches.schema())?;
|
||||
|
||||
for batch in batches {
|
||||
let batch = batch?;
|
||||
|
||||
@@ -21,9 +21,11 @@ use std::sync::Arc;
|
||||
use arrow::array::AsArray;
|
||||
use arrow::datatypes::Float32Type;
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use arrow_schema::{Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
|
||||
@@ -46,7 +48,6 @@ use lance_index::IndexType;
|
||||
use lance_table::io::commit::ManifestNamingScheme;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::whatever;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
@@ -54,20 +55,25 @@ use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, M
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::scalar::FtsIndexBuilder;
|
||||
use crate::index::vector::{
|
||||
IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
|
||||
suggested_num_partitions_for_hnsw, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder,
|
||||
IvfPqIndexBuilder, VectorIndex,
|
||||
};
|
||||
use crate::index::IndexConfig;
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::index::{
|
||||
vector::{suggested_num_partitions, suggested_num_sub_vectors},
|
||||
Index, IndexBuilder,
|
||||
};
|
||||
use crate::index::{IndexConfig, IndexStatisticsImpl};
|
||||
use crate::query::{
|
||||
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
|
||||
};
|
||||
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
|
||||
use crate::utils::{
|
||||
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
|
||||
supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type,
|
||||
PatchReadParam, PatchWriteParam,
|
||||
};
|
||||
|
||||
use self::dataset::{DatasetConsistencyWrapper, DatasetReadGuard};
|
||||
use self::dataset::DatasetConsistencyWrapper;
|
||||
use self::merge::MergeInsertBuilder;
|
||||
|
||||
pub(crate) mod dataset;
|
||||
@@ -374,12 +380,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
/// Count the number of rows in this table.
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
async fn build_plan(
|
||||
&self,
|
||||
ds_ref: &DatasetReadGuard,
|
||||
query: &VectorQuery,
|
||||
options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner>;
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
@@ -390,7 +390,12 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
query: &Query,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream>;
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String>;
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
|
||||
let plan = self.create_plan(query, Default::default()).await?;
|
||||
let display = DisplayableExecutionPlan::new(plan.as_ref());
|
||||
|
||||
Ok(format!("{}", display.indent(verbose)))
|
||||
}
|
||||
async fn add(
|
||||
&self,
|
||||
add: AddDataBuilder<NoData>,
|
||||
@@ -400,6 +405,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64>;
|
||||
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>>;
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>>;
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
@@ -418,6 +424,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn checkout_latest(&self) -> Result<()>;
|
||||
async fn restore(&self) -> Result<()>;
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
fn dataset_uri(&self) -> &str;
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
@@ -949,6 +956,22 @@ impl Table {
|
||||
pub async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
self.inner.list_indices().await
|
||||
}
|
||||
|
||||
/// Get the underlying dataset URI
|
||||
///
|
||||
/// Warning: This is an internal API and the return value is subject to change.
|
||||
pub fn dataset_uri(&self) -> &str {
|
||||
self.inner.dataset_uri()
|
||||
}
|
||||
|
||||
/// Get statistics about an index.
|
||||
/// Returns None if the index does not exist.
|
||||
pub async fn index_stats(
|
||||
&self,
|
||||
index_name: impl AsRef<str>,
|
||||
) -> Result<Option<IndexStatistics>> {
|
||||
self.inner.index_stats(index_name.as_ref()).await
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NativeTable> for Table {
|
||||
@@ -1079,46 +1102,6 @@ impl NativeTable {
|
||||
Ok(name.to_string())
|
||||
}
|
||||
|
||||
fn supported_btree_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer()
|
||||
|| dtype.is_floating()
|
||||
|| matches!(
|
||||
dtype,
|
||||
DataType::Boolean
|
||||
| DataType::Utf8
|
||||
| DataType::Time32(_)
|
||||
| DataType::Time64(_)
|
||||
| DataType::Date32
|
||||
| DataType::Date64
|
||||
| DataType::Timestamp(_, _)
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_bitmap_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer() || matches!(dtype, DataType::Utf8)
|
||||
}
|
||||
|
||||
fn supported_label_list_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::List(field) => Self::supported_bitmap_data_type(field.data_type()),
|
||||
DataType::FixedSizeList(field, _) => {
|
||||
Self::supported_bitmap_data_type(field.data_type())
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn supported_fts_data_type(dtype: &DataType) -> bool {
|
||||
matches!(dtype, DataType::Utf8 | DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn supported_vector_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new Table
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -1277,91 +1260,6 @@ impl NativeTable {
|
||||
.await)
|
||||
}
|
||||
|
||||
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
|
||||
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
#[allow(deprecated)]
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(stats.num_indexed_rows)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
|
||||
pub async fn count_unindexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
#[allow(deprecated)]
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(stats.num_unindexed_rows)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
|
||||
pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> {
|
||||
#[allow(deprecated)]
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(stats.index_type.unwrap_or_default())),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
|
||||
pub async fn get_distance_type(&self, index_uuid: &str) -> Result<Option<String>> {
|
||||
#[allow(deprecated)]
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(
|
||||
stats
|
||||
.indices
|
||||
.iter()
|
||||
.filter_map(|i| i.metric_type.clone())
|
||||
.collect(),
|
||||
)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
|
||||
pub async fn load_index_stats(&self, index_uuid: &str) -> Result<Option<IndexStatistics>> {
|
||||
let index = self
|
||||
.load_indices()
|
||||
.await?
|
||||
.into_iter()
|
||||
.find(|i| i.index_uuid == index_uuid);
|
||||
if index.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
let dataset = self.dataset.get().await?;
|
||||
let index_stats = dataset.index_statistics(&index.unwrap().index_name).await?;
|
||||
let index_stats: IndexStatistics = whatever!(
|
||||
serde_json::from_str(&index_stats),
|
||||
"error deserializing index statistics {index_stats}",
|
||||
);
|
||||
|
||||
Ok(Some(index_stats))
|
||||
}
|
||||
|
||||
/// Get statistics about an index.
|
||||
/// Returns an error if the index does not exist.
|
||||
pub async fn index_stats(
|
||||
&self,
|
||||
index_name: impl AsRef<str>,
|
||||
) -> Result<Option<IndexStatistics>> {
|
||||
let stats = match self
|
||||
.dataset
|
||||
.get()
|
||||
.await?
|
||||
.index_statistics(index_name.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok(stats) => stats,
|
||||
Err(lance::error::Error::IndexNotFound { .. }) => return Ok(None),
|
||||
Err(e) => return Err(Error::from(e)),
|
||||
};
|
||||
|
||||
serde_json::from_str(&stats).map_err(|e| Error::InvalidInput {
|
||||
message: format!("error deserializing index statistics: {}", e),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
|
||||
@@ -1377,7 +1275,7 @@ impl NativeTable {
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_vector_data_type(field.data_type()) {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF PQ index cannot be created on the column `{}` which has data type {}",
|
||||
@@ -1430,7 +1328,7 @@ impl NativeTable {
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_vector_data_type(field.data_type()) {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF HNSW PQ index cannot be created on the column `{}` which has data type {}",
|
||||
@@ -1440,11 +1338,19 @@ impl NativeTable {
|
||||
});
|
||||
}
|
||||
|
||||
let num_partitions = if let Some(n) = index.num_partitions {
|
||||
let num_partitions: u32 = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.count_rows(None).await?)
|
||||
match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(_, n) => Ok::<u32, Error>(
|
||||
suggested_num_partitions_for_hnsw(self.count_rows(None).await?, *n as u32),
|
||||
),
|
||||
_ => Err(Error::Schema {
|
||||
message: format!("Column '{}' is not a FixedSizeList", field.name()),
|
||||
}),
|
||||
}?
|
||||
};
|
||||
|
||||
let num_sub_vectors: u32 = if let Some(n) = index.num_sub_vectors {
|
||||
n
|
||||
} else {
|
||||
@@ -1493,7 +1399,7 @@ impl NativeTable {
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_vector_data_type(field.data_type()) {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}",
|
||||
@@ -1503,10 +1409,17 @@ impl NativeTable {
|
||||
});
|
||||
}
|
||||
|
||||
let num_partitions = if let Some(n) = index.num_partitions {
|
||||
let num_partitions: u32 = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.count_rows(None).await?)
|
||||
match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(_, n) => Ok::<u32, Error>(
|
||||
suggested_num_partitions_for_hnsw(self.count_rows(None).await?, *n as u32),
|
||||
),
|
||||
_ => Err(Error::Schema {
|
||||
message: format!("Column '{}' is not a FixedSizeList", field.name()),
|
||||
}),
|
||||
}?
|
||||
};
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
@@ -1539,10 +1452,10 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if Self::supported_vector_data_type(field.data_type()) {
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
|
||||
.await
|
||||
} else if Self::supported_btree_data_type(field.data_type()) {
|
||||
} else if supported_btree_data_type(field.data_type()) {
|
||||
self.create_btree_index(field, opts).await
|
||||
} else {
|
||||
Err(Error::InvalidInput {
|
||||
@@ -1556,7 +1469,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_btree_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !Self::supported_btree_data_type(field.data_type()) {
|
||||
if !supported_btree_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A BTree index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1583,7 +1496,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_bitmap_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !Self::supported_bitmap_data_type(field.data_type()) {
|
||||
if !supported_bitmap_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A Bitmap index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1610,7 +1523,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_label_list_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !Self::supported_label_list_data_type(field.data_type()) {
|
||||
if !supported_label_list_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A LabelList index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1642,7 +1555,7 @@ impl NativeTable {
|
||||
fts_opts: FtsIndexBuilder,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_fts_data_type(field.data_type()) {
|
||||
if !supported_fts_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A FTS index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1863,12 +1776,13 @@ impl TableInternal for NativeTable {
|
||||
Ok(res.rows_updated)
|
||||
}
|
||||
|
||||
async fn build_plan(
|
||||
async fn create_plan(
|
||||
&self,
|
||||
ds_ref: &DatasetReadGuard,
|
||||
query: &VectorQuery,
|
||||
options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner> {
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let mut scanner: Scanner = ds_ref.scan();
|
||||
|
||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||
@@ -1938,25 +1852,16 @@ impl TableInternal for NativeTable {
|
||||
Select::All => {}
|
||||
}
|
||||
|
||||
if let Some(opts) = options {
|
||||
scanner.batch_size(opts.max_batch_length as usize);
|
||||
if query.base.with_row_id {
|
||||
scanner.with_row_id();
|
||||
}
|
||||
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
|
||||
if query.base.fast_search {
|
||||
scanner.fast_search();
|
||||
}
|
||||
|
||||
Ok(scanner)
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let mut scanner = self.build_plan(&ds_ref, query, Some(options)).await?;
|
||||
|
||||
match &query.base.select {
|
||||
Select::Columns(select) => {
|
||||
scanner.project(select.as_slice())?;
|
||||
@@ -1995,16 +1900,6 @@ impl TableInternal for NativeTable {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let scanner = self.build_plan(&ds_ref, query, None).await?;
|
||||
|
||||
let plan = scanner.explain_plan(verbose).await?;
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
@@ -2129,28 +2024,70 @@ impl TableInternal for NativeTable {
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
let indices = dataset.load_indices().await?;
|
||||
indices.iter().map(|idx| {
|
||||
let mut is_vector = false;
|
||||
futures::stream::iter(indices.as_slice()).then(|idx| async {
|
||||
let stats = dataset.index_statistics(idx.name.as_str()).await?;
|
||||
let stats: serde_json::Value = serde_json::from_str(&stats).map_err(|e| Error::Runtime {
|
||||
message: format!("error deserializing index statistics: {}", e),
|
||||
})?;
|
||||
let index_type = stats.get("index_type").and_then(|v| v.as_str())
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "index statistics was missing index type".to_string(),
|
||||
})?;
|
||||
let index_type: crate::index::IndexType = index_type.parse().map_err(|e| Error::Runtime {
|
||||
message: format!("error parsing index type: {}", e),
|
||||
})?;
|
||||
|
||||
let mut columns = Vec::with_capacity(idx.fields.len());
|
||||
for field_id in &idx.fields {
|
||||
let field = dataset.schema().field_by_id(*field_id).ok_or_else(|| Error::Runtime { message: format!("The index with name {} and uuid {} referenced a field with id {} which does not exist in the schema", idx.name, idx.uuid, field_id) })?;
|
||||
if field.data_type().is_nested() {
|
||||
// Temporary hack to determine if an index is scalar or vector
|
||||
// Should be removed in https://github.com/lancedb/lance/issues/2039
|
||||
is_vector = true;
|
||||
}
|
||||
columns.push(field.name.clone());
|
||||
}
|
||||
|
||||
let index_type = if is_vector {
|
||||
crate::index::IndexType::IvfPq
|
||||
} else {
|
||||
crate::index::IndexType::BTree
|
||||
};
|
||||
|
||||
let name = idx.name.clone();
|
||||
Ok(IndexConfig { index_type, columns, name })
|
||||
}).collect::<Result<Vec<_>>>()
|
||||
}).try_collect::<Vec<_>>().await
|
||||
}
|
||||
|
||||
fn dataset_uri(&self) -> &str {
|
||||
self.uri.as_str()
|
||||
}
|
||||
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
|
||||
let stats = match self
|
||||
.dataset
|
||||
.get()
|
||||
.await?
|
||||
.index_statistics(index_name.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok(stats) => stats,
|
||||
Err(lance::error::Error::IndexNotFound { .. }) => return Ok(None),
|
||||
Err(e) => return Err(Error::from(e)),
|
||||
};
|
||||
|
||||
let mut stats: IndexStatisticsImpl =
|
||||
serde_json::from_str(&stats).map_err(|e| Error::InvalidInput {
|
||||
message: format!("error deserializing index statistics: {}", e),
|
||||
})?;
|
||||
|
||||
let first_index = stats.indices.pop().ok_or_else(|| Error::InvalidInput {
|
||||
message: "index statistics is empty".to_string(),
|
||||
})?;
|
||||
// Index type should be present at one of the levels.
|
||||
let index_type =
|
||||
stats
|
||||
.index_type
|
||||
.or(first_index.index_type)
|
||||
.ok_or_else(|| Error::InvalidInput {
|
||||
message: "index statistics was missing index type".to_string(),
|
||||
})?;
|
||||
Ok(Some(IndexStatistics {
|
||||
num_indexed_rows: stats.num_indexed_rows,
|
||||
num_unindexed_rows: stats.num_unindexed_rows,
|
||||
index_type,
|
||||
distance_type: first_index.metric_type,
|
||||
num_indices: stats.num_indices,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2789,24 +2726,7 @@ mod tests {
|
||||
|
||||
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
assert_eq!(table.index_stats("my_index").await.unwrap(), None);
|
||||
|
||||
table
|
||||
.create_index(&["embeddings"], Index::Auto)
|
||||
@@ -2823,43 +2743,12 @@ mod tests {
|
||||
assert_eq!(table.name(), "test");
|
||||
|
||||
let indices = table.as_native().unwrap().load_indices().await.unwrap();
|
||||
let index_uuid = &indices[0].index_uuid;
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(512)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(0)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.get_index_type(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some("IVF_PQ".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.get_distance_type(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(crate::DistanceType::L2.to_string())
|
||||
);
|
||||
let index_name = &indices[0].index_name;
|
||||
let stats = table.index_stats(index_name).await.unwrap().unwrap();
|
||||
assert_eq!(stats.num_indexed_rows, 512);
|
||||
assert_eq!(stats.num_unindexed_rows, 0);
|
||||
assert_eq!(stats.index_type, crate::index::IndexType::IvfPq);
|
||||
assert_eq!(stats.distance_type, Some(crate::DistanceType::L2));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -2902,24 +2791,8 @@ mod tests {
|
||||
|
||||
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
let stats = table.index_stats("my_index").await.unwrap();
|
||||
assert!(stats.is_none());
|
||||
|
||||
let index = IvfHnswSqIndexBuilder::default();
|
||||
table
|
||||
@@ -2931,31 +2804,16 @@ mod tests {
|
||||
let index_configs = table.list_indices().await.unwrap();
|
||||
assert_eq!(index_configs.len(), 1);
|
||||
let index = index_configs.into_iter().next().unwrap();
|
||||
assert_eq!(index.index_type, crate::index::IndexType::IvfPq);
|
||||
assert_eq!(index.index_type, crate::index::IndexType::IvfHnswSq);
|
||||
assert_eq!(index.columns, vec!["embeddings".to_string()]);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 512);
|
||||
assert_eq!(table.name(), "test");
|
||||
|
||||
let indices = table.as_native().unwrap().load_indices().await.unwrap();
|
||||
let index_uuid = &indices[0].index_uuid;
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(512)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(0)
|
||||
);
|
||||
let index_name = &indices[0].index_name;
|
||||
let stats = table.index_stats(index_name).await.unwrap().unwrap();
|
||||
assert_eq!(stats.num_indexed_rows, 512);
|
||||
assert_eq!(stats.num_unindexed_rows, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -2997,25 +2855,8 @@ mod tests {
|
||||
);
|
||||
|
||||
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
let stats = table.index_stats("my_index").await.unwrap();
|
||||
assert!(stats.is_none());
|
||||
|
||||
let index = IvfHnswPqIndexBuilder::default();
|
||||
table
|
||||
@@ -3027,31 +2868,16 @@ mod tests {
|
||||
let index_configs = table.list_indices().await.unwrap();
|
||||
assert_eq!(index_configs.len(), 1);
|
||||
let index = index_configs.into_iter().next().unwrap();
|
||||
assert_eq!(index.index_type, crate::index::IndexType::IvfPq);
|
||||
assert_eq!(index.index_type, crate::index::IndexType::IvfHnswPq);
|
||||
assert_eq!(index.columns, vec!["embeddings".to_string()]);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 512);
|
||||
assert_eq!(table.name(), "test");
|
||||
|
||||
let indices = table.as_native().unwrap().load_indices().await.unwrap();
|
||||
let index_uuid = &indices[0].index_uuid;
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(512)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(0)
|
||||
);
|
||||
let indices: Vec<VectorIndex> = table.as_native().unwrap().load_indices().await.unwrap();
|
||||
let index_name = &indices[0].index_name;
|
||||
let stats = table.index_stats(index_name).await.unwrap().unwrap();
|
||||
assert_eq!(stats.num_indexed_rows, 512);
|
||||
assert_eq!(stats.num_unindexed_rows, 0);
|
||||
}
|
||||
|
||||
fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> {
|
||||
@@ -3127,25 +2953,10 @@ mod tests {
|
||||
assert_eq!(index.columns, vec!["i".to_string()]);
|
||||
|
||||
let indices = table.as_native().unwrap().load_indices().await.unwrap();
|
||||
let index_uuid = &indices[0].index_uuid;
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(1)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(0)
|
||||
);
|
||||
let index_name = &indices[0].index_name;
|
||||
let stats = table.index_stats(index_name).await.unwrap().unwrap();
|
||||
assert_eq!(stats.num_indexed_rows, 1);
|
||||
assert_eq!(stats.num_unindexed_rows, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3230,7 +3041,7 @@ mod tests {
|
||||
let values_builder = StringBuilder::new();
|
||||
let mut builder = ListBuilder::new(values_builder);
|
||||
for i in 0..120 {
|
||||
builder.values().append_value(TAGS[i % 3].to_string());
|
||||
builder.values().append_value(TAGS[i % 3]);
|
||||
if i % 3 == 0 {
|
||||
builder.append(true)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Schema;
|
||||
use arrow_schema::{DataType, Schema};
|
||||
use lance::dataset::{ReadParams, WriteParams};
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lazy_static::lazy_static;
|
||||
@@ -137,6 +137,44 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported_btree_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer()
|
||||
|| dtype.is_floating()
|
||||
|| matches!(
|
||||
dtype,
|
||||
DataType::Boolean
|
||||
| DataType::Utf8
|
||||
| DataType::Time32(_)
|
||||
| DataType::Time64(_)
|
||||
| DataType::Date32
|
||||
| DataType::Date64
|
||||
| DataType::Timestamp(_, _)
|
||||
)
|
||||
}
|
||||
|
||||
pub fn supported_bitmap_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer() || matches!(dtype, DataType::Utf8)
|
||||
}
|
||||
|
||||
pub fn supported_label_list_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::List(field) => supported_bitmap_data_type(field.data_type()),
|
||||
DataType::FixedSizeList(field, _) => supported_bitmap_data_type(field.data_type()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported_fts_data_type(dtype: &DataType) -> bool {
|
||||
matches!(dtype, DataType::Utf8 | DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user