mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-02 20:00:46 +00:00
Compare commits
4 Commits
python-v0.
...
yang/fix-q
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33bf57570e | ||
|
|
7b874905fd | ||
|
|
a327044e2f | ||
|
|
f20ec99dec |
7
.agents/skills/README.md
Normal file
7
.agents/skills/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Agent Skills
|
||||
|
||||
This directory contains repo-scoped code agent skills for the LanceDB project.
|
||||
|
||||
Each skill is a folder that contains a required `SKILL.md` and optional bundled resources.
|
||||
|
||||
Codex discovers skills from `.agents/skills` in the current working directory and parent directories.
|
||||
98
.agents/skills/lancedb-update-lance-dependency/SKILL.md
Normal file
98
.agents/skills/lancedb-update-lance-dependency/SKILL.md
Normal file
@@ -0,0 +1,98 @@
|
||||
---
|
||||
name: lancedb-update-lance-dependency
|
||||
description: Update LanceDB to a specific Lance release or tag. Use when bumping Lance dependencies in the lancedb repository, including Rust workspace Lance crates, Java lance-core, validation, branch creation, commit, push, and PR creation when requested.
|
||||
---
|
||||
|
||||
# LanceDB Update Lance Dependency
|
||||
|
||||
## Scope
|
||||
|
||||
Use this skill in the `lancedb/lancedb` repository when updating the Lance dependency to a specific Lance version or tag.
|
||||
|
||||
Inputs can be a version (`7.2.0-beta.1`), a tag (`v7.2.0-beta.1`), a tag ref (`refs/tags/v7.2.0-beta.1`), or `latest`.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Confirm the worktree status with `git status --short`.
|
||||
2. Resolve the target Lance version:
|
||||
|
||||
- If the input is `latest`, empty, or omitted, run:
|
||||
|
||||
```bash
|
||||
python3 ci/check_lance_release.py
|
||||
```
|
||||
|
||||
Parse the JSON output. If `needs_update` is not `true`, stop without creating a PR. Otherwise use `latest_tag`.
|
||||
|
||||
- If the input is explicit, use it directly.
|
||||
|
||||
3. Compute update metadata without changing files:
|
||||
|
||||
```bash
|
||||
python3 ci/update_lance_dependency.py "$TAG_OR_VERSION" --metadata-only
|
||||
```
|
||||
|
||||
Before making changes, check for an existing open PR with the emitted `pr_title`:
|
||||
|
||||
```bash
|
||||
gh pr list --search "\"$PR_TITLE\" in:title" --state open --limit 1 --json number,url,title
|
||||
```
|
||||
|
||||
If a matching open PR exists, stop and report it instead of creating a duplicate.
|
||||
|
||||
4. Run the deterministic update entrypoint:
|
||||
|
||||
```bash
|
||||
python3 ci/update_lance_dependency.py "$TAG_OR_VERSION"
|
||||
```
|
||||
|
||||
This updates the Rust workspace Lance dependencies through `ci/set_lance_version.py`, updates `java/pom.xml`, refreshes Cargo metadata, and prints JSON metadata containing `branch_name`, `commit_message`, and `pr_title`.
|
||||
|
||||
5. Run validation:
|
||||
|
||||
```bash
|
||||
cargo clippy --quiet --workspace --tests --all-features -- -D warnings
|
||||
cargo fmt --all --quiet
|
||||
```
|
||||
|
||||
Fix real diagnostics and rerun clippy until it succeeds. Do not skip warnings.
|
||||
|
||||
6. Inspect `git status --short` and `git diff` to ensure only the Lance dependency update and required compatibility fixes are present.
|
||||
|
||||
7. If the task only asks to prepare local changes, stop here and report the changed files and validation result.
|
||||
|
||||
8. If the task asks to publish the update, create a branch using the printed `branch_name`, stage all relevant files, and commit using the printed `commit_message`. Do not amend or rewrite existing commits.
|
||||
|
||||
9. Push to `origin`. Before creating the PR, check that the current token has push permission:
|
||||
|
||||
```bash
|
||||
gh api repos/lancedb/lancedb --jq .permissions.push
|
||||
```
|
||||
|
||||
If the remote branch already exists for the same generated branch name, delete the remote ref with `gh api -X DELETE repos/lancedb/lancedb/git/refs/heads/$BRANCH_NAME`, then push. Do not force-push.
|
||||
|
||||
10. Create a PR targeting `main` with the printed `pr_title`. If there is no PR template, keep the body to two or three concise sentences: state the Lance dependency bump, note any required compatibility fixes, and link the triggering Lance tag or release.
|
||||
|
||||
11. Read back the remote PR title after creation. If it is not a Conventional Commit title, fix it immediately.
|
||||
|
||||
12. When running in GitHub Actions after creating the LanceDB PR, trigger the Sophon dependency update:
|
||||
|
||||
```bash
|
||||
gh workflow run codex-bump-lancedb-lance.yml \
|
||||
--repo lancedb/sophon \
|
||||
-f lance_ref="$LANCE_TAG" \
|
||||
-f lancedb_ref="$BRANCH_NAME"
|
||||
gh run list --repo lancedb/sophon --workflow codex-bump-lancedb-lance.yml --limit 1 --json databaseId,url,displayTitle
|
||||
```
|
||||
|
||||
Use the emitted metadata `tag` value as `LANCE_TAG`. Do this only after a new LanceDB PR has been created. If the update was skipped because no update is needed or an open PR already exists, do not trigger Sophon.
|
||||
|
||||
## GitHub Actions
|
||||
|
||||
When this skill is used from GitHub Actions, `TAG`, `GH_TOKEN`, and `GITHUB_TOKEN` may already be set. Resolve `latest` first when `TAG` is empty. Once an explicit tag or version is known, use:
|
||||
|
||||
```bash
|
||||
python3 ci/update_lance_dependency.py "$TAG" --github-output "$GITHUB_OUTPUT"
|
||||
```
|
||||
|
||||
Then use the emitted `branch_name`, `commit_message`, and `pr_title` values for branch, commit, and PR creation.
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.30.0-beta.1"
|
||||
current_version = "0.30.1-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -4,14 +4,16 @@ on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
description: "Tag name from Lance. If omitted, the skill will use the latest Lance release that needs an update."
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Tag name from Lance"
|
||||
required: true
|
||||
description: "Tag name from Lance. Leave empty to use the latest Lance release that needs an update."
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
@@ -25,7 +27,7 @@ jobs:
|
||||
steps:
|
||||
- name: Show inputs
|
||||
run: |
|
||||
echo "tag = ${{ inputs.tag }}"
|
||||
echo "tag = ${{ inputs.tag || 'latest' }}"
|
||||
|
||||
- name: Checkout Repo LanceDB
|
||||
uses: actions/checkout@v4
|
||||
@@ -71,65 +73,21 @@ jobs:
|
||||
OPENAI_API_KEY: ${{ secrets.CODEX_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
VERSION="${TAG#refs/tags/}"
|
||||
VERSION="${VERSION#v}"
|
||||
BRANCH_NAME="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
|
||||
# Use "chore" for beta/rc versions, "feat" for stable releases
|
||||
if [[ "${VERSION}" == *beta* ]] || [[ "${VERSION}" == *rc* ]]; then
|
||||
COMMIT_TYPE="chore"
|
||||
else
|
||||
COMMIT_TYPE="feat"
|
||||
fi
|
||||
TARGET_TAG="${TAG:-latest}"
|
||||
|
||||
cat <<EOF >/tmp/codex-prompt.txt
|
||||
You are running inside the lancedb repository on a GitHub Actions runner. Update the Lance dependency to version ${VERSION} and prepare a pull request for maintainers to review.
|
||||
You are running inside the lancedb repository on a GitHub Actions runner.
|
||||
|
||||
Follow these steps exactly:
|
||||
1. Use script "ci/set_lance_version.py" to update Lance Rust dependencies. The script already refreshes Cargo metadata, so allow it to finish even if it takes time.
|
||||
2. Update the Java lance-core dependency version in "java/pom.xml": change the "<lance-core.version>...</lance-core.version>" property to "${VERSION}".
|
||||
3. Run "cargo clippy --workspace --tests --all-features -- -D warnings". If diagnostics appear, fix them yourself and rerun clippy until it exits cleanly. Do not skip any warnings.
|
||||
4. After clippy succeeds, run "cargo fmt --all" to format the workspace.
|
||||
5. Ensure the repository is clean except for intentional changes. Inspect "git status --short" and "git diff" to confirm the dependency update and any required fixes.
|
||||
6. Create and switch to a new branch named "${BRANCH_NAME}" (replace any duplicated hyphens if necessary).
|
||||
7. Stage all relevant files with "git add -A". Commit using the message "${COMMIT_TYPE}: update lance dependency to v${VERSION}".
|
||||
8. Push the branch to origin. If the remote branch already exists, delete it first with "gh api -X DELETE repos/lancedb/lancedb/git/refs/heads/${BRANCH_NAME}" then push with "git push origin ${BRANCH_NAME}". Do NOT use "git push --force" or "git push -f".
|
||||
9. env "GH_TOKEN" is available, use "gh" tools for github related operations like creating pull request.
|
||||
10. Create a pull request targeting "main" with title "${COMMIT_TYPE}: update lance dependency to v${VERSION}". First, write the PR body to /tmp/pr-body.md using a heredoc (cat <<'EOF' > /tmp/pr-body.md). The body should summarize the dependency bump, clippy/fmt verification, and link the triggering tag (${TAG}). Then run "gh pr create --body-file /tmp/pr-body.md".
|
||||
11. After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
Use \$lancedb-update-lance-dependency with target "${TARGET_TAG}".
|
||||
|
||||
Constraints:
|
||||
- Use bash commands; avoid modifying GitHub workflow files other than through the scripted task above.
|
||||
- Do not merge the PR.
|
||||
- If any command fails, diagnose and fix the issue instead of aborting.
|
||||
- Use env "GH_TOKEN" for GitHub operations.
|
||||
- Do not merge the pull request.
|
||||
- Do not force-push.
|
||||
- Do not create a duplicate pull request if an open PR already exists for the target Lance version.
|
||||
- If any command fails, diagnose and fix the root cause instead of aborting.
|
||||
- After creating the PR, display the PR URL, "git status --short", and a concise summary of the commands run and their results.
|
||||
EOF
|
||||
|
||||
printenv OPENAI_API_KEY | codex login --with-api-key
|
||||
codex --config shell_environment_policy.ignore_default_excludes=true exec --dangerously-bypass-approvals-and-sandbox "$(cat /tmp/codex-prompt.txt)"
|
||||
|
||||
- name: Trigger sophon dependency update
|
||||
env:
|
||||
TAG: ${{ inputs.tag }}
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
VERSION="${TAG#refs/tags/}"
|
||||
VERSION="${VERSION#v}"
|
||||
LANCEDB_BRANCH="codex/update-lance-${VERSION//[^a-zA-Z0-9]/-}"
|
||||
|
||||
echo "Triggering sophon workflow with:"
|
||||
echo " lance_ref: ${TAG#refs/tags/}"
|
||||
echo " lancedb_ref: ${LANCEDB_BRANCH}"
|
||||
|
||||
gh workflow run codex-bump-lancedb-lance.yml \
|
||||
--repo lancedb/sophon \
|
||||
-f lance_ref="${TAG#refs/tags/}" \
|
||||
-f lancedb_ref="${LANCEDB_BRANCH}"
|
||||
|
||||
- name: Show latest sophon workflow run
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "Latest sophon workflow run:"
|
||||
gh run list --repo lancedb/sophon --workflow codex-bump-lancedb-lance.yml --limit 1 --json databaseId,url,displayTitle
|
||||
|
||||
62
.github/workflows/lance-release-timer.yml
vendored
62
.github/workflows/lance-release-timer.yml
vendored
@@ -1,62 +0,0 @@
|
||||
name: Lance Release Timer
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "*/10 * * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
actions: write
|
||||
|
||||
concurrency:
|
||||
group: lance-release-timer
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
trigger-update:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check for new Lance tag
|
||||
id: check
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
python3 ci/check_lance_release.py --github-output "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Look for existing PR
|
||||
if: steps.check.outputs.needs_update == 'true'
|
||||
id: pr
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TITLE="chore: update lance dependency to v${{ steps.check.outputs.latest_version }}"
|
||||
COUNT=$(gh pr list --search "\"$TITLE\" in:title" --state open --limit 1 --json number --jq 'length')
|
||||
if [ "$COUNT" -gt 0 ]; then
|
||||
echo "Open PR already exists for $TITLE"
|
||||
echo "pr_exists=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "No existing PR for $TITLE"
|
||||
echo "pr_exists=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Trigger codex update workflow
|
||||
if: steps.check.outputs.needs_update == 'true' && steps.pr.outputs.pr_exists != 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG=${{ steps.check.outputs.latest_tag }}
|
||||
gh workflow run codex-update-lance-dependency.yml -f tag=refs/tags/$TAG
|
||||
|
||||
- name: Show latest codex workflow run
|
||||
if: steps.check.outputs.needs_update == 'true' && steps.pr.outputs.pr_exists != 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.ROBOT_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
gh run list --workflow codex-update-lance-dependency.yml --limit 1 --json databaseId,url,displayTitle
|
||||
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -5128,7 +5128,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.30.0-beta.1"
|
||||
version = "0.30.1-beta.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -5211,7 +5211,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.30.0-beta.1"
|
||||
version = "0.30.1-beta.0"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5234,7 +5234,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.33.0-beta.1"
|
||||
version = "0.33.1-beta.0"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
|
||||
126
ci/update_lance_dependency.py
Normal file
126
ci/update_lance_dependency.py
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prepare a Lance dependency update for LanceDB."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
try:
|
||||
from check_lance_release import parse_semver
|
||||
except ModuleNotFoundError:
|
||||
# Supports importing as ci.update_lance_dependency from tests or ad hoc checks.
|
||||
from ci.check_lance_release import parse_semver # type: ignore
|
||||
|
||||
|
||||
def normalize_version(raw: str) -> str:
|
||||
value = raw.strip()
|
||||
value = value.removeprefix("refs/tags/")
|
||||
value = value.removeprefix("v")
|
||||
try:
|
||||
parse_semver(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported Lance version or tag: {raw}")
|
||||
return value
|
||||
|
||||
|
||||
def normalized_tag(version: str) -> str:
|
||||
return f"v{version}"
|
||||
|
||||
|
||||
def branch_name(version: str) -> str:
|
||||
suffix = re.sub(r"[^a-zA-Z0-9]+", "-", version).strip("-")
|
||||
suffix = re.sub(r"-+", "-", suffix)
|
||||
return f"codex/update-lance-{suffix}"
|
||||
|
||||
|
||||
def commit_type(version: str) -> str:
|
||||
prerelease = version.split("-", maxsplit=1)[1] if "-" in version else ""
|
||||
return "chore" if "beta" in prerelease or "rc" in prerelease else "feat"
|
||||
|
||||
|
||||
def metadata_for(version: str) -> dict[str, str]:
|
||||
kind = commit_type(version)
|
||||
message = f"{kind}: update lance dependency to v{version}"
|
||||
return {
|
||||
"version": version,
|
||||
"tag": normalized_tag(version),
|
||||
"branch_name": branch_name(version),
|
||||
"commit_type": kind,
|
||||
"commit_message": message,
|
||||
"pr_title": message,
|
||||
}
|
||||
|
||||
|
||||
def run_command(cmd: Sequence[str], *, cwd: Path) -> None:
|
||||
subprocess.run(cmd, cwd=cwd, check=True)
|
||||
|
||||
|
||||
def update_java_lance_core_version(repo_root: Path, version: str) -> None:
|
||||
pom_path = repo_root / "java" / "pom.xml"
|
||||
contents = pom_path.read_text(encoding="utf-8")
|
||||
updated, count = re.subn(
|
||||
r"(<lance-core\.version>)[^<]+(</lance-core\.version>)",
|
||||
rf"\g<1>{version}\g<2>",
|
||||
contents,
|
||||
count=1,
|
||||
)
|
||||
if count != 1:
|
||||
raise RuntimeError(
|
||||
"Expected exactly one <lance-core.version> entry in java/pom.xml"
|
||||
)
|
||||
pom_path.write_text(updated, encoding="utf-8")
|
||||
|
||||
|
||||
def write_github_outputs(path: str | None, payload: dict[str, str]) -> None:
|
||||
if not path:
|
||||
return
|
||||
with open(path, "a", encoding="utf-8") as output:
|
||||
for key, value in payload.items():
|
||||
output.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def main(argv: Sequence[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"tag_or_version",
|
||||
help="Lance tag or version, for example refs/tags/v7.2.0-beta.1 or 7.2.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-root",
|
||||
type=Path,
|
||||
default=Path(__file__).resolve().parents[1],
|
||||
help="Path to the lancedb repository root",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--github-output",
|
||||
default=None,
|
||||
help="Optional GitHub Actions output file to receive metadata fields",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata-only",
|
||||
action="store_true",
|
||||
help="Only print derived metadata; do not modify dependency files",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
repo_root = args.repo_root.resolve()
|
||||
version = normalize_version(args.tag_or_version)
|
||||
payload = metadata_for(version)
|
||||
|
||||
if not args.metadata_only:
|
||||
run_command([sys.executable, "ci/set_lance_version.py", version], cwd=repo_root)
|
||||
update_java_lance_core_version(repo_root, version)
|
||||
|
||||
write_github_outputs(args.github_output, payload)
|
||||
print(json.dumps(payload, sort_keys=True))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.30.0-beta.1</version>
|
||||
<version>0.30.1-beta.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.0-beta.1</version>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.30.0-beta.1</version>
|
||||
<version>0.30.1-beta.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.30.0-beta.1"
|
||||
version = "0.30.1-beta.0"
|
||||
publish = false
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.30.0-beta.1",
|
||||
"version": "0.30.1-beta.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -315,6 +315,15 @@ def deserialize_conn(
|
||||
manifest_enabled=parsed.get("manifest_enabled", False),
|
||||
namespace_client_properties=parsed.get("namespace_client_properties"),
|
||||
)
|
||||
elif connection_type == "remote":
|
||||
return RemoteDBConnection(
|
||||
parsed["db_url"],
|
||||
parsed["api_key"],
|
||||
parsed.get("region", "us-east-1"),
|
||||
host_override=parsed.get("host_override"),
|
||||
client_config=parsed.get("client_config"),
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown connection_type: {connection_type}")
|
||||
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
from deprecation import deprecated
|
||||
import pyarrow as pa
|
||||
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .table import LanceTable, Table
|
||||
from .background_loop import LOOP
|
||||
from .util import batch_to_tensor, batch_to_tensor_rows
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
@@ -354,6 +355,49 @@ class Transforms:
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _table_to_pickle_state(table: Table) -> dict[str, Any]:
|
||||
from .remote.table import RemoteTable
|
||||
|
||||
if isinstance(table, RemoteTable):
|
||||
return {
|
||||
"kind": "remote",
|
||||
"table": table,
|
||||
}
|
||||
|
||||
if not isinstance(table, LanceTable):
|
||||
raise ValueError(f"Cannot pickle table of type {type(table)!r}")
|
||||
|
||||
base_uri = table._conn.uri
|
||||
if base_uri.startswith("memory://"):
|
||||
return {
|
||||
"kind": "memory",
|
||||
"name": table.name,
|
||||
"data": table.to_arrow(),
|
||||
}
|
||||
|
||||
return {
|
||||
"kind": "local",
|
||||
"name": table.name,
|
||||
"uri": base_uri,
|
||||
"namespace": table._namespace_path,
|
||||
"storage_options": table._conn.storage_options,
|
||||
}
|
||||
|
||||
|
||||
def _table_from_pickle_state(state: dict[str, Any]) -> Table:
|
||||
from . import connect
|
||||
|
||||
kind = state["kind"]
|
||||
if kind == "remote":
|
||||
return state["table"]
|
||||
if kind == "memory":
|
||||
return connect("memory://").create_table(state["name"], state["data"])
|
||||
if kind == "local":
|
||||
db = connect(state["uri"], storage_options=state["storage_options"])
|
||||
return db.open_table(state["name"], namespace_path=state["namespace"] or None)
|
||||
raise ValueError(f"Unknown table pickle state kind: {kind}")
|
||||
|
||||
|
||||
class Permutation:
|
||||
"""
|
||||
A Permutation is a view of a dataset that can be used as input to model training
|
||||
@@ -369,15 +413,15 @@ class Permutation:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_table: LanceTable,
|
||||
permutation_table: Optional[LanceTable],
|
||||
base_table: Table,
|
||||
permutation_table: Optional[Table],
|
||||
split: int,
|
||||
selection: dict[str, str],
|
||||
batch_size: int,
|
||||
transform_fn: Callable[pa.RecordBatch, Any],
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
connection_factory: Optional[Callable[[str], LanceTable]] = None,
|
||||
connection_factory: Optional[Callable[[str], Table]] = None,
|
||||
_reader: Optional[PermutationReader] = None,
|
||||
):
|
||||
"""
|
||||
@@ -397,6 +441,7 @@ class Permutation:
|
||||
if _reader is None:
|
||||
_reader = LOOP.run(self._build_reader())
|
||||
self.reader: PermutationReader = _reader
|
||||
self._pid = os.getpid()
|
||||
|
||||
async def _build_reader(self) -> PermutationReader:
|
||||
reader = await PermutationReader.from_tables(
|
||||
@@ -428,29 +473,25 @@ class Permutation:
|
||||
return new
|
||||
|
||||
def with_connection_factory(
|
||||
self, connection_factory: Callable[[str], LanceTable]
|
||||
self, connection_factory: Callable[[str], Table]
|
||||
) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation that will use ``connection_factory`` to reopen
|
||||
the base table when this permutation is unpickled in a worker process.
|
||||
|
||||
The factory is a callable that takes a single argument — the base table
|
||||
name — and returns a [LanceTable]. It must be picklable; the worker
|
||||
The factory is a callable that takes a single argument, the base table
|
||||
name, and returns a LanceDB table. It must be picklable; the worker
|
||||
will pickle it via standard ``pickle`` and call it to recover the base
|
||||
table. Picklable callables in practice means top-level (module-level)
|
||||
functions, ``functools.partial`` of such functions, or instances of
|
||||
picklable classes implementing ``__call__``. Lambdas and closures over
|
||||
local variables don't pickle with the default protocol.
|
||||
|
||||
Setting a factory is necessary when the URI alone is not enough to
|
||||
re-open the connection — most importantly for LanceDB Cloud (``db://``)
|
||||
connections, where ``api_key`` and ``region`` aren't recoverable from
|
||||
the connection object after construction.
|
||||
|
||||
For local file or cloud-storage paths the factory is optional: if not
|
||||
set, ``__getstate__`` falls back to capturing
|
||||
``(uri, storage_options, namespace_path)`` and re-opening via
|
||||
``lancedb.connect(uri, storage_options=...)``.
|
||||
A factory is optional for normal local and remote LanceDB connections:
|
||||
if not set, ``__getstate__`` captures the table's own picklable reopen
|
||||
state. Use a factory when that default state is not enough, for example
|
||||
when credentials should be loaded from the worker environment instead
|
||||
of being embedded in the pickle.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -508,7 +549,7 @@ class Permutation:
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def identity(cls, table: LanceTable) -> "Permutation":
|
||||
def identity(cls, table: Table) -> "Permutation":
|
||||
"""
|
||||
Creates an identity permutation for the given table.
|
||||
"""
|
||||
@@ -517,8 +558,8 @@ class Permutation:
|
||||
@classmethod
|
||||
def from_tables(
|
||||
cls,
|
||||
base_table: LanceTable,
|
||||
permutation_table: Optional[LanceTable] = None,
|
||||
base_table: Table,
|
||||
permutation_table: Optional[Table] = None,
|
||||
split: Optional[Union[str, int]] = None,
|
||||
) -> "Permutation":
|
||||
"""
|
||||
@@ -594,11 +635,10 @@ class Permutation:
|
||||
|
||||
The base table is captured either via a user-supplied
|
||||
``connection_factory`` (see [with_connection_factory]) or, as a
|
||||
fallback, by introspecting ``(uri, storage_options, namespace_path)``
|
||||
on the connection. The permutation table — always an in-memory
|
||||
LanceDB table — is captured as a pyarrow Table (which pickles via
|
||||
Arrow IPC natively). The reader is dropped from the wire format;
|
||||
``__setstate__`` rebuilds it from the restored tables.
|
||||
fallback, by the table's own picklable reopen state. The permutation
|
||||
table is captured as a pyarrow Table (which pickles via Arrow IPC
|
||||
natively). The reader is dropped from the wire format and rebuilt
|
||||
lazily on first use.
|
||||
"""
|
||||
permutation_data: Optional[pa.Table] = None
|
||||
if self.permutation_table is not None:
|
||||
@@ -622,39 +662,9 @@ class Permutation:
|
||||
# namespace from the existing connection.
|
||||
return common
|
||||
|
||||
# URI-introspection fallback: only viable for native (OSS) connections
|
||||
# where (uri, storage_options) is enough to reopen. Remote / cloud
|
||||
# connections don't expose recoverable api_key / region — those users
|
||||
# must call with_connection_factory().
|
||||
try:
|
||||
base_uri = self.base_table._conn.uri
|
||||
storage_options = self.base_table._conn.storage_options
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
"Cannot pickle this Permutation: the base table's connection "
|
||||
"does not expose a uri/storage_options, which usually means it "
|
||||
"is a remote (LanceDB Cloud) connection. Call "
|
||||
"Permutation.with_connection_factory(...) first to provide a "
|
||||
"picklable callable that re-opens the base table from a worker "
|
||||
"process."
|
||||
) from e
|
||||
|
||||
if base_uri.startswith("memory://"):
|
||||
# In-memory base tables don't exist in any worker process by
|
||||
# default, so dump the entire base table into the pickle. This
|
||||
# can be expensive for large datasets — users with large
|
||||
# in-memory base tables should either persist them or set a
|
||||
# connection_factory.
|
||||
return {
|
||||
**common,
|
||||
"base_table_data": self.base_table.to_arrow(),
|
||||
}
|
||||
|
||||
return {
|
||||
**common,
|
||||
"base_table_uri": base_uri,
|
||||
"base_table_namespace": self.base_table._namespace_path,
|
||||
"base_table_storage_options": storage_options,
|
||||
"base_table_state": _table_to_pickle_state(self.base_table),
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
@@ -663,6 +673,8 @@ class Permutation:
|
||||
connection_factory = state["connection_factory"]
|
||||
if connection_factory is not None:
|
||||
base_table = connection_factory(state["base_table_name"])
|
||||
elif "base_table_state" in state:
|
||||
base_table = _table_from_pickle_state(state["base_table_state"])
|
||||
elif "base_table_data" in state:
|
||||
# In-memory base table inlined into the pickle; rebuild the same
|
||||
# way we rebuild the in-memory permutation table.
|
||||
@@ -680,7 +692,7 @@ class Permutation:
|
||||
namespace_path=state["base_table_namespace"] or None,
|
||||
)
|
||||
|
||||
permutation_table: Optional[LanceTable] = None
|
||||
permutation_table: Optional[Table] = None
|
||||
if state["permutation_data"] is not None:
|
||||
mem_db = connect("memory://")
|
||||
permutation_table = mem_db.create_table(
|
||||
@@ -696,10 +708,28 @@ class Permutation:
|
||||
self.offset = state["offset"]
|
||||
self.limit = state["limit"]
|
||||
self.connection_factory = connection_factory
|
||||
self.reader = None
|
||||
self._pid = None
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
pid = os.getpid()
|
||||
if self.reader is not None and getattr(self, "_pid", None) == pid:
|
||||
return
|
||||
# The reader owns Rust-side table handles. Rebuild it after unpickle or
|
||||
# fork even though the Python table wrappers reopen themselves.
|
||||
if hasattr(self.base_table, "_ensure_open"):
|
||||
self.base_table._ensure_open()
|
||||
if self.permutation_table is not None and hasattr(
|
||||
self.permutation_table, "_ensure_open"
|
||||
):
|
||||
self.permutation_table._ensure_open()
|
||||
self.reader = LOOP.run(self._build_reader())
|
||||
self._pid = pid
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
self._ensure_open()
|
||||
|
||||
async def do_output_schema():
|
||||
return await self.reader.output_schema(self.selection)
|
||||
|
||||
@@ -717,6 +747,7 @@ class Permutation:
|
||||
"""
|
||||
The number of rows in the permutation
|
||||
"""
|
||||
self._ensure_open()
|
||||
return self.reader.count_rows()
|
||||
|
||||
@property
|
||||
@@ -875,6 +906,7 @@ class Permutation:
|
||||
If skip_last_batch is True, the last batch will be skipped if it is not a
|
||||
multiple of batch_size.
|
||||
"""
|
||||
self._ensure_open()
|
||||
|
||||
async def get_iter():
|
||||
return await self.reader.read(self.selection, batch_size=batch_size)
|
||||
@@ -976,6 +1008,7 @@ class Permutation:
|
||||
so `with_format` and `with_transform` affect this method in the same way
|
||||
they affect iteration.
|
||||
"""
|
||||
self._ensure_open()
|
||||
|
||||
async def do_take_offsets():
|
||||
return await self.reader.take_offsets(offsets, selection=self.selection)
|
||||
@@ -1011,9 +1044,11 @@ class Permutation:
|
||||
"""
|
||||
Skip the first `skip` rows of the permutation
|
||||
"""
|
||||
self._ensure_open()
|
||||
new = copy.copy(self)
|
||||
new.offset = skip
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
new._pid = os.getpid()
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_take instead")
|
||||
@@ -1032,9 +1067,11 @@ class Permutation:
|
||||
"""
|
||||
Limit the permutation to `limit` rows (following any `skip`)
|
||||
"""
|
||||
self._ensure_open()
|
||||
new = copy.copy(self)
|
||||
new.limit = limit
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
new._pid = os.getpid()
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_repeat instead")
|
||||
|
||||
@@ -41,6 +41,14 @@ from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import flatten_columns
|
||||
|
||||
BlobMode = Literal["lazy", "bytes", "descriptions"]
|
||||
|
||||
_BLOB_MODE_TO_HANDLING = {
|
||||
"lazy": "blobs_descriptions",
|
||||
"bytes": "all_binary",
|
||||
"descriptions": "blobs_descriptions",
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
|
||||
@@ -55,7 +63,7 @@ if TYPE_CHECKING:
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from .common import VEC
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
from .table import AsyncTable, Table
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self
|
||||
@@ -65,6 +73,147 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T", bound="LanceModel")
|
||||
|
||||
|
||||
def _validate_blob_mode(blob_mode: BlobMode) -> None:
|
||||
if blob_mode not in _BLOB_MODE_TO_HANDLING:
|
||||
modes = ", ".join(repr(mode) for mode in _BLOB_MODE_TO_HANDLING)
|
||||
raise ValueError(f"blob_mode must be one of {modes}, got {blob_mode!r}")
|
||||
|
||||
|
||||
def _field_is_blob(field: pa.Field) -> bool:
|
||||
metadata = field.metadata or {}
|
||||
return metadata.get(b"lance-encoding:blob") == b"true" or (
|
||||
metadata.get("lance-encoding:blob") == "true"
|
||||
)
|
||||
|
||||
|
||||
def _schema_has_blob_field(schema: pa.Schema) -> bool:
|
||||
return any(_field_is_blob(field) for field in schema)
|
||||
|
||||
|
||||
def _blob_mode_requires_native_pandas(blob_mode: BlobMode, schema: pa.Schema) -> bool:
|
||||
return blob_mode in ("lazy", "bytes") and _schema_has_blob_field(schema)
|
||||
|
||||
|
||||
def _unsupported_blob_pandas_error(reason: str) -> RuntimeError:
|
||||
return RuntimeError(
|
||||
"blob_mode='lazy' and blob_mode='bytes' require Lance native pandas "
|
||||
f"conversion for queries that return blob columns, but {reason}. "
|
||||
"Use blob_mode='descriptions' or remove blob columns from the projection."
|
||||
)
|
||||
|
||||
|
||||
def _query_is_plain_scan(query: Query) -> bool:
|
||||
return (
|
||||
query.vector is None
|
||||
and query.full_text_query is None
|
||||
and not query.postfilter
|
||||
and not query.order_by
|
||||
)
|
||||
|
||||
|
||||
def _filter_to_sql(filter: Optional[Union[str, Expr]]) -> Optional[str]:
|
||||
if filter is None:
|
||||
return None
|
||||
if isinstance(filter, Expr):
|
||||
return filter.to_sql()
|
||||
return filter
|
||||
|
||||
|
||||
def _projection_to_scanner_kwargs(
|
||||
columns: Optional[
|
||||
Union[
|
||||
List[str], List[Tuple[str, Union[str, Expr]]], Dict[str, Union[str, Expr]]
|
||||
]
|
||||
],
|
||||
) -> Dict[str, Any]:
|
||||
if columns is None:
|
||||
return {}
|
||||
if isinstance(columns, list):
|
||||
if all(isinstance(column, str) for column in columns):
|
||||
return {"columns": columns}
|
||||
if all(isinstance(column, tuple) and len(column) == 2 for column in columns):
|
||||
return {
|
||||
"columns": {
|
||||
name: expr.to_sql() if isinstance(expr, Expr) else expr
|
||||
for name, expr in columns
|
||||
}
|
||||
}
|
||||
# Let Lance raise the detailed projection validation error.
|
||||
return {"columns": columns}
|
||||
|
||||
projection = {}
|
||||
for name, expr in columns.items():
|
||||
if isinstance(expr, Expr):
|
||||
expr = expr.to_sql()
|
||||
projection[name] = expr
|
||||
return {"columns": projection}
|
||||
|
||||
|
||||
def _scanner_kwargs_for_query(query: Query, blob_mode: BlobMode) -> Dict[str, Any]:
|
||||
kwargs = {
|
||||
**_projection_to_scanner_kwargs(query.columns),
|
||||
"filter": _filter_to_sql(query.filter),
|
||||
"limit": query.limit,
|
||||
"offset": query.offset,
|
||||
"with_row_id": query.with_row_id,
|
||||
"fast_search": query.fast_search,
|
||||
"blob_handling": _BLOB_MODE_TO_HANDLING[blob_mode],
|
||||
}
|
||||
return {key: value for key, value in kwargs.items() if value is not None}
|
||||
|
||||
|
||||
def _ensure_lazy_blob_frame(
|
||||
df: "pd.DataFrame", schema: pa.Schema, blob_mode: BlobMode
|
||||
) -> "pd.DataFrame":
|
||||
if blob_mode != "lazy" or not _schema_has_blob_field(schema) or len(df) == 0:
|
||||
return df
|
||||
|
||||
for field in schema:
|
||||
if not _field_is_blob(field) or field.name not in df.columns:
|
||||
continue
|
||||
value = df[field.name].iloc[0]
|
||||
if value is not None and not hasattr(value, "readall"):
|
||||
raise _unsupported_blob_pandas_error(
|
||||
"the Lance scanner did not return lazy blob files"
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def _scanner_to_pandas(scanner: Any, blob_mode: BlobMode, **kwargs) -> "pd.DataFrame":
|
||||
schema = getattr(scanner, "projected_schema", None)
|
||||
if schema is None:
|
||||
schema = getattr(scanner, "schema", None)
|
||||
if schema is None:
|
||||
schema = getattr(scanner, "dataset_schema", None)
|
||||
if callable(schema):
|
||||
schema = schema()
|
||||
if hasattr(scanner, "to_pandas"):
|
||||
try:
|
||||
df = scanner.to_pandas(blob_mode=blob_mode, **kwargs)
|
||||
except TypeError as err:
|
||||
message = str(err)
|
||||
if "blob_mode" not in message and "unexpected keyword" not in message:
|
||||
raise
|
||||
df = scanner.to_pandas(**kwargs)
|
||||
if schema is not None:
|
||||
return _ensure_lazy_blob_frame(df, schema, blob_mode)
|
||||
return df
|
||||
|
||||
if hasattr(scanner, "to_pyarrow"):
|
||||
reader = scanner.to_pyarrow()
|
||||
tbl = reader.read_all()
|
||||
elif hasattr(scanner, "to_table"):
|
||||
tbl = scanner.to_table()
|
||||
else:
|
||||
reader = scanner.to_reader()
|
||||
tbl = reader.read_all()
|
||||
if blob_mode == "lazy" and _schema_has_blob_field(tbl.schema):
|
||||
raise _unsupported_blob_pandas_error(
|
||||
"the Lance scanner does not expose to_pandas"
|
||||
)
|
||||
return tbl.to_pandas(**kwargs)
|
||||
|
||||
|
||||
# Pydantic validation function for vector queries
|
||||
def ensure_vector_query(
|
||||
val: Any,
|
||||
@@ -718,6 +867,7 @@ class LanceQueryBuilder(ABC):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
*,
|
||||
blob_mode: BlobMode = "lazy",
|
||||
timeout: Optional[timedelta] = None,
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
@@ -737,11 +887,31 @@ class LanceQueryBuilder(ABC):
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If None, wait indefinitely.
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how blob columns are returned for plain scan queries.
|
||||
Vector, FTS, hybrid, and other non-native query shapes keep the
|
||||
existing Arrow conversion path and only support blob descriptions.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
_validate_blob_mode(blob_mode)
|
||||
native_error = None
|
||||
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
|
||||
if _blob_mode_requires_native_pandas(blob_mode, tbl.schema):
|
||||
if flatten is None and timeout is None:
|
||||
try:
|
||||
df = self._plain_scan_to_pandas(blob_mode, **kwargs)
|
||||
if df is not None:
|
||||
return df
|
||||
except Exception as err:
|
||||
native_error = err
|
||||
reason = (
|
||||
"this query shape cannot use Lance native pandas conversion"
|
||||
if native_error is None
|
||||
else str(native_error)
|
||||
)
|
||||
raise _unsupported_blob_pandas_error(reason) from native_error
|
||||
return tbl.to_pandas(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
@@ -1086,6 +1256,19 @@ class LanceQueryBuilder(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
query = self.to_query_object()
|
||||
if not _query_is_plain_scan(query):
|
||||
return None
|
||||
|
||||
dataset = self._table.to_lance()
|
||||
scanner = dataset.scanner(**_scanner_kwargs_for_query(query, blob_mode))
|
||||
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def to_query_object(self) -> Query:
|
||||
"""Return a serializable representation of the query
|
||||
@@ -2207,7 +2390,11 @@ class AsyncQueryBase(object):
|
||||
Base class for all async queries (take, scan, vector, fts, hybrid)
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery, LanceTakeQuery]):
|
||||
def __init__(
|
||||
self,
|
||||
inner: Union[LanceQuery, LanceVectorQuery, LanceTakeQuery],
|
||||
table: Optional["AsyncTable"] = None,
|
||||
):
|
||||
"""
|
||||
Construct an AsyncQueryBase
|
||||
|
||||
@@ -2215,6 +2402,7 @@ class AsyncQueryBase(object):
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
self._inner = inner
|
||||
self._table = table
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
"""
|
||||
@@ -2357,6 +2545,8 @@ class AsyncQueryBase(object):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
*,
|
||||
blob_mode: BlobMode = "lazy",
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
@@ -2390,13 +2580,48 @@ class AsyncQueryBase(object):
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how blob columns are returned for plain scan queries.
|
||||
Vector, FTS, hybrid, and other non-native query shapes keep the
|
||||
existing Arrow conversion path and only support blob descriptions.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
return (
|
||||
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
|
||||
).to_pandas(**kwargs)
|
||||
_validate_blob_mode(blob_mode)
|
||||
native_error = None
|
||||
tbl = flatten_columns(await self.to_arrow(timeout=timeout), flatten)
|
||||
if _blob_mode_requires_native_pandas(blob_mode, tbl.schema):
|
||||
if flatten is None and timeout is None:
|
||||
try:
|
||||
df = await self._plain_scan_to_pandas(blob_mode, **kwargs)
|
||||
if df is not None:
|
||||
return df
|
||||
except Exception as err:
|
||||
native_error = err
|
||||
reason = (
|
||||
"this query shape cannot use Lance native pandas conversion"
|
||||
if native_error is None
|
||||
else str(native_error)
|
||||
)
|
||||
raise _unsupported_blob_pandas_error(reason) from native_error
|
||||
return tbl.to_pandas(**kwargs)
|
||||
|
||||
async def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
if self._table is None:
|
||||
return None
|
||||
|
||||
query = self.to_query_object()
|
||||
if not _query_is_plain_scan(query):
|
||||
return None
|
||||
|
||||
dataset = await self._table._to_lance()
|
||||
scanner = dataset.scanner(**_scanner_kwargs_for_query(query, blob_mode))
|
||||
return _scanner_to_pandas(scanner, blob_mode, **kwargs)
|
||||
|
||||
async def to_polars(
|
||||
self,
|
||||
@@ -2503,14 +2728,18 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
Base class for "standard" async queries (all but take currently)
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):
|
||||
def __init__(
|
||||
self,
|
||||
inner: Union[LanceQuery, LanceVectorQuery],
|
||||
table: Optional["AsyncTable"] = None,
|
||||
):
|
||||
"""
|
||||
Construct an AsyncStandardQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
super().__init__(inner, table)
|
||||
|
||||
def where(self, predicate: Union[str, Expr]) -> Self:
|
||||
"""
|
||||
@@ -2616,14 +2845,14 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
|
||||
|
||||
class AsyncQuery(AsyncStandardQuery):
|
||||
def __init__(self, inner: LanceQuery):
|
||||
def __init__(self, inner: LanceQuery, table: Optional["AsyncTable"] = None):
|
||||
"""
|
||||
Construct an AsyncQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
|
||||
@classmethod
|
||||
@@ -2707,10 +2936,11 @@ class AsyncQuery(AsyncStandardQuery):
|
||||
new_self = self._inner.nearest_to(query_vectors[0])
|
||||
for v in query_vectors[1:]:
|
||||
new_self.add_query_vector(v)
|
||||
return AsyncVectorQuery(new_self)
|
||||
return AsyncVectorQuery(new_self, self._table)
|
||||
else:
|
||||
return AsyncVectorQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)),
|
||||
self._table,
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
@@ -2743,17 +2973,18 @@ class AsyncQuery(AsyncStandardQuery):
|
||||
|
||||
if isinstance(query, str):
|
||||
return AsyncFTSQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns}),
|
||||
self._table,
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}))
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}), self._table)
|
||||
|
||||
|
||||
class AsyncFTSQuery(AsyncStandardQuery):
|
||||
"""A query for full text search for LanceDB."""
|
||||
|
||||
def __init__(self, inner: LanceFTSQuery):
|
||||
super().__init__(inner)
|
||||
def __init__(self, inner: LanceFTSQuery, table: Optional["AsyncTable"] = None):
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
self._reranker = None
|
||||
|
||||
@@ -2835,10 +3066,11 @@ class AsyncFTSQuery(AsyncStandardQuery):
|
||||
new_self = self._inner.nearest_to(query_vectors[0])
|
||||
for v in query_vectors[1:]:
|
||||
new_self.add_query_vector(v)
|
||||
return AsyncHybridQuery(new_self)
|
||||
return AsyncHybridQuery(new_self, self._table)
|
||||
else:
|
||||
return AsyncHybridQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)),
|
||||
self._table,
|
||||
)
|
||||
|
||||
async def to_batches(
|
||||
@@ -3029,7 +3261,7 @@ class AsyncVectorQueryBase:
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
def __init__(self, inner: LanceVectorQuery):
|
||||
def __init__(self, inner: LanceVectorQuery, table: Optional["AsyncTable"] = None):
|
||||
"""
|
||||
Construct an AsyncVectorQuery
|
||||
|
||||
@@ -3039,7 +3271,7 @@ class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
a vector query. Or you can use
|
||||
[AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search]
|
||||
"""
|
||||
super().__init__(inner)
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
self._reranker = None
|
||||
self._query_string = None
|
||||
@@ -3093,10 +3325,13 @@ class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
|
||||
if isinstance(query, str):
|
||||
return AsyncHybridQuery(
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns}),
|
||||
self._table,
|
||||
)
|
||||
# FullTextQuery object
|
||||
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query}))
|
||||
return AsyncHybridQuery(
|
||||
self._inner.nearest_to_text({"query": query}), self._table
|
||||
)
|
||||
|
||||
async def to_batches(
|
||||
self,
|
||||
@@ -3123,8 +3358,8 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
in the `rerank` method to convert the scores to ranks and then normalize them.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: LanceHybridQuery):
|
||||
super().__init__(inner)
|
||||
def __init__(self, inner: LanceHybridQuery, table: Optional["AsyncTable"] = None):
|
||||
super().__init__(inner, table)
|
||||
self._inner = inner
|
||||
self._norm = "score"
|
||||
self._reranker = RRFReranker()
|
||||
@@ -3165,8 +3400,8 @@ class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
max_batch_length: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> AsyncRecordBatchReader:
|
||||
fts_query = AsyncFTSQuery(self._inner.to_fts_query())
|
||||
vec_query = AsyncVectorQuery(self._inner.to_vector_query())
|
||||
fts_query = AsyncFTSQuery(self._inner.to_fts_query(), self._table)
|
||||
vec_query = AsyncVectorQuery(self._inner.to_vector_query(), self._table)
|
||||
|
||||
# save the row ID choice that was made on the query builder and force it
|
||||
# to actually fetch the row ids because we need this for reranking
|
||||
@@ -3266,8 +3501,15 @@ class AsyncTakeQuery(AsyncQueryBase):
|
||||
Builder for parameterizing and executing take queries.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: LanceTakeQuery):
|
||||
super().__init__(inner)
|
||||
def __init__(self, inner: LanceTakeQuery, table: Optional["AsyncTable"] = None):
|
||||
super().__init__(inner, table)
|
||||
|
||||
async def _plain_scan_to_pandas(
|
||||
self,
|
||||
blob_mode: BlobMode,
|
||||
**kwargs,
|
||||
) -> Optional["pd.DataFrame"]:
|
||||
return None
|
||||
|
||||
|
||||
class BaseQueryBuilder(object):
|
||||
@@ -3400,6 +3642,8 @@ class BaseQueryBuilder(object):
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
*,
|
||||
blob_mode: BlobMode = "lazy",
|
||||
**kwargs,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
@@ -3433,11 +3677,15 @@ class BaseQueryBuilder(object):
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
blob_mode: str, default "lazy"
|
||||
Controls how blob columns are returned for plain scan queries.
|
||||
**kwargs
|
||||
Forwarded to pyarrow.Table.to_pandas after query execution and
|
||||
optional flattening.
|
||||
"""
|
||||
return LOOP.run(self._inner.to_pandas(flatten, timeout, **kwargs))
|
||||
return LOOP.run(
|
||||
self._inner.to_pandas(flatten, timeout, blob_mode=blob_mode, **kwargs)
|
||||
)
|
||||
|
||||
def to_polars(
|
||||
self,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import sys
|
||||
@@ -17,7 +18,7 @@ else:
|
||||
|
||||
# Remove this import to fix circular dependency
|
||||
# from lancedb import connect_async
|
||||
from lancedb.remote import ClientConfig
|
||||
from lancedb.remote import ClientConfig, RetryConfig, TimeoutConfig, TlsConfig
|
||||
import pyarrow as pa
|
||||
|
||||
from ..common import DATA
|
||||
@@ -36,6 +37,64 @@ from ..table import Table
|
||||
from ..util import validate_table_name
|
||||
|
||||
|
||||
def _duration_seconds(value: Optional[timedelta]) -> Optional[float]:
|
||||
return value.total_seconds() if value is not None else None
|
||||
|
||||
|
||||
def _timeout_config_to_dict(
|
||||
config: Optional[TimeoutConfig],
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if config is None:
|
||||
return None
|
||||
return {
|
||||
"timeout": _duration_seconds(config.timeout),
|
||||
"connect_timeout": _duration_seconds(config.connect_timeout),
|
||||
"read_timeout": _duration_seconds(config.read_timeout),
|
||||
"pool_idle_timeout": _duration_seconds(config.pool_idle_timeout),
|
||||
}
|
||||
|
||||
|
||||
def _retry_config_to_dict(config: RetryConfig) -> dict[str, Any]:
|
||||
return {
|
||||
"retries": config.retries,
|
||||
"connect_retries": config.connect_retries,
|
||||
"read_retries": config.read_retries,
|
||||
"backoff_factor": config.backoff_factor,
|
||||
"backoff_jitter": config.backoff_jitter,
|
||||
"statuses": config.statuses,
|
||||
}
|
||||
|
||||
|
||||
def _tls_config_to_dict(config: Optional[TlsConfig]) -> Optional[dict[str, Any]]:
|
||||
if config is None:
|
||||
return None
|
||||
return {
|
||||
"cert_file": config.cert_file,
|
||||
"key_file": config.key_file,
|
||||
"ssl_ca_cert": config.ssl_ca_cert,
|
||||
"assert_hostname": config.assert_hostname,
|
||||
}
|
||||
|
||||
|
||||
def _client_config_to_dict(config: ClientConfig) -> dict[str, Any]:
|
||||
if config.header_provider is not None:
|
||||
raise ValueError(
|
||||
"Cannot serialize a remote connection with a header_provider. "
|
||||
"Use static api_key/extra_headers or provide a worker-side "
|
||||
"connection factory instead."
|
||||
)
|
||||
return {
|
||||
"user_agent": config.user_agent,
|
||||
"retry_config": _retry_config_to_dict(config.retry_config),
|
||||
"timeout_config": _timeout_config_to_dict(config.timeout_config),
|
||||
"extra_headers": config.extra_headers,
|
||||
"id_delimiter": config.id_delimiter,
|
||||
"tls_config": _tls_config_to_dict(config.tls_config),
|
||||
"header_provider": None,
|
||||
"user_id": config.user_id,
|
||||
}
|
||||
|
||||
|
||||
class RemoteDBConnection(DBConnection):
|
||||
"""A connection to a remote LanceDB database."""
|
||||
|
||||
@@ -89,6 +148,11 @@ class RemoteDBConnection(DBConnection):
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_url = db_url
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.host_override = host_override
|
||||
self.storage_options = storage_options
|
||||
self.db_name = parsed.netloc
|
||||
|
||||
self.client_config = client_config
|
||||
@@ -111,6 +175,20 @@ class RemoteDBConnection(DBConnection):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
|
||||
@override
|
||||
def serialize(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"connection_type": "remote",
|
||||
"db_url": self.db_url,
|
||||
"api_key": self.api_key,
|
||||
"region": self.region,
|
||||
"host_override": self.host_override,
|
||||
"client_config": _client_config_to_dict(self.client_config),
|
||||
"storage_options": self.storage_options,
|
||||
}
|
||||
)
|
||||
|
||||
@override
|
||||
def list_namespaces(
|
||||
self,
|
||||
@@ -331,7 +409,12 @@ class RemoteDBConnection(DBConnection):
|
||||
)
|
||||
|
||||
table = LOOP.run(self._conn.open_table(name, namespace_path=namespace_path))
|
||||
return RemoteTable(table, self.db_name)
|
||||
return RemoteTable(
|
||||
table,
|
||||
self.db_name,
|
||||
connection_state=self.serialize,
|
||||
namespace_path=namespace_path,
|
||||
)
|
||||
|
||||
def clone_table(
|
||||
self,
|
||||
@@ -380,7 +463,12 @@ class RemoteDBConnection(DBConnection):
|
||||
is_shallow=is_shallow,
|
||||
)
|
||||
)
|
||||
return RemoteTable(table, self.db_name)
|
||||
return RemoteTable(
|
||||
table,
|
||||
self.db_name,
|
||||
connection_state=self.serialize,
|
||||
namespace_path=target_namespace_path,
|
||||
)
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
@@ -525,7 +613,12 @@ class RemoteDBConnection(DBConnection):
|
||||
fill_value=fill_value,
|
||||
)
|
||||
)
|
||||
return RemoteTable(table, self.db_name)
|
||||
return RemoteTable(
|
||||
table,
|
||||
self.db_name,
|
||||
connection_state=self.serialize,
|
||||
namespace_path=namespace_path,
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, namespace_path: Optional[List[str]] = None):
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import timedelta
|
||||
import deprecation
|
||||
import logging
|
||||
from functools import cached_property
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -63,14 +64,80 @@ class RemoteTable(Table):
|
||||
self,
|
||||
table: AsyncTable,
|
||||
db_name: str,
|
||||
*,
|
||||
connection_state: Optional[Union[str, Callable[[], str]]] = None,
|
||||
namespace_path: Optional[List[str]] = None,
|
||||
):
|
||||
self._table = table
|
||||
self._table_handle = table
|
||||
self._name = table.name
|
||||
self.db_name = db_name
|
||||
self._connection_state = connection_state
|
||||
self._namespace_path = list(namespace_path or [])
|
||||
self._checkout_version: Optional[int] = None
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _serialized_connection_state(self) -> str:
|
||||
if self._connection_state is None:
|
||||
raise RuntimeError(
|
||||
"Cannot reopen this remote table because it does not carry "
|
||||
"serialized connection state"
|
||||
)
|
||||
if callable(self._connection_state):
|
||||
self._connection_state = self._connection_state()
|
||||
return self._connection_state
|
||||
|
||||
@property
|
||||
def _table(self) -> AsyncTable:
|
||||
self._ensure_open()
|
||||
assert self._table_handle is not None
|
||||
return self._table_handle
|
||||
|
||||
@_table.setter
|
||||
def _table(self, table: AsyncTable) -> None:
|
||||
self._table_handle = table
|
||||
self._name = table.name
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
pid = os.getpid()
|
||||
if self._table_handle is not None and self._pid == pid:
|
||||
return
|
||||
|
||||
# Pickle clears the handle; fork inherits a handle created in the
|
||||
# parent process. In both cases reopen before touching the Rust client.
|
||||
from lancedb import deserialize_conn
|
||||
|
||||
db = deserialize_conn(self._serialized_connection_state(), for_worker=True)
|
||||
table = db.open_table(self._name, namespace_path=self._namespace_path)
|
||||
if self._checkout_version is not None:
|
||||
table.checkout(self._checkout_version)
|
||||
|
||||
self._table_handle = table._table
|
||||
self.db_name = table.db_name
|
||||
self._pid = pid
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
return {
|
||||
"connection_state": self._serialized_connection_state(),
|
||||
"db_name": self.db_name,
|
||||
"name": self.name,
|
||||
"namespace_path": self._namespace_path,
|
||||
"checkout_version": self._checkout_version,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self._table_handle = None
|
||||
self._name = state["name"]
|
||||
self.db_name = state["db_name"]
|
||||
self._connection_state = state["connection_state"]
|
||||
self._namespace_path = state["namespace_path"]
|
||||
self._checkout_version = state["checkout_version"]
|
||||
self._pid = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the table"""
|
||||
return self._table.name
|
||||
return self._name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self.db_name}.{self.name})"
|
||||
@@ -120,13 +187,19 @@ class RemoteTable(Table):
|
||||
raise NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def checkout(self, version: Union[int, str]):
|
||||
return LOOP.run(self._table.checkout(version))
|
||||
result = LOOP.run(self._table.checkout(version))
|
||||
self._checkout_version = self.version
|
||||
return result
|
||||
|
||||
def checkout_latest(self):
|
||||
return LOOP.run(self._table.checkout_latest())
|
||||
result = LOOP.run(self._table.checkout_latest())
|
||||
self._checkout_version = None
|
||||
return result
|
||||
|
||||
def restore(self, version: Optional[Union[int, str]] = None):
|
||||
return LOOP.run(self._table.restore(version))
|
||||
result = LOOP.run(self._table.restore(version))
|
||||
self._checkout_version = None
|
||||
return result
|
||||
|
||||
def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""List all the indices on the table"""
|
||||
|
||||
@@ -89,6 +89,18 @@ from .index import lang_mapping
|
||||
|
||||
BlobMode = Literal["lazy", "bytes", "descriptions"]
|
||||
|
||||
|
||||
def _field_is_blob(field: pa.Field) -> bool:
|
||||
metadata = field.metadata or {}
|
||||
return metadata.get(b"lance-encoding:blob") == b"true" or (
|
||||
metadata.get("lance-encoding:blob") == "true"
|
||||
)
|
||||
|
||||
|
||||
def _schema_has_blob_field(schema: pa.Schema) -> bool:
|
||||
return any(_field_is_blob(field) for field in schema)
|
||||
|
||||
|
||||
_MODEL_BACKED_TOKENIZER_PREFIXES = ("jieba", "lindera")
|
||||
_MODEL_BACKED_TOKENIZER_ERRORS = (
|
||||
"unknown base tokenizer",
|
||||
@@ -2270,9 +2282,13 @@ class LanceTable(Table):
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
if blob_mode == "lazy" and (
|
||||
self._namespace_client is not None
|
||||
or get_uri_scheme(self._dataset_path) == "memory"
|
||||
if blob_mode == "descriptions" or not _schema_has_blob_field(self.schema):
|
||||
return self.to_arrow().to_pandas(**kwargs)
|
||||
|
||||
if (
|
||||
blob_mode == "lazy"
|
||||
and self._namespace_client is None
|
||||
and get_uri_scheme(self._dataset_path) == "memory"
|
||||
):
|
||||
return self.to_arrow().to_pandas(**kwargs)
|
||||
|
||||
@@ -4280,7 +4296,7 @@ class AsyncTable:
|
||||
can be executed with methods like [to_arrow][lancedb.query.AsyncQuery.to_arrow],
|
||||
[to_pandas][lancedb.query.AsyncQuery.to_pandas] and more.
|
||||
"""
|
||||
return AsyncQuery(self._inner.query())
|
||||
return AsyncQuery(self._inner.query(), self)
|
||||
|
||||
async def _to_lance(self, **kwargs) -> lance.LanceDataset:
|
||||
try:
|
||||
@@ -4312,7 +4328,12 @@ class AsyncTable:
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
if blob_mode == "lazy":
|
||||
if blob_mode == "descriptions" or not _schema_has_blob_field(
|
||||
await self.schema()
|
||||
):
|
||||
return (await self.to_arrow()).to_pandas(**kwargs)
|
||||
|
||||
if blob_mode == "lazy" and get_uri_scheme(await self.uri()) == "memory":
|
||||
return (await self.to_arrow()).to_pandas(**kwargs)
|
||||
return (await self._to_lance()).to_pandas(blob_mode=blob_mode, **kwargs)
|
||||
|
||||
@@ -5349,7 +5370,7 @@ class AsyncTable:
|
||||
pa.RecordBatch
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
return AsyncTakeQuery(self._inner.take_offsets(offsets))
|
||||
return AsyncTakeQuery(self._inner.take_offsets(offsets), self)
|
||||
|
||||
def take_row_ids(self, row_ids: list[int]) -> AsyncTakeQuery:
|
||||
"""
|
||||
@@ -5378,7 +5399,7 @@ class AsyncTable:
|
||||
AsyncTakeQuery
|
||||
A query object that can be executed to get the rows.
|
||||
"""
|
||||
return AsyncTakeQuery(self._inner.take_row_ids(row_ids))
|
||||
return AsyncTakeQuery(self._inner.take_row_ids(row_ids), self)
|
||||
|
||||
@property
|
||||
def tags(self) -> AsyncTags:
|
||||
|
||||
@@ -76,6 +76,35 @@ class TestNamespaceConnection:
|
||||
assert len(result) == 0
|
||||
assert list(result.columns) == ["id", "vector", "text"]
|
||||
|
||||
def test_table_to_pandas_blob_lazy_through_namespace(self):
|
||||
"""Namespace-backed tables should use Lance blob-aware pandas conversion."""
|
||||
pytest.importorskip("lance")
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
db.create_namespace(["test_ns"])
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob",
|
||||
pa.large_binary(),
|
||||
metadata={"lance-encoding:blob": "true"},
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
table = db.create_table("blob_table", data, namespace_path=["test_ns"])
|
||||
df = table.to_pandas(blob_mode="lazy").sort_values("id")
|
||||
|
||||
blob = df["blob"].iloc[0]
|
||||
assert hasattr(blob, "readall")
|
||||
assert blob.readall() == b"hello"
|
||||
|
||||
def test_open_table_through_namespace(self):
|
||||
"""Test opening an existing table through namespace."""
|
||||
db = lancedb.connect_namespace("dir", {"root": self.temp_dir})
|
||||
|
||||
@@ -39,6 +39,35 @@ from utils import exception_output
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
||||
def _blob_query_data():
|
||||
return pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2, 3, 4], pa.int64()),
|
||||
"tag": pa.array(["drop", "keep", "keep", "keep"], pa.utf8()),
|
||||
"vector": pa.array(
|
||||
[[1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]],
|
||||
type=pa.list_(pa.float32(), list_size=2),
|
||||
),
|
||||
"blob": pa.array([b"one", b"two", b"three", b"four"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("tag", pa.utf8()),
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_lazy_blob(value, expected: bytes):
|
||||
assert hasattr(value, "readall")
|
||||
assert value.readall() == expected
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def table(tmpdir_factory) -> lancedb.table.Table:
|
||||
tmp_path = str(tmpdir_factory.mktemp("data"))
|
||||
@@ -181,6 +210,97 @@ async def test_query_to_pandas_kwargs(table, table_async):
|
||||
assert async_df["id"].tolist() == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blob_mode", ["lazy", "bytes", "descriptions"])
|
||||
def test_plain_scan_query_to_pandas_blob_modes(tmp_db, blob_mode):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
f"test_query_to_pandas_blob_{blob_mode}", _blob_query_data()
|
||||
)
|
||||
|
||||
df = (
|
||||
table.search()
|
||||
.select(["id", "blob"])
|
||||
.where("id = 1")
|
||||
.to_pandas(blob_mode=blob_mode)
|
||||
)
|
||||
|
||||
assert df["id"].tolist() == [1]
|
||||
if blob_mode == "lazy":
|
||||
_assert_lazy_blob(df["blob"].iloc[0], b"one")
|
||||
elif blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"one"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_plain_scan_query_to_pandas_blob_projection(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"test_query_to_pandas_blob_projection", _blob_query_data()
|
||||
)
|
||||
|
||||
df = (
|
||||
table.search()
|
||||
.where("id >= 2")
|
||||
.select({"id_alias": "id", "payload": "blob", "double_id": "id * 2"})
|
||||
.limit(2)
|
||||
.offset(1)
|
||||
.to_pandas(blob_mode="bytes")
|
||||
)
|
||||
|
||||
assert df["id_alias"].tolist() == [3, 4]
|
||||
assert df["payload"].tolist() == [b"three", b"four"]
|
||||
assert df["double_id"].tolist() == [6, 8]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_plain_scan_query_to_pandas_blob_projection(tmp_db_async):
|
||||
pytest.importorskip("lance")
|
||||
table = await tmp_db_async.create_table(
|
||||
"test_async_query_to_pandas_blob_projection", _blob_query_data()
|
||||
)
|
||||
|
||||
lazy_df = await (
|
||||
table.query().where("id = 1").select(["id", "blob"]).to_pandas(blob_mode="lazy")
|
||||
)
|
||||
assert lazy_df["id"].tolist() == [1]
|
||||
_assert_lazy_blob(lazy_df["blob"].iloc[0], b"one")
|
||||
|
||||
bytes_df = await (
|
||||
table.query()
|
||||
.where("id >= 2")
|
||||
.select({"id_alias": "id", "payload": "blob", "double_id": "id * 2"})
|
||||
.limit(2)
|
||||
.offset(1)
|
||||
.to_pandas(blob_mode="bytes")
|
||||
)
|
||||
assert bytes_df["id_alias"].tolist() == [3, 4]
|
||||
assert bytes_df["payload"].tolist() == [b"three", b"four"]
|
||||
assert bytes_df["double_id"].tolist() == [6, 8]
|
||||
|
||||
desc_df = await (
|
||||
table.query()
|
||||
.where("id = 1")
|
||||
.select(["blob"])
|
||||
.to_pandas(blob_mode="descriptions")
|
||||
)
|
||||
first = desc_df["blob"].iloc[0]
|
||||
assert first != b"one"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_vector_query_to_pandas_blob_mode_requires_native_path(tmp_db):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table("test_vector_query_blob_mode", _blob_query_data())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Lance native pandas conversion"):
|
||||
table.search([1.0, 0.0]).select(["blob", "vector"]).limit(1).to_pandas(
|
||||
blob_mode="lazy"
|
||||
)
|
||||
|
||||
|
||||
def test_order_by_plain_query(mem_db):
|
||||
table = mem_db.create_table(
|
||||
"test_order_by",
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -171,6 +172,155 @@ def test_table_len_sync():
|
||||
assert len(table) == 1
|
||||
|
||||
|
||||
def test_remote_connection_serializes():
|
||||
def handler(request):
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
serialized = json.loads(db.serialize())
|
||||
assert isinstance(serialized["client_config"], dict)
|
||||
restored = lancedb.deserialize_conn(db.serialize())
|
||||
assert restored.table_names() == []
|
||||
|
||||
|
||||
def test_remote_table_is_picklable():
|
||||
def handler(request):
|
||||
request.close_connection = True
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
{
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "id", "type": {"type": "int64"}, "nullable": False}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"3")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.open_table("test")
|
||||
restored = pickle.loads(pickle.dumps(table))
|
||||
assert restored.count_rows() == 3
|
||||
|
||||
|
||||
def test_remote_table_open_does_not_require_picklable_client_config():
|
||||
from lancedb.remote import HeaderProvider
|
||||
|
||||
class LocalHeaderProvider(HeaderProvider):
|
||||
def get_headers(self):
|
||||
return {"X-Test-Header": "present"}
|
||||
|
||||
def handler(request):
|
||||
request.close_connection = True
|
||||
assert request.headers.get("X-Test-Header") == "present"
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"3")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 0), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
try:
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
"header_provider": LocalHeaderProvider(),
|
||||
},
|
||||
)
|
||||
table = db.open_table("test")
|
||||
assert table.count_rows() == 3
|
||||
with pytest.raises(ValueError, match="header_provider"):
|
||||
pickle.dumps(table)
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
def test_remote_permutation_is_picklable():
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
rows = list(range(10))
|
||||
|
||||
def handler(request):
|
||||
request.close_connection = True
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
{
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "a", "type": {"type": "int64"}, "nullable": False}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(str(len(rows)).encode())
|
||||
elif request.path == "/v1/table/test/query/":
|
||||
content_len = int(request.headers.get("Content-Length"))
|
||||
body = json.loads(request.rfile.read(content_len))
|
||||
if "filter" in body:
|
||||
match = re.search(r"_rowoffset in \((.*?)\)", body["filter"])
|
||||
offsets = [int(offset.strip()) for offset in match.group(1).split(",")]
|
||||
else:
|
||||
offsets = rows
|
||||
table = pa.table({"a": [rows[offset] for offset in offsets]})
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
||||
request.end_headers()
|
||||
with pa.ipc.new_file(request.wfile, schema=table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
permutation = Permutation.identity(db.open_table("test"))
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
assert restored.__getitems__([0, 2, 4]) == [{"a": 0}, {"a": 2}, {"a": 4}]
|
||||
|
||||
|
||||
def test_create_table_exist_ok():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/create/?mode=exist_ok":
|
||||
@@ -1400,6 +1550,10 @@ def _remote_fork_child(port: int, queue) -> None:
|
||||
queue.put(db.table_names())
|
||||
|
||||
|
||||
def _remote_table_fork_child(table, queue) -> None:
|
||||
queue.put(table.count_rows())
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
@@ -1462,3 +1616,65 @@ def test_remote_connection_after_fork():
|
||||
finally:
|
||||
server.shutdown()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_inherited_remote_table_reopens_after_fork():
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"7")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler))
|
||||
port = server.server_address[1]
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.start()
|
||||
try:
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
table = db.open_table("test")
|
||||
assert table.count_rows() == 7
|
||||
|
||||
ctx = mp.get_context("fork")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_remote_table_fork_child, args=(table, queue))
|
||||
proc.start()
|
||||
proc.join(timeout=15)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail("Remote table hung after fork")
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no result"
|
||||
assert queue.get() == 7
|
||||
finally:
|
||||
server.shutdown()
|
||||
server_thread.join()
|
||||
|
||||
@@ -26,6 +26,28 @@ from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _blob_test_data():
|
||||
return pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_lazy_blob(value, expected: bytes):
|
||||
assert hasattr(value, "readall")
|
||||
assert value.readall() == expected
|
||||
|
||||
|
||||
def test_basic(mem_db: DBConnection):
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
@@ -57,27 +79,22 @@ def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection):
|
||||
pd.testing.assert_frame_equal(table.to_pandas(), expected)
|
||||
|
||||
|
||||
def test_table_to_pandas_blob_bytes(tmp_db: DBConnection):
|
||||
@pytest.mark.parametrize("blob_mode", ["lazy", "bytes", "descriptions"])
|
||||
def test_table_to_pandas_blob_modes(tmp_db: DBConnection, blob_mode):
|
||||
pytest.importorskip("lance")
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
table = tmp_db.create_table("test_to_pandas_blob_bytes", data=data)
|
||||
table = tmp_db.create_table(f"test_to_pandas_blob_{blob_mode}", _blob_test_data())
|
||||
|
||||
df = table.to_pandas(blob_mode="bytes")
|
||||
df = table.to_pandas(blob_mode=blob_mode)
|
||||
|
||||
assert df["blob"].tolist() == [b"hello", b"world"]
|
||||
if blob_mode == "lazy":
|
||||
_assert_lazy_blob(df["blob"].iloc[0], b"hello")
|
||||
_assert_lazy_blob(df["blob"].iloc[1], b"world")
|
||||
elif blob_mode == "bytes":
|
||||
assert df["blob"].tolist() == [b"hello", b"world"]
|
||||
else:
|
||||
first = df["blob"].iloc[0]
|
||||
assert first != b"hello"
|
||||
assert not hasattr(first, "readall")
|
||||
|
||||
|
||||
def test_table_to_pandas_kwargs(tmp_db: DBConnection):
|
||||
@@ -93,22 +110,8 @@ def test_table_to_pandas_kwargs(tmp_db: DBConnection):
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_table_to_pandas_blob_bytes(tmp_db_async: AsyncConnection):
|
||||
pytest.importorskip("lance")
|
||||
data = pa.table(
|
||||
{
|
||||
"id": pa.array([1, 2], pa.int64()),
|
||||
"blob": pa.array([b"hello", b"world"], pa.large_binary()),
|
||||
},
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field(
|
||||
"blob", pa.large_binary(), metadata={"lance-encoding:blob": "true"}
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
table = await tmp_db_async.create_table(
|
||||
"test_async_to_pandas_blob_bytes", data=data
|
||||
"test_async_to_pandas_blob_bytes", data=_blob_test_data()
|
||||
)
|
||||
|
||||
df = await table.to_pandas(blob_mode="bytes")
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import http.server
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
@@ -15,6 +20,107 @@ from lancedb.util import tbl_to_tensor
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
REMOTE_ROWS = list(range(100))
|
||||
|
||||
|
||||
def _make_mock_http_handler(handler):
|
||||
class MockLanceDBHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
handler(self)
|
||||
|
||||
def do_POST(self):
|
||||
handler(self)
|
||||
|
||||
return MockLanceDBHandler
|
||||
|
||||
|
||||
def _remote_schema_payload():
|
||||
return {
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "a", "type": {"type": "int64"}, "nullable": False},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _offsets_from_filter(filter_sql: str | None) -> list[int]:
|
||||
if filter_sql is None:
|
||||
return REMOTE_ROWS
|
||||
match = re.search(r"_rowoffset in \((.*?)\)", filter_sql)
|
||||
if match is None:
|
||||
return REMOTE_ROWS
|
||||
raw_offsets = match.group(1).strip()
|
||||
if raw_offsets == "":
|
||||
return []
|
||||
return [int(offset.strip()) for offset in raw_offsets.split(",")]
|
||||
|
||||
|
||||
def _remote_dataset_handler(request):
|
||||
request.close_connection = True
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(json.dumps(_remote_schema_payload()).encode())
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(str(len(REMOTE_ROWS)).encode())
|
||||
elif request.path == "/v1/table/test/query/":
|
||||
content_len = int(request.headers.get("Content-Length"))
|
||||
body = json.loads(request.rfile.read(content_len))
|
||||
offsets = _offsets_from_filter(body.get("filter"))
|
||||
requested_columns = body.get("columns") or ["a"]
|
||||
if isinstance(requested_columns, dict):
|
||||
requested_columns = list(requested_columns)
|
||||
|
||||
data = {}
|
||||
for column in requested_columns:
|
||||
if column == "a":
|
||||
data[column] = [REMOTE_ROWS[offset] for offset in offsets]
|
||||
elif column == "_rowoffset":
|
||||
data[column] = offsets
|
||||
elif column == "_rowid":
|
||||
data[column] = offsets
|
||||
|
||||
table = pa.table(data)
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
||||
request.end_headers()
|
||||
with pa.ipc.new_file(request.wfile, schema=table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _remote_dataset_table():
|
||||
with http.server.ThreadingHTTPServer(
|
||||
("localhost", 0), _make_mock_http_handler(_remote_dataset_handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
try:
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
yield db.open_table("test")
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
def _open_native_table(uri: str, table_name: str):
|
||||
"""Top-level connection factory used by the explicit-factory pickle test.
|
||||
|
||||
@@ -107,6 +213,39 @@ def test_permutation_dataloader_multiprocessing(tmp_db):
|
||||
assert seen == 1000
|
||||
|
||||
|
||||
def test_remote_table_dataloader_multiprocessing():
|
||||
with _remote_dataset_table() as table:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
table,
|
||||
collate_fn=tbl_to_tensor,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="spawn",
|
||||
)
|
||||
seen = 0
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
seen += batch.size(1)
|
||||
assert seen == len(REMOTE_ROWS)
|
||||
|
||||
|
||||
def test_remote_permutation_dataloader_multiprocessing():
|
||||
with _remote_dataset_table() as table:
|
||||
permutation = Permutation.identity(table)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="spawn",
|
||||
)
|
||||
seen = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
seen += batch["a"].size(0)
|
||||
assert seen == len(REMOTE_ROWS)
|
||||
|
||||
|
||||
def test_permutation_pickle_with_connection_factory(tmp_path):
|
||||
"""When the user provides a connection_factory, pickling should round-trip
|
||||
through that factory rather than introspecting the connection URI. Useful
|
||||
@@ -171,6 +310,35 @@ def _multiworker_dataloader_target(db_uri: str, result_queue):
|
||||
result_queue.put(count)
|
||||
|
||||
|
||||
def _remote_multiworker_dataloader_target(port: int, result_queue):
|
||||
import lancedb
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
table = db.open_table("test")
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="fork",
|
||||
)
|
||||
count = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
count += 1
|
||||
result_queue.put(count)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
@@ -208,3 +376,46 @@ def test_permutation_dataloader_fork_workers(tmp_path):
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no batches"
|
||||
assert queue.get() == 100
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_remote_permutation_dataloader_fork_workers():
|
||||
with http.server.ThreadingHTTPServer(
|
||||
("localhost", 0), _make_mock_http_handler(_remote_dataset_handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
try:
|
||||
ctx = mp.get_context("spawn")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(
|
||||
target=_remote_multiworker_dataloader_target,
|
||||
args=(port, queue),
|
||||
)
|
||||
proc.start()
|
||||
proc.join(timeout=30)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail(
|
||||
"Remote permutation hung when iterated in a fork-based "
|
||||
"DataLoader worker"
|
||||
)
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no batches"
|
||||
assert queue.get() == 10
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.30.0-beta.1"
|
||||
version = "0.30.1-beta.0"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
Reference in New Issue
Block a user