mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-28 11:30:39 +00:00
Compare commits
15 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
972c682857 | ||
|
|
4f8ee82730 | ||
|
|
131024839f | ||
|
|
3c7ddf4d0c | ||
|
|
461176f9f2 | ||
|
|
3b8996bb69 | ||
|
|
3755064e93 | ||
|
|
8773b865a9 | ||
|
|
1ee29675b3 | ||
|
|
9be28448f5 | ||
|
|
357197bacc | ||
|
|
ad51e2dd1f | ||
|
|
e9e904783c | ||
|
|
8500b16eca | ||
|
|
57e7282342 |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.24.0"
|
current_version = "0.24.1"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ name: build-linux-wheel
|
|||||||
description: "Build a manylinux wheel for lance"
|
description: "Build a manylinux wheel for lance"
|
||||||
inputs:
|
inputs:
|
||||||
python-minor-version:
|
python-minor-version:
|
||||||
description: "8, 9, 10, 11, 12"
|
description: "10, 11, 12, 13"
|
||||||
required: true
|
required: true
|
||||||
args:
|
args:
|
||||||
description: "--release"
|
description: "--release"
|
||||||
|
|||||||
2
.github/workflows/build_mac_wheel/action.yml
vendored
2
.github/workflows/build_mac_wheel/action.yml
vendored
@@ -3,7 +3,7 @@ name: build_wheel
|
|||||||
description: "Build a lance wheel"
|
description: "Build a lance wheel"
|
||||||
inputs:
|
inputs:
|
||||||
python-minor-version:
|
python-minor-version:
|
||||||
description: "8, 9, 10, 11"
|
description: "10, 11, 12, 13"
|
||||||
required: true
|
required: true
|
||||||
args:
|
args:
|
||||||
description: "--release"
|
description: "--release"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ name: build_wheel
|
|||||||
description: "Build a lance wheel"
|
description: "Build a lance wheel"
|
||||||
inputs:
|
inputs:
|
||||||
python-minor-version:
|
python-minor-version:
|
||||||
description: "8, 9, 10, 11"
|
description: "10, 11, 12, 13, 14"
|
||||||
required: true
|
required: true
|
||||||
args:
|
args:
|
||||||
description: "--release"
|
description: "--release"
|
||||||
|
|||||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
sudo apt install -y protobuf-compiler libssl-dev
|
sudo apt install -y protobuf-compiler libssl-dev
|
||||||
rustup update && rustup default
|
rustup update && rustup default
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: "pip"
|
cache: "pip"
|
||||||
|
|||||||
18
.github/workflows/pypi-publish.yml
vendored
18
.github/workflows/pypi-publish.yml
vendored
@@ -44,12 +44,12 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: "3.10"
|
||||||
- uses: ./.github/workflows/build_linux_wheel
|
- uses: ./.github/workflows/build_linux_wheel
|
||||||
with:
|
with:
|
||||||
python-minor-version: 8
|
python-minor-version: 10
|
||||||
args: "--release --strip ${{ matrix.config.extra_args }}"
|
args: "--release --strip ${{ matrix.config.extra_args }}"
|
||||||
arm-build: ${{ matrix.config.platform == 'aarch64' }}
|
arm-build: ${{ matrix.config.platform == 'aarch64' }}
|
||||||
manylinux: ${{ matrix.config.manylinux }}
|
manylinux: ${{ matrix.config.manylinux }}
|
||||||
@@ -74,12 +74,12 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: 3.12
|
python-version: "3.13"
|
||||||
- uses: ./.github/workflows/build_mac_wheel
|
- uses: ./.github/workflows/build_mac_wheel
|
||||||
with:
|
with:
|
||||||
python-minor-version: 8
|
python-minor-version: 10
|
||||||
args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels"
|
args: "--release --strip --target ${{ matrix.config.target }} --features fp16kernels"
|
||||||
- uses: ./.github/workflows/upload_wheel
|
- uses: ./.github/workflows/upload_wheel
|
||||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
if: startsWith(github.ref, 'refs/tags/python-v')
|
||||||
@@ -95,12 +95,12 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: 3.12
|
python-version: "3.13"
|
||||||
- uses: ./.github/workflows/build_windows_wheel
|
- uses: ./.github/workflows/build_windows_wheel
|
||||||
with:
|
with:
|
||||||
python-minor-version: 8
|
python-minor-version: 10
|
||||||
args: "--release --strip"
|
args: "--release --strip"
|
||||||
vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }}
|
vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }}
|
||||||
- uses: ./.github/workflows/upload_wheel
|
- uses: ./.github/workflows/upload_wheel
|
||||||
|
|||||||
28
.github/workflows/python.yml
vendored
28
.github/workflows/python.yml
vendored
@@ -36,9 +36,9 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.13"
|
||||||
- name: Install ruff
|
- name: Install ruff
|
||||||
run: |
|
run: |
|
||||||
pip install ruff==0.9.9
|
pip install ruff==0.9.9
|
||||||
@@ -61,9 +61,9 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.13"
|
||||||
- name: Install protobuf compiler
|
- name: Install protobuf compiler
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
@@ -90,9 +90,9 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.13"
|
||||||
cache: "pip"
|
cache: "pip"
|
||||||
- name: Install protobuf
|
- name: Install protobuf
|
||||||
run: |
|
run: |
|
||||||
@@ -110,7 +110,7 @@ jobs:
|
|||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-minor-version: ["9", "12"]
|
python-minor-version: ["10", "13"]
|
||||||
runs-on: "ubuntu-24.04"
|
runs-on: "ubuntu-24.04"
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y protobuf-compiler
|
sudo apt install -y protobuf-compiler
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: 3.${{ matrix.python-minor-version }}
|
python-version: 3.${{ matrix.python-minor-version }}
|
||||||
- uses: ./.github/workflows/build_linux_wheel
|
- uses: ./.github/workflows/build_linux_wheel
|
||||||
@@ -156,9 +156,9 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.13"
|
||||||
- uses: ./.github/workflows/build_mac_wheel
|
- uses: ./.github/workflows/build_mac_wheel
|
||||||
with:
|
with:
|
||||||
args: --profile ci
|
args: --profile ci
|
||||||
@@ -185,9 +185,9 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.13"
|
||||||
- uses: ./.github/workflows/build_windows_wheel
|
- uses: ./.github/workflows/build_windows_wheel
|
||||||
with:
|
with:
|
||||||
args: --profile ci
|
args: --profile ci
|
||||||
@@ -212,9 +212,9 @@ jobs:
|
|||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y protobuf-compiler
|
sudo apt install -y protobuf-compiler
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: 3.9
|
python-version: "3.10"
|
||||||
- name: Install lancedb
|
- name: Install lancedb
|
||||||
run: |
|
run: |
|
||||||
pip install "pydantic<2"
|
pip install "pydantic<2"
|
||||||
|
|||||||
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -4971,7 +4971,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.24.0"
|
version = "0.24.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
@@ -5050,7 +5050,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
version = "0.24.0"
|
version = "0.24.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow-array",
|
"arrow-array",
|
||||||
"arrow-ipc",
|
"arrow-ipc",
|
||||||
@@ -5070,7 +5070,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.27.0"
|
version = "0.27.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ Follow the [Quickstart](https://lancedb.com/docs/quickstart/) doc to set up Lanc
|
|||||||
| Python SDK | https://lancedb.github.io/lancedb/python/python/ |
|
| Python SDK | https://lancedb.github.io/lancedb/python/python/ |
|
||||||
| Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ |
|
| Typescript SDK | https://lancedb.github.io/lancedb/js/globals/ |
|
||||||
| Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html |
|
| Rust SDK | https://docs.rs/lancedb/latest/lancedb/index.html |
|
||||||
| REST API | https://docs.lancedb.com/api-reference/introduction |
|
| REST API | https://docs.lancedb.com/api-reference/rest |
|
||||||
|
|
||||||
## **Join Us and Contribute**
|
## **Join Us and Contribute**
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
# VoyageAI Embeddings
|
||||||
|
|
||||||
|
Voyage AI provides cutting-edge embedding and rerankers.
|
||||||
|
|
||||||
|
|
||||||
|
Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
|
||||||
|
You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API.
|
||||||
|
|
||||||
|
Supported models are:
|
||||||
|
|
||||||
|
**Voyage-4 Series (Latest)**
|
||||||
|
|
||||||
|
- voyage-4 (1024 dims, general-purpose and multilingual retrieval, 320K batch tokens)
|
||||||
|
- voyage-4-lite (1024 dims, optimized for latency and cost, 1M batch tokens)
|
||||||
|
- voyage-4-large (1024 dims, best retrieval quality, 120K batch tokens)
|
||||||
|
|
||||||
|
**Voyage-3 Series**
|
||||||
|
|
||||||
|
- voyage-3
|
||||||
|
- voyage-3-lite
|
||||||
|
|
||||||
|
**Domain-Specific Models**
|
||||||
|
|
||||||
|
- voyage-finance-2
|
||||||
|
- voyage-multilingual-2
|
||||||
|
- voyage-law-2
|
||||||
|
- voyage-code-2
|
||||||
|
|
||||||
|
|
||||||
|
Supported parameters (to be passed in `create` method) are:
|
||||||
|
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|--------|---------|
|
||||||
|
| `name` | `str` | `None` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-4, voyage-4-lite, voyage-4-large, voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 |
|
||||||
|
| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. |
|
||||||
|
| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. |
|
||||||
|
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
|
voyageai = EmbeddingFunctionRegistry
|
||||||
|
.get_instance()
|
||||||
|
.get("voyageai")
|
||||||
|
.create(name="voyage-3")
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = voyageai.SourceField()
|
||||||
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
|
data = [ { "text": "hello world" },
|
||||||
|
{ "text": "goodbye world" }]
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(data)
|
||||||
|
```
|
||||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-core</artifactId>
|
<artifactId>lancedb-core</artifactId>
|
||||||
<version>0.24.0</version>
|
<version>0.24.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<parent>
|
<parent>
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.24.0-final.0</version>
|
<version>0.24.1-final.0</version>
|
||||||
<relativePath>../pom.xml</relativePath>
|
<relativePath>../pom.xml</relativePath>
|
||||||
</parent>
|
</parent>
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
<groupId>com.lancedb</groupId>
|
<groupId>com.lancedb</groupId>
|
||||||
<artifactId>lancedb-parent</artifactId>
|
<artifactId>lancedb-parent</artifactId>
|
||||||
<version>0.24.0-final.0</version>
|
<version>0.24.1-final.0</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<name>${project.artifactId}</name>
|
<name>${project.artifactId}</name>
|
||||||
<description>LanceDB Java SDK Parent POM</description>
|
<description>LanceDB Java SDK Parent POM</description>
|
||||||
@@ -28,7 +28,7 @@
|
|||||||
<properties>
|
<properties>
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<arrow.version>15.0.0</arrow.version>
|
<arrow.version>15.0.0</arrow.version>
|
||||||
<lance-core.version>1.0.0-rc.2</lance-core.version>
|
<lance-core.version>1.0.4</lance-core.version>
|
||||||
<spotless.skip>false</spotless.skip>
|
<spotless.skip>false</spotless.skip>
|
||||||
<spotless.version>2.30.0</spotless.version>
|
<spotless.version>2.30.0</spotless.version>
|
||||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
version = "0.24.0"
|
version = "0.24.1"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
description.workspace = true
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-musl.node",
|
"main": "lancedb.linux-arm64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-musl.node",
|
"main": "lancedb.linux-x64-musl.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": [
|
"os": [
|
||||||
"win32"
|
"win32"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"private": false,
|
"private": false,
|
||||||
"version": "0.24.0",
|
"version": "0.24.1",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.27.1"
|
current_version = "0.28.0-beta.0"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ The Python package is a wrapper around the Rust library, `lancedb`. We use
|
|||||||
|
|
||||||
To set up your development environment, you will need to install the following:
|
To set up your development environment, you will need to install the following:
|
||||||
|
|
||||||
1. Python 3.9 or later
|
1. Python 3.10 or later
|
||||||
2. Cargo (Rust's package manager). Use [rustup](https://rustup.rs/) to install.
|
2. Cargo (Rust's package manager). Use [rustup](https://rustup.rs/) to install.
|
||||||
3. [protoc](https://grpc.io/docs/protoc-installation/) (Protocol Buffers compiler)
|
3. [protoc](https://grpc.io/docs/protoc-installation/) (Protocol Buffers compiler)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.27.1"
|
version = "0.28.0-beta.0"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -21,7 +21,7 @@ lance-core.workspace = true
|
|||||||
lance-namespace.workspace = true
|
lance-namespace.workspace = true
|
||||||
lance-io.workspace = true
|
lance-io.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39"] }
|
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py310"] }
|
||||||
pyo3-async-runtimes = { version = "0.25", features = [
|
pyo3-async-runtimes = { version = "0.25", features = [
|
||||||
"attributes",
|
"attributes",
|
||||||
"tokio-runtime",
|
"tokio-runtime",
|
||||||
@@ -34,7 +34,7 @@ tokio = { version = "1.40", features = ["sync"] }
|
|||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
pyo3-build-config = { version = "0.25", features = [
|
pyo3-build-config = { version = "0.25", features = [
|
||||||
"extension-module",
|
"extension-module",
|
||||||
"abi3-py39",
|
"abi3-py310",
|
||||||
] }
|
] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ description = "lancedb"
|
|||||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
keywords = [
|
keywords = [
|
||||||
"data-format",
|
"data-format",
|
||||||
"data-science",
|
"data-science",
|
||||||
@@ -33,10 +33,10 @@ classifiers = [
|
|||||||
"Programming Language :: Python",
|
"Programming Language :: Python",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3 :: Only",
|
"Programming Language :: Python :: 3 :: Only",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
"Topic :: Scientific/Engineering",
|
"Topic :: Scientific/Engineering",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -137,4 +137,4 @@ include = [
|
|||||||
"python/lancedb/_lancedb.pyi",
|
"python/lancedb/_lancedb.pyi",
|
||||||
]
|
]
|
||||||
exclude = ["python/tests/"]
|
exclude = ["python/tests/"]
|
||||||
pythonVersion = "3.12"
|
pythonVersion = "3.13"
|
||||||
|
|||||||
@@ -22,7 +22,12 @@ class BackgroundEventLoop:
|
|||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
def run(self, future):
|
def run(self, future):
|
||||||
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
|
concurrent_future = asyncio.run_coroutine_threadsafe(future, self.loop)
|
||||||
|
try:
|
||||||
|
return concurrent_future.result()
|
||||||
|
except BaseException:
|
||||||
|
concurrent_future.cancel()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
LOOP = BackgroundEventLoop()
|
LOOP = BackgroundEventLoop()
|
||||||
|
|||||||
@@ -275,7 +275,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
Convert image inputs to PIL Images.
|
Convert image inputs to PIL Images.
|
||||||
"""
|
"""
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
requests = attempt_import_or_raise("requests", "requests")
|
requests = attempt_import_or_raise("requests", "requests")
|
||||||
images = self.sanitize_input(images)
|
images = self.sanitize_input(images)
|
||||||
pil_images = []
|
pil_images = []
|
||||||
@@ -285,12 +285,12 @@ class ColPaliEmbeddings(EmbeddingFunction):
|
|||||||
if image.startswith(("http://", "https://")):
|
if image.startswith(("http://", "https://")):
|
||||||
response = requests.get(image, timeout=10)
|
response = requests.get(image, timeout=10)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
|
pil_images.append(PIL_Image.open(io.BytesIO(response.content)))
|
||||||
else:
|
else:
|
||||||
with PIL.Image.open(image) as im:
|
with PIL_Image.open(image) as im:
|
||||||
pil_images.append(im.copy())
|
pil_images.append(im.copy())
|
||||||
elif isinstance(image, bytes):
|
elif isinstance(image, bytes):
|
||||||
pil_images.append(PIL.Image.open(io.BytesIO(image)))
|
pil_images.append(PIL_Image.open(io.BytesIO(image)))
|
||||||
else:
|
else:
|
||||||
# Assume it's a PIL Image; will raise if invalid
|
# Assume it's a PIL Image; will raise if invalid
|
||||||
pil_images.append(image)
|
pil_images.append(image)
|
||||||
|
|||||||
@@ -77,8 +77,8 @@ class JinaEmbeddings(EmbeddingFunction):
|
|||||||
if isinstance(inputs, list):
|
if isinstance(inputs, list):
|
||||||
inputs = inputs
|
inputs = inputs
|
||||||
else:
|
else:
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if isinstance(inputs, PIL.Image.Image):
|
if isinstance(inputs, PIL_Image.Image):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@@ -89,13 +89,13 @@ class JinaEmbeddings(EmbeddingFunction):
|
|||||||
elif isinstance(image, (str, Path)):
|
elif isinstance(image, (str, Path)):
|
||||||
parsed = urlparse.urlparse(image)
|
parsed = urlparse.urlparse(image)
|
||||||
# TODO handle drive letter on windows.
|
# TODO handle drive letter on windows.
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if parsed.scheme == "file":
|
if parsed.scheme == "file":
|
||||||
pil_image = PIL.Image.open(parsed.path)
|
pil_image = PIL_Image.open(parsed.path)
|
||||||
elif parsed.scheme == "":
|
elif parsed.scheme == "":
|
||||||
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
|
pil_image = PIL_Image.open(image if os.name == "nt" else parsed.path)
|
||||||
elif parsed.scheme.startswith("http"):
|
elif parsed.scheme.startswith("http"):
|
||||||
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
pil_image = PIL_Image.open(io.BytesIO(url_retrieve(image)))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
@@ -103,9 +103,9 @@ class JinaEmbeddings(EmbeddingFunction):
|
|||||||
image_bytes = buffered.getvalue()
|
image_bytes = buffered.getvalue()
|
||||||
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||||
else:
|
else:
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL_Image.Image):
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="PNG")
|
image.save(buffered, format="PNG")
|
||||||
image_bytes = buffered.getvalue()
|
image_bytes = buffered.getvalue()
|
||||||
@@ -136,9 +136,9 @@ class JinaEmbeddings(EmbeddingFunction):
|
|||||||
elif isinstance(query, (Path, bytes)):
|
elif isinstance(query, (Path, bytes)):
|
||||||
return [self.generate_image_embedding(query)]
|
return [self.generate_image_embedding(query)]
|
||||||
else:
|
else:
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
|
|
||||||
if isinstance(query, PIL.Image.Image):
|
if isinstance(query, PIL_Image.Image):
|
||||||
return [self.generate_image_embedding(query)]
|
return [self.generate_image_embedding(query)]
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
return [self.generate_text_embeddings(query)]
|
return [self.generate_text_embeddings(query)]
|
||||||
else:
|
else:
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if isinstance(query, PIL.Image.Image):
|
if isinstance(query, PIL_Image.Image):
|
||||||
return [self.generate_image_embedding(query)]
|
return [self.generate_image_embedding(query)]
|
||||||
else:
|
else:
|
||||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||||
@@ -145,20 +145,20 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
return self._encode_and_normalize_image(image)
|
return self._encode_and_normalize_image(image)
|
||||||
|
|
||||||
def _to_pil(self, image: Union[str, bytes]):
|
def _to_pil(self, image: Union[str, bytes]):
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
return PIL.Image.open(io.BytesIO(image))
|
return PIL_Image.open(io.BytesIO(image))
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL_Image.Image):
|
||||||
return image
|
return image
|
||||||
elif isinstance(image, str):
|
elif isinstance(image, str):
|
||||||
parsed = urlparse.urlparse(image)
|
parsed = urlparse.urlparse(image)
|
||||||
# TODO handle drive letter on windows.
|
# TODO handle drive letter on windows.
|
||||||
if parsed.scheme == "file":
|
if parsed.scheme == "file":
|
||||||
return PIL.Image.open(parsed.path)
|
return PIL_Image.open(parsed.path)
|
||||||
elif parsed.scheme == "":
|
elif parsed.scheme == "":
|
||||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
return PIL_Image.open(image if os.name == "nt" else parsed.path)
|
||||||
elif parsed.scheme.startswith("http"):
|
elif parsed.scheme.startswith("http"):
|
||||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
return PIL_Image.open(io.BytesIO(url_retrieve(image)))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||||
|
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ class SigLipEmbeddings(EmbeddingFunction):
|
|||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
return [self.generate_text_embeddings(query)]
|
return [self.generate_text_embeddings(query)]
|
||||||
else:
|
else:
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if isinstance(query, PIL.Image.Image):
|
if isinstance(query, PIL_Image.Image):
|
||||||
return [self.generate_image_embedding(query)]
|
return [self.generate_image_embedding(query)]
|
||||||
else:
|
else:
|
||||||
raise TypeError("SigLIP supports str or PIL Image as query")
|
raise TypeError("SigLIP supports str or PIL Image as query")
|
||||||
@@ -127,21 +127,21 @@ class SigLipEmbeddings(EmbeddingFunction):
|
|||||||
return image_features.cpu().detach().numpy().squeeze()
|
return image_features.cpu().detach().numpy().squeeze()
|
||||||
|
|
||||||
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
|
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL_Image.Image):
|
||||||
return image.convert("RGB") if image.mode != "RGB" else image
|
return image.convert("RGB") if image.mode != "RGB" else image
|
||||||
elif isinstance(image, bytes):
|
elif isinstance(image, bytes):
|
||||||
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
|
return PIL_Image.open(io.BytesIO(image)).convert("RGB")
|
||||||
elif isinstance(image, str):
|
elif isinstance(image, str):
|
||||||
parsed = urlparse.urlparse(image)
|
parsed = urlparse.urlparse(image)
|
||||||
if parsed.scheme == "file":
|
if parsed.scheme == "file":
|
||||||
return PIL.Image.open(parsed.path).convert("RGB")
|
return PIL_Image.open(parsed.path).convert("RGB")
|
||||||
elif parsed.scheme == "":
|
elif parsed.scheme == "":
|
||||||
path = image if os.name == "nt" else parsed.path
|
path = image if os.name == "nt" else parsed.path
|
||||||
return PIL.Image.open(path).convert("RGB")
|
return PIL_Image.open(path).convert("RGB")
|
||||||
elif parsed.scheme.startswith("http"):
|
elif parsed.scheme.startswith("http"):
|
||||||
image_bytes = url_retrieve(image)
|
image_bytes = url_retrieve(image)
|
||||||
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
return PIL_Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
# Token limits for different VoyageAI models
|
# Token limits for different VoyageAI models
|
||||||
VOYAGE_TOTAL_TOKEN_LIMITS = {
|
VOYAGE_TOTAL_TOKEN_LIMITS = {
|
||||||
|
"voyage-4": 320_000,
|
||||||
|
"voyage-4-lite": 1_000_000,
|
||||||
|
"voyage-4-large": 120_000,
|
||||||
"voyage-context-3": 32_000,
|
"voyage-context-3": 32_000,
|
||||||
"voyage-3.5-lite": 1_000_000,
|
"voyage-3.5-lite": 1_000_000,
|
||||||
"voyage-3.5": 320_000,
|
"voyage-3.5": 320_000,
|
||||||
@@ -61,7 +64,7 @@ def is_video_path(path: Path) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def transform_input(input_data: Union[str, bytes, Path]):
|
def transform_input(input_data: Union[str, bytes, Path]):
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if isinstance(input_data, str):
|
if isinstance(input_data, str):
|
||||||
if is_valid_url(input_data):
|
if is_valid_url(input_data):
|
||||||
if is_video_url(input_data):
|
if is_video_url(input_data):
|
||||||
@@ -70,7 +73,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
|||||||
content = {"type": "image_url", "image_url": input_data}
|
content = {"type": "image_url", "image_url": input_data}
|
||||||
else:
|
else:
|
||||||
content = {"type": "text", "text": input_data}
|
content = {"type": "text", "text": input_data}
|
||||||
elif isinstance(input_data, PIL.Image.Image):
|
elif isinstance(input_data, PIL_Image.Image):
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
input_data.save(buffered, format="JPEG")
|
input_data.save(buffered, format="JPEG")
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
@@ -79,7 +82,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
|||||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||||
}
|
}
|
||||||
elif isinstance(input_data, bytes):
|
elif isinstance(input_data, bytes):
|
||||||
img = PIL.Image.open(BytesIO(input_data))
|
img = PIL_Image.open(BytesIO(input_data))
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
img.save(buffered, format="JPEG")
|
img.save(buffered, format="JPEG")
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
@@ -98,7 +101,7 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
|||||||
"video_base64": video_str,
|
"video_base64": video_str,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
img = PIL.Image.open(input_data)
|
img = PIL_Image.open(input_data)
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
img.save(buffered, format="JPEG")
|
img.save(buffered, format="JPEG")
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
@@ -116,8 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
|||||||
"""
|
"""
|
||||||
Sanitize the input to the embedding function.
|
Sanitize the input to the embedding function.
|
||||||
"""
|
"""
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL_Image = attempt_import_or_raise("PIL.Image", "pillow")
|
||||||
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
|
if isinstance(inputs, (str, bytes, Path, PIL_Image.Image)):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
elif isinstance(inputs, list):
|
elif isinstance(inputs, list):
|
||||||
pass # Already a list, use as-is
|
pass # Already a list, use as-is
|
||||||
@@ -130,7 +133,7 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
|||||||
f"Input type {type(inputs)} not allowed with multimodal model."
|
f"Input type {type(inputs)} not allowed with multimodal model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
|
if not all(isinstance(x, (str, bytes, Path, PIL_Image.Image)) for x in inputs):
|
||||||
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
||||||
|
|
||||||
return [transform_input(i) for i in inputs]
|
return [transform_input(i) for i in inputs]
|
||||||
@@ -167,6 +170,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
name: str
|
name: str
|
||||||
The name of the model to use. List of acceptable models:
|
The name of the model to use. List of acceptable models:
|
||||||
|
|
||||||
|
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
|
||||||
|
* voyage-4-lite (1024 dims, optimized for latency and cost)
|
||||||
|
* voyage-4-large (1024 dims, best retrieval quality)
|
||||||
* voyage-context-3
|
* voyage-context-3
|
||||||
* voyage-3.5
|
* voyage-3.5
|
||||||
* voyage-3.5-lite
|
* voyage-3.5-lite
|
||||||
@@ -215,6 +221,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
|
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
|
||||||
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
|
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
|
||||||
text_embedding_models: list = [
|
text_embedding_models: list = [
|
||||||
|
"voyage-4",
|
||||||
|
"voyage-4-lite",
|
||||||
|
"voyage-4-large",
|
||||||
"voyage-3.5",
|
"voyage-3.5",
|
||||||
"voyage-3.5-lite",
|
"voyage-3.5-lite",
|
||||||
"voyage-3",
|
"voyage-3",
|
||||||
@@ -252,6 +261,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
elif self.name == "voyage-code-2":
|
elif self.name == "voyage-code-2":
|
||||||
return 1536
|
return 1536
|
||||||
elif self.name in [
|
elif self.name in [
|
||||||
|
"voyage-4",
|
||||||
|
"voyage-4-lite",
|
||||||
|
"voyage-4-large",
|
||||||
"voyage-context-3",
|
"voyage-context-3",
|
||||||
"voyage-3.5",
|
"voyage-3.5",
|
||||||
"voyage-3.5-lite",
|
"voyage-3.5-lite",
|
||||||
|
|||||||
@@ -275,7 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
|||||||
return pa.timestamp("us", tz=tz)
|
return pa.timestamp("us", tz=tz)
|
||||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||||
child = py_type.__args__[0]
|
child = py_type.__args__[0]
|
||||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
return _pydantic_list_child_to_arrow(child, field)
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||||
)
|
)
|
||||||
@@ -298,12 +298,18 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||||
|
def _safe_issubclass(candidate: Any, base: type) -> bool:
|
||||||
|
try:
|
||||||
|
return issubclass(candidate, base)
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
if inspect.isclass(tp):
|
if inspect.isclass(tp):
|
||||||
if issubclass(tp, pydantic.BaseModel):
|
if _safe_issubclass(tp, pydantic.BaseModel):
|
||||||
# Struct
|
# Struct
|
||||||
fields = _pydantic_model_to_fields(tp)
|
fields = _pydantic_model_to_fields(tp)
|
||||||
return pa.struct(fields)
|
return pa.struct(fields)
|
||||||
if issubclass(tp, FixedSizeListMixin):
|
if _safe_issubclass(tp, FixedSizeListMixin):
|
||||||
if getattr(tp, "is_multi_vector", lambda: False)():
|
if getattr(tp, "is_multi_vector", lambda: False)():
|
||||||
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
|
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
|
||||||
# For regular Vector
|
# For regular Vector
|
||||||
@@ -311,45 +317,67 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
|||||||
return _py_type_to_arrow_type(tp, field)
|
return _py_type_to_arrow_type(tp, field)
|
||||||
|
|
||||||
|
|
||||||
|
def _pydantic_list_child_to_arrow(child: Any, field: FieldInfo) -> pa.DataType:
|
||||||
|
unwrapped = _unwrap_optional_annotation(child)
|
||||||
|
if unwrapped is not None:
|
||||||
|
return pa.list_(
|
||||||
|
pa.field("item", _pydantic_type_to_arrow_type(unwrapped, field), True)
|
||||||
|
)
|
||||||
|
return pa.list_(_pydantic_type_to_arrow_type(child, field))
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_optional_annotation(annotation: Any) -> Any | None:
|
||||||
|
if isinstance(annotation, (_GenericAlias, GenericAlias)):
|
||||||
|
origin = annotation.__origin__
|
||||||
|
args = annotation.__args__
|
||||||
|
if origin == Union:
|
||||||
|
non_none = [arg for arg in args if arg is not type(None)]
|
||||||
|
if len(non_none) == 1 and len(non_none) != len(args):
|
||||||
|
return non_none[0]
|
||||||
|
elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType):
|
||||||
|
args = annotation.__args__
|
||||||
|
non_none = [arg for arg in args if arg is not type(None)]
|
||||||
|
if len(non_none) == 1 and len(non_none) != len(args):
|
||||||
|
return non_none[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||||
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
||||||
|
unwrapped = _unwrap_optional_annotation(field.annotation)
|
||||||
|
if unwrapped is not None:
|
||||||
|
return _pydantic_type_to_arrow_type(unwrapped, field)
|
||||||
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
||||||
origin = field.annotation.__origin__
|
origin = field.annotation.__origin__
|
||||||
args = field.annotation.__args__
|
args = field.annotation.__args__
|
||||||
|
|
||||||
if origin is list:
|
if origin is list:
|
||||||
child = args[0]
|
child = args[0]
|
||||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
return _pydantic_list_child_to_arrow(child, field)
|
||||||
elif origin == Union:
|
|
||||||
if len(args) == 2 and args[1] is type(None):
|
|
||||||
return _pydantic_type_to_arrow_type(args[0], field)
|
|
||||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
|
||||||
args = field.annotation.__args__
|
|
||||||
if len(args) == 2:
|
|
||||||
for typ in args:
|
|
||||||
if typ is type(None):
|
|
||||||
continue
|
|
||||||
return _py_type_to_arrow_type(typ, field)
|
|
||||||
return _pydantic_type_to_arrow_type(field.annotation, field)
|
return _pydantic_type_to_arrow_type(field.annotation, field)
|
||||||
|
|
||||||
|
|
||||||
def is_nullable(field: FieldInfo) -> bool:
|
def is_nullable(field: FieldInfo) -> bool:
|
||||||
"""Check if a Pydantic FieldInfo is nullable."""
|
"""Check if a Pydantic FieldInfo is nullable."""
|
||||||
|
if _unwrap_optional_annotation(field.annotation) is not None:
|
||||||
|
return True
|
||||||
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
||||||
origin = field.annotation.__origin__
|
origin = field.annotation.__origin__
|
||||||
args = field.annotation.__args__
|
args = field.annotation.__args__
|
||||||
if origin == Union:
|
if origin == Union:
|
||||||
if len(args) == 2 and args[1] is type(None):
|
if any(typ is type(None) for typ in args):
|
||||||
return True
|
return True
|
||||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
||||||
args = field.annotation.__args__
|
args = field.annotation.__args__
|
||||||
for typ in args:
|
for typ in args:
|
||||||
if typ is type(None):
|
if typ is type(None):
|
||||||
return True
|
return True
|
||||||
elif inspect.isclass(field.annotation) and issubclass(
|
elif inspect.isclass(field.annotation):
|
||||||
field.annotation, FixedSizeListMixin
|
try:
|
||||||
):
|
if issubclass(field.annotation, FixedSizeListMixin):
|
||||||
return field.annotation.nullable()
|
return field.annotation.nullable()
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -517,19 +517,36 @@ def test_ollama_embedding(tmp_path):
|
|||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||||
)
|
)
|
||||||
def test_voyageai_embedding_function():
|
@pytest.mark.parametrize(
|
||||||
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
|
"model_name,expected_dims",
|
||||||
|
[
|
||||||
|
("voyage-3", 1024),
|
||||||
|
("voyage-4", 1024),
|
||||||
|
("voyage-4-lite", 1024),
|
||||||
|
("voyage-4-large", 1024),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
|
||||||
|
"""Integration test for VoyageAI text embedding models with real API calls."""
|
||||||
|
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
|
||||||
|
|
||||||
class TextModel(LanceModel):
|
class TextModel(LanceModel):
|
||||||
text: str = voyageai.SourceField()
|
text: str = voyageai.SourceField()
|
||||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||||
|
|
||||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
db = lancedb.connect("~/lancedb")
|
db = lancedb.connect(tmp_path)
|
||||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
tbl.add(df)
|
tbl.add(df)
|
||||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||||
|
assert voyageai.ndims() == expected_dims, (
|
||||||
|
f"{model_name} should have {expected_dims} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test search functionality
|
||||||
|
result = tbl.search("hello").limit(1).to_pandas()
|
||||||
|
assert result["text"][0] == "hello world"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
|||||||
@@ -438,11 +438,15 @@ def test_filter_with_splits(mem_db):
|
|||||||
row_count = permutation_tbl.count_rows()
|
row_count = permutation_tbl.count_rows()
|
||||||
assert row_count == 67
|
assert row_count == 67
|
||||||
|
|
||||||
data = permutation_tbl.search(None).to_arrow().to_pydict()
|
# Verify the permutation table only contains row_id and split_id
|
||||||
|
assert set(permutation_tbl.schema.names) == {"row_id", "split_id"}
|
||||||
|
|
||||||
|
row_ids = permutation_tbl.search(None).to_arrow().to_pydict()["row_id"]
|
||||||
|
data = tbl.take_row_ids(row_ids).to_arrow().to_pydict()
|
||||||
categories = data["category"]
|
categories = data["category"]
|
||||||
|
|
||||||
# All categories should be A or B
|
# All categories should be A or B
|
||||||
assert all(cat in ["A", "B"] for cat in categories)
|
assert all(cat in ("A", "B") for cat in categories)
|
||||||
|
|
||||||
|
|
||||||
def test_filter_with_shuffle(mem_db):
|
def test_filter_with_shuffle(mem_db):
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -20,10 +19,6 @@ from pydantic import BaseModel
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9),
|
|
||||||
reason="using native type alias requires python3.9 or higher",
|
|
||||||
)
|
|
||||||
def test_pydantic_to_arrow():
|
def test_pydantic_to_arrow():
|
||||||
class StructModel(pydantic.BaseModel):
|
class StructModel(pydantic.BaseModel):
|
||||||
a: str
|
a: str
|
||||||
@@ -83,10 +78,6 @@ def test_pydantic_to_arrow():
|
|||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 10),
|
|
||||||
reason="using | type syntax requires python3.10 or higher",
|
|
||||||
)
|
|
||||||
def test_optional_types_py310():
|
def test_optional_types_py310():
|
||||||
class TestModel(pydantic.BaseModel):
|
class TestModel(pydantic.BaseModel):
|
||||||
a: str | None
|
a: str | None
|
||||||
@@ -105,10 +96,233 @@ def test_optional_types_py310():
|
|||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
def test_optional_structs():
|
||||||
sys.version_info > (3, 8),
|
class SplitInfo(pydantic.BaseModel):
|
||||||
reason="using native type alias requires python3.9 or higher",
|
start_frame: int
|
||||||
)
|
end_frame: int
|
||||||
|
|
||||||
|
class TestModel(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
split: SplitInfo | None = None
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(TestModel)
|
||||||
|
|
||||||
|
expect_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field(
|
||||||
|
"split",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("start_frame", pa.int64(), False),
|
||||||
|
pa.field("end_frame", pa.int64(), False),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_struct_list_py310():
|
||||||
|
class SplitInfo(pydantic.BaseModel):
|
||||||
|
start_frame: int
|
||||||
|
end_frame: int
|
||||||
|
|
||||||
|
class TestModel(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
splits: list[SplitInfo] | None = None
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(TestModel)
|
||||||
|
|
||||||
|
expect_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field(
|
||||||
|
"splits",
|
||||||
|
pa.list_(
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("start_frame", pa.int64(), False),
|
||||||
|
pa.field("end_frame", pa.int64(), False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_struct_list():
|
||||||
|
class SplitInfo(pydantic.BaseModel):
|
||||||
|
start_frame: int
|
||||||
|
end_frame: int
|
||||||
|
|
||||||
|
class TestModel(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
splits: list[SplitInfo]
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(TestModel)
|
||||||
|
|
||||||
|
expect_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field(
|
||||||
|
"splits",
|
||||||
|
pa.list_(
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("start_frame", pa.int64(), False),
|
||||||
|
pa.field("end_frame", pa.int64(), False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_struct_list_optional():
|
||||||
|
class SplitInfo(pydantic.BaseModel):
|
||||||
|
start_frame: int
|
||||||
|
end_frame: int
|
||||||
|
|
||||||
|
class TestModel(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
splits: Optional[list[SplitInfo]] = None
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(TestModel)
|
||||||
|
|
||||||
|
expect_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field(
|
||||||
|
"splits",
|
||||||
|
pa.list_(
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("start_frame", pa.int64(), False),
|
||||||
|
pa.field("end_frame", pa.int64(), False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_struct_list_optional_items():
|
||||||
|
class SplitInfo(pydantic.BaseModel):
|
||||||
|
start_frame: int
|
||||||
|
end_frame: int
|
||||||
|
|
||||||
|
class TestModel(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
splits: list[Optional[SplitInfo]]
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(TestModel)
|
||||||
|
|
||||||
|
expect_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field(
|
||||||
|
"splits",
|
||||||
|
pa.list_(
|
||||||
|
pa.field(
|
||||||
|
"item",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("start_frame", pa.int64(), False),
|
||||||
|
pa.field("end_frame", pa.int64(), False),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_struct_list_optional_container_and_items():
|
||||||
|
class SplitInfo(pydantic.BaseModel):
|
||||||
|
start_frame: int
|
||||||
|
end_frame: int
|
||||||
|
|
||||||
|
class TestModel(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
splits: Optional[list[Optional[SplitInfo]]] = None
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(TestModel)
|
||||||
|
|
||||||
|
expect_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field(
|
||||||
|
"splits",
|
||||||
|
pa.list_(
|
||||||
|
pa.field(
|
||||||
|
"item",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("start_frame", pa.int64(), False),
|
||||||
|
pa.field("end_frame", pa.int64(), False),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_struct_list_optional_items_pep604():
|
||||||
|
class SplitInfo(pydantic.BaseModel):
|
||||||
|
start_frame: int
|
||||||
|
end_frame: int
|
||||||
|
|
||||||
|
class TestModel(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
splits: list[SplitInfo | None]
|
||||||
|
|
||||||
|
schema = pydantic_to_schema(TestModel)
|
||||||
|
|
||||||
|
expect_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.utf8(), False),
|
||||||
|
pa.field(
|
||||||
|
"splits",
|
||||||
|
pa.list_(
|
||||||
|
pa.field(
|
||||||
|
"item",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("start_frame", pa.int64(), False),
|
||||||
|
pa.field("end_frame", pa.int64(), False),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
def test_pydantic_to_arrow_py38():
|
def test_pydantic_to_arrow_py38():
|
||||||
class StructModel(pydantic.BaseModel):
|
class StructModel(pydantic.BaseModel):
|
||||||
a: str
|
a: str
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import http.server
|
|||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
import uuid
|
import uuid
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
@@ -1203,3 +1203,22 @@ async def test_header_provider_overrides_static_headers():
|
|||||||
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
|
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
|
||||||
) as db:
|
) as db:
|
||||||
await db.table_names()
|
await db.table_names()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
|
||||||
|
def test_background_loop_cancellation(exception):
|
||||||
|
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
|
||||||
|
from lancedb.background_loop import BackgroundEventLoop
|
||||||
|
|
||||||
|
mock_future = MagicMock()
|
||||||
|
mock_future.result.side_effect = exception()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(BackgroundEventLoop, "__init__", return_value=None),
|
||||||
|
patch("asyncio.run_coroutine_threadsafe", return_value=mock_future),
|
||||||
|
):
|
||||||
|
loop = BackgroundEventLoop()
|
||||||
|
loop.loop = MagicMock()
|
||||||
|
with pytest.raises(exception):
|
||||||
|
loop.run(None)
|
||||||
|
mock_future.cancel.assert_called_once()
|
||||||
|
|||||||
108
python/python/tests/test_voyageai_embeddings.py
Normal file
108
python/python/tests/test_voyageai_embeddings.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
"""Unit tests for VoyageAI embedding function.
|
||||||
|
|
||||||
|
These tests verify model registration and configuration without requiring API calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_voyageai_client():
|
||||||
|
"""Reset VoyageAI client before and after each test to avoid state pollution."""
|
||||||
|
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
|
||||||
|
|
||||||
|
VoyageAIEmbeddingFunction.client = None
|
||||||
|
yield
|
||||||
|
VoyageAIEmbeddingFunction.client = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoyageAIModelRegistration:
|
||||||
|
"""Tests for VoyageAI model registration and configuration."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_voyageai_client(self):
|
||||||
|
"""Mock VoyageAI client to avoid API calls."""
|
||||||
|
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
|
||||||
|
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_voyageai = MagicMock()
|
||||||
|
mock_voyageai.Client.return_value = mock_client
|
||||||
|
mock.return_value = mock_voyageai
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
|
def test_voyageai_registered(self):
|
||||||
|
"""Test that VoyageAI is registered in the embedding function registry."""
|
||||||
|
registry = get_registry()
|
||||||
|
assert registry.get("voyageai") is not None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name,expected_dims",
|
||||||
|
[
|
||||||
|
# Voyage-4 series (all 1024 dims)
|
||||||
|
("voyage-4", 1024),
|
||||||
|
("voyage-4-lite", 1024),
|
||||||
|
("voyage-4-large", 1024),
|
||||||
|
# Voyage-3 series
|
||||||
|
("voyage-3", 1024),
|
||||||
|
("voyage-3-lite", 512),
|
||||||
|
# Domain-specific models
|
||||||
|
("voyage-finance-2", 1024),
|
||||||
|
("voyage-multilingual-2", 1024),
|
||||||
|
("voyage-law-2", 1024),
|
||||||
|
("voyage-code-2", 1536),
|
||||||
|
# Multimodal
|
||||||
|
("voyage-multimodal-3", 1024),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
|
||||||
|
"""Test that each model returns the correct dimensions."""
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("voyageai").create(name=model_name)
|
||||||
|
assert func.ndims() == expected_dims, (
|
||||||
|
f"Model {model_name} should have {expected_dims} dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_unsupported_model_raises_error(self, mock_voyageai_client):
|
||||||
|
"""Test that unsupported models raise ValueError."""
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("voyageai").create(name="unsupported-model")
|
||||||
|
with pytest.raises(ValueError, match="not supported"):
|
||||||
|
func.ndims()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[
|
||||||
|
"voyage-4",
|
||||||
|
"voyage-4-lite",
|
||||||
|
"voyage-4-large",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
|
||||||
|
"""Test that voyage-4 models are classified as text models (not multimodal)."""
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("voyageai").create(name=model_name)
|
||||||
|
assert not func._is_multimodal_model(model_name), (
|
||||||
|
f"{model_name} should be a text model, not multimodal"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
|
||||||
|
"""Test that voyage-4 models are in the text_embedding_models list."""
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("voyageai").create(name="voyage-4")
|
||||||
|
assert "voyage-4" in func.text_embedding_models
|
||||||
|
assert "voyage-4-lite" in func.text_embedding_models
|
||||||
|
assert "voyage-4-large" in func.text_embedding_models
|
||||||
|
|
||||||
|
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
|
||||||
|
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("voyageai").create(name="voyage-4")
|
||||||
|
assert "voyage-4" not in func.multimodal_embedding_models
|
||||||
|
assert "voyage-4-lite" not in func.multimodal_embedding_models
|
||||||
|
assert "voyage-4-large" not in func.multimodal_embedding_models
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.24.0"
|
version = "0.24.1"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|||||||
@@ -251,8 +251,36 @@ impl CreateTableBuilder<false> {
|
|||||||
/// Execute the create table operation
|
/// Execute the create table operation
|
||||||
pub async fn execute(self) -> Result<Table> {
|
pub async fn execute(self) -> Result<Table> {
|
||||||
let parent = self.parent.clone();
|
let parent = self.parent.clone();
|
||||||
let table = parent.create_table(self.request).await?;
|
let embedding_registry = self.embedding_registry.clone();
|
||||||
Ok(Table::new(table, parent))
|
let request = self.into_request()?;
|
||||||
|
Ok(Table::new_with_embedding_registry(
|
||||||
|
parent.create_table(request).await?,
|
||||||
|
parent,
|
||||||
|
embedding_registry,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_request(self) -> Result<CreateTableRequest> {
|
||||||
|
if self.embeddings.is_empty() {
|
||||||
|
return Ok(self.request);
|
||||||
|
}
|
||||||
|
|
||||||
|
let CreateTableData::Empty(table_def) = self.request.data else {
|
||||||
|
unreachable!("CreateTableBuilder<false> should always have Empty data")
|
||||||
|
};
|
||||||
|
|
||||||
|
let schema = table_def.schema.clone();
|
||||||
|
let empty_batch = arrow_array::RecordBatch::new_empty(schema.clone());
|
||||||
|
|
||||||
|
let reader = Box::new(std::iter::once(Ok(empty_batch)).collect::<Vec<_>>());
|
||||||
|
let reader = arrow_array::RecordBatchIterator::new(reader.into_iter(), schema);
|
||||||
|
let with_embeddings = WithEmbeddings::new(reader, self.embeddings);
|
||||||
|
let table_definition = with_embeddings.table_definition()?;
|
||||||
|
|
||||||
|
Ok(CreateTableRequest {
|
||||||
|
data: CreateTableData::Empty(table_definition),
|
||||||
|
..self.request
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1692,4 +1720,128 @@ mod tests {
|
|||||||
let cloned_count = cloned_table.count_rows(None).await.unwrap();
|
let cloned_count = cloned_table.count_rows(None).await.unwrap();
|
||||||
assert_eq!(source_count, cloned_count);
|
assert_eq!(source_count, cloned_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_empty_table_with_embeddings() {
|
||||||
|
use crate::embeddings::{EmbeddingDefinition, EmbeddingFunction};
|
||||||
|
use arrow_array::{
|
||||||
|
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
|
||||||
|
};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct MockEmbedding {
|
||||||
|
dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingFunction for MockEmbedding {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"test_embedding"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn source_type(&self) -> Result<Cow<'_, DataType>> {
|
||||||
|
Ok(Cow::Owned(DataType::Utf8))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
|
||||||
|
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||||
|
DataType::Float32,
|
||||||
|
self.dim as i32,
|
||||||
|
true,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||||
|
let len = source.len();
|
||||||
|
let values = vec![1.0f32; len * self.dim];
|
||||||
|
let values = Arc::new(Float32Array::from(values));
|
||||||
|
let field = Arc::new(Field::new("item", DataType::Float32, true));
|
||||||
|
Ok(Arc::new(FixedSizeListArray::new(
|
||||||
|
field,
|
||||||
|
self.dim as i32,
|
||||||
|
values,
|
||||||
|
None,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
let db = connect(uri).execute().await.unwrap();
|
||||||
|
|
||||||
|
let embed_func = Arc::new(MockEmbedding { dim: 128 });
|
||||||
|
db.embedding_registry()
|
||||||
|
.register("test_embedding", embed_func.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
|
||||||
|
let ed = EmbeddingDefinition {
|
||||||
|
source_column: "name".to_owned(),
|
||||||
|
dest_column: Some("name_embedding".to_owned()),
|
||||||
|
embedding_name: "test_embedding".to_owned(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let table = db
|
||||||
|
.create_empty_table("test", schema)
|
||||||
|
.mode(CreateTableMode::Overwrite)
|
||||||
|
.add_embedding(ed)
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let table_schema = table.schema().await.unwrap();
|
||||||
|
assert!(table_schema.column_with_name("name").is_some());
|
||||||
|
assert!(table_schema.column_with_name("name_embedding").is_some());
|
||||||
|
|
||||||
|
let embedding_field = table_schema.field_with_name("name_embedding").unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
embedding_field.data_type(),
|
||||||
|
&DataType::new_fixed_size_list(DataType::Float32, 128, true)
|
||||||
|
);
|
||||||
|
|
||||||
|
let input_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
|
||||||
|
let input_batch = RecordBatch::try_new(
|
||||||
|
input_schema.clone(),
|
||||||
|
vec![Arc::new(StringArray::from(vec![
|
||||||
|
Some("Alice"),
|
||||||
|
Some("Bob"),
|
||||||
|
Some("Charlie"),
|
||||||
|
]))],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let input_reader = Box::new(RecordBatchIterator::new(
|
||||||
|
vec![Ok(input_batch)].into_iter(),
|
||||||
|
input_schema,
|
||||||
|
));
|
||||||
|
|
||||||
|
table.add(input_reader).execute().await.unwrap();
|
||||||
|
|
||||||
|
let results = table
|
||||||
|
.query()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
let batch = &results[0];
|
||||||
|
assert_eq!(batch.num_rows(), 3);
|
||||||
|
assert!(batch.column_by_name("name_embedding").is_some());
|
||||||
|
|
||||||
|
let embedding_col = batch
|
||||||
|
.column_by_name("name_embedding")
|
||||||
|
.unwrap()
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<FixedSizeListArray>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(embedding_col.len(), 3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use crate::{
|
|||||||
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
|
||||||
util::{rename_column, TemporaryDirectory},
|
util::{rename_column, TemporaryDirectory},
|
||||||
},
|
},
|
||||||
query::{ExecutableQuery, QueryBase},
|
query::{ExecutableQuery, QueryBase, Select},
|
||||||
Error, Result, Table,
|
Error, Result, Table,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -27,6 +27,8 @@ pub const SRC_ROW_ID_COL: &str = "row_id";
|
|||||||
|
|
||||||
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
|
pub const SPLIT_NAMES_CONFIG_KEY: &str = "split_names";
|
||||||
|
|
||||||
|
pub const DEFAULT_MEMORY_LIMIT: usize = 100 * 1024 * 1024;
|
||||||
|
|
||||||
/// Where to store the permutation table
|
/// Where to store the permutation table
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
enum PermutationDestination {
|
enum PermutationDestination {
|
||||||
@@ -167,10 +169,20 @@ impl PermutationBuilder {
|
|||||||
&self,
|
&self,
|
||||||
data: SendableRecordBatchStream,
|
data: SendableRecordBatchStream,
|
||||||
) -> Result<SendableRecordBatchStream> {
|
) -> Result<SendableRecordBatchStream> {
|
||||||
|
let memory_limit = std::env::var("LANCEDB_PERM_BUILDER_MEMORY_LIMIT")
|
||||||
|
.unwrap_or_else(|_| DEFAULT_MEMORY_LIMIT.to_string())
|
||||||
|
.parse::<usize>()
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
log::error!(
|
||||||
|
"Failed to parse LANCEDB_PERM_BUILDER_MEMORY_LIMIT, using default: {}",
|
||||||
|
DEFAULT_MEMORY_LIMIT
|
||||||
|
);
|
||||||
|
DEFAULT_MEMORY_LIMIT
|
||||||
|
});
|
||||||
let ctx = SessionContext::new_with_config_rt(
|
let ctx = SessionContext::new_with_config_rt(
|
||||||
SessionConfig::default(),
|
SessionConfig::default(),
|
||||||
RuntimeEnvBuilder::new()
|
RuntimeEnvBuilder::new()
|
||||||
.with_memory_limit(100 * 1024 * 1024, 1.0)
|
.with_memory_limit(memory_limit, 1.0)
|
||||||
.with_disk_manager_builder(
|
.with_disk_manager_builder(
|
||||||
DiskManagerBuilder::default()
|
DiskManagerBuilder::default()
|
||||||
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
|
||||||
@@ -232,7 +244,7 @@ impl PermutationBuilder {
|
|||||||
/// Builds the permutation table and stores it in the given database.
|
/// Builds the permutation table and stores it in the given database.
|
||||||
pub async fn build(self) -> Result<Table> {
|
pub async fn build(self) -> Result<Table> {
|
||||||
// First pass, apply filter and load row ids
|
// First pass, apply filter and load row ids
|
||||||
let mut rows = self.base_table.query().with_row_id();
|
let mut rows = self.base_table.query().select(Select::columns(&[ROW_ID]));
|
||||||
|
|
||||||
if let Some(filter) = &self.config.filter {
|
if let Some(filter) = &self.config.filter {
|
||||||
rows = rows.only_if(filter);
|
rows = rows.only_if(filter);
|
||||||
@@ -321,6 +333,47 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_permutation_table_only_stores_row_id_and_split_id() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
|
||||||
|
let db = connect(temp_dir.path().to_str().unwrap())
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let initial_data = lance_datagen::gen_batch()
|
||||||
|
.col("col_a", lance_datagen::array::step::<Int32Type>())
|
||||||
|
.col("col_b", lance_datagen::array::step::<Int32Type>())
|
||||||
|
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
|
||||||
|
let data_table = db
|
||||||
|
.create_table_streaming("base_tbl", initial_data)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let permutation_table = PermutationBuilder::new(data_table.clone())
|
||||||
|
.with_split_strategy(
|
||||||
|
SplitStrategy::Sequential {
|
||||||
|
sizes: SplitSizes::Percentages(vec![0.5, 0.5]),
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.with_filter("col_a > 57".to_string())
|
||||||
|
.build()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let schema = permutation_table.schema().await.unwrap();
|
||||||
|
let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
field_names,
|
||||||
|
vec!["row_id", "split_id"],
|
||||||
|
"Permutation table should only contain row_id and split_id columns, but found: {:?}",
|
||||||
|
field_names,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_permutation_builder() {
|
async fn test_permutation_builder() {
|
||||||
let temp_dir = tempfile::tempdir().unwrap();
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
@@ -352,8 +405,6 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
println!("permutation_table: {:?}", permutation_table);
|
|
||||||
|
|
||||||
// Potentially brittle seed-dependent values below
|
// Potentially brittle seed-dependent values below
|
||||||
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
|
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ use datafusion_common::hash_utils::create_hashes;
|
|||||||
use futures::{StreamExt, TryStreamExt};
|
use futures::{StreamExt, TryStreamExt};
|
||||||
use lance_arrow::SchemaExt;
|
use lance_arrow::SchemaExt;
|
||||||
|
|
||||||
|
use lance_core::ROW_ID;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||||
dataloader::{
|
dataloader::{
|
||||||
@@ -360,11 +362,15 @@ impl Splitter {
|
|||||||
|
|
||||||
pub fn project(&self, query: Query) -> Query {
|
pub fn project(&self, query: Query) -> Query {
|
||||||
match &self.strategy {
|
match &self.strategy {
|
||||||
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
|
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![
|
||||||
SPLIT_ID_COLUMN.to_string(),
|
(SPLIT_ID_COLUMN.to_string(), calculation.clone()),
|
||||||
calculation.clone(),
|
(ROW_ID.to_string(), ROW_ID.to_string()),
|
||||||
)])),
|
])),
|
||||||
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
|
SplitStrategy::Hash { columns, .. } => {
|
||||||
|
let mut cols = columns.clone();
|
||||||
|
cols.push(ROW_ID.to_string());
|
||||||
|
query.select(Select::Columns(cols))
|
||||||
|
}
|
||||||
_ => query,
|
_ => query,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,10 +79,11 @@ use self::merge::MergeInsertBuilder;
|
|||||||
|
|
||||||
pub mod datafusion;
|
pub mod datafusion;
|
||||||
pub(crate) mod dataset;
|
pub(crate) mod dataset;
|
||||||
|
pub mod delete;
|
||||||
pub mod merge;
|
pub mod merge;
|
||||||
|
|
||||||
use crate::index::waiter::wait_for_index;
|
use crate::index::waiter::wait_for_index;
|
||||||
pub use chrono::Duration;
|
pub use chrono::Duration;
|
||||||
|
pub use delete::DeleteResult;
|
||||||
use futures::future::{join_all, Either};
|
use futures::future::{join_all, Either};
|
||||||
pub use lance::dataset::optimize::CompactionOptions;
|
pub use lance::dataset::optimize::CompactionOptions;
|
||||||
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
|
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
|
||||||
@@ -446,15 +447,6 @@ pub struct AddResult {
|
|||||||
pub version: u64,
|
pub version: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
|
||||||
pub struct DeleteResult {
|
|
||||||
// The commit version associated with the operation.
|
|
||||||
// A version of `0` indicates compatibility with legacy servers that do not return
|
|
||||||
/// a commit version.
|
|
||||||
#[serde(default)]
|
|
||||||
pub version: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
pub struct MergeResult {
|
pub struct MergeResult {
|
||||||
// The commit version associated with the operation.
|
// The commit version associated with the operation.
|
||||||
@@ -3078,11 +3070,8 @@ impl BaseTable for NativeTable {
|
|||||||
|
|
||||||
/// Delete rows from the table
|
/// Delete rows from the table
|
||||||
async fn delete(&self, predicate: &str) -> Result<DeleteResult> {
|
async fn delete(&self, predicate: &str) -> Result<DeleteResult> {
|
||||||
let mut dataset = self.dataset.get_mut().await?;
|
// Delegate to the submodule implementation
|
||||||
dataset.delete(predicate).await?;
|
delete::execute_delete(self, predicate).await
|
||||||
Ok(DeleteResult {
|
|
||||||
version: dataset.version().version,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
|
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
|
||||||
|
|||||||
161
rust/lancedb/src/table/delete.rs
Normal file
161
rust/lancedb/src/table/delete.rs
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::NativeTable;
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||||
|
pub struct DeleteResult {
|
||||||
|
// The commit version associated with the operation.
|
||||||
|
// A version of `0` indicates compatibility with legacy servers that do not return
|
||||||
|
/// a commit version.
|
||||||
|
#[serde(default)]
|
||||||
|
pub version: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal implementation of the delete logic
|
||||||
|
///
|
||||||
|
/// This logic was moved from NativeTable::delete to keep table.rs clean.
|
||||||
|
pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result<DeleteResult> {
|
||||||
|
// We access the dataset from the table. Since this is in the same module hierarchy (super),
|
||||||
|
// and 'dataset' is pub(crate), we can access it.
|
||||||
|
let mut dataset = table.dataset.get_mut().await?;
|
||||||
|
|
||||||
|
// Perform the actual delete on the Lance dataset
|
||||||
|
dataset.delete(predicate).await?;
|
||||||
|
|
||||||
|
// Return the result with the new version
|
||||||
|
Ok(DeleteResult {
|
||||||
|
version: dataset.version().version,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::connect;
|
||||||
|
use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator};
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::query::ExecutableQuery;
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_simple() {
|
||||||
|
let conn = connect("memory://").execute().await.unwrap();
|
||||||
|
|
||||||
|
// 1. Create a table with values 0 to 9
|
||||||
|
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||||
|
let batch = RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let table = conn
|
||||||
|
.create_table(
|
||||||
|
"test_delete",
|
||||||
|
RecordBatchIterator::new(vec![Ok(batch)], schema),
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// 2. Verify initial state
|
||||||
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
|
|
||||||
|
// 3. Execute Delete (removes values > 5)
|
||||||
|
table.delete("i > 5").await.unwrap();
|
||||||
|
|
||||||
|
// 4. Verify results
|
||||||
|
assert_eq!(table.count_rows(None).await.unwrap(), 6); // 0, 1, 2, 3, 4, 5 remain
|
||||||
|
|
||||||
|
// 5. Verify specific data consistency
|
||||||
|
let batches = table
|
||||||
|
.query()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = &batches[0];
|
||||||
|
let array = batch
|
||||||
|
.column(0)
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int32Array>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Ensure no value > 5 exists
|
||||||
|
for val in array.iter() {
|
||||||
|
assert!(val.unwrap() <= 5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rows_removed_schema_same() {
|
||||||
|
let conn = connect("memory://").execute().await.unwrap();
|
||||||
|
let batch = record_batch!(
|
||||||
|
("id", Int32, [1, 2, 3, 4, 5]),
|
||||||
|
("name", Utf8, ["a", "b", "c", "d", "e"])
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let original_schema = batch.schema();
|
||||||
|
|
||||||
|
let table = conn
|
||||||
|
.create_table(
|
||||||
|
"test_delete_all",
|
||||||
|
RecordBatchIterator::new(vec![Ok(batch)], original_schema.clone()),
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
table.delete("true").await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(table.count_rows(None).await.unwrap(), 0);
|
||||||
|
|
||||||
|
let current_schema = table.schema().await.unwrap();
|
||||||
|
//check if the original schema is the same as current
|
||||||
|
assert_eq!(current_schema, original_schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_false_increments_version() {
|
||||||
|
let conn = connect("memory://").execute().await.unwrap();
|
||||||
|
|
||||||
|
// Create a table with 5 rows
|
||||||
|
let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5])).unwrap();
|
||||||
|
|
||||||
|
let schema = batch.schema();
|
||||||
|
|
||||||
|
let table = conn
|
||||||
|
.create_table(
|
||||||
|
"test_delete_noop",
|
||||||
|
RecordBatchIterator::new(vec![Ok(batch)], schema),
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Capture the initial state (Rows = 5, Version = 1)
|
||||||
|
let initial_rows = table.count_rows(None).await.unwrap();
|
||||||
|
let initial_version = table.version().await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(initial_rows, 5);
|
||||||
|
table.delete("false").await.unwrap();
|
||||||
|
|
||||||
|
// Rows should still be 5
|
||||||
|
let current_rows = table.count_rows(None).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
current_rows, initial_rows,
|
||||||
|
"Data should not change when predicate is false"
|
||||||
|
);
|
||||||
|
|
||||||
|
// version check
|
||||||
|
let current_version = table.version().await.unwrap();
|
||||||
|
assert!(
|
||||||
|
current_version > initial_version,
|
||||||
|
"Table version must increment after delete operation"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user