Compare commits

..

5 Commits

Author SHA1 Message Date
Wyatt Alt
bc76671b62 add nodejs test 2026-02-28 22:18:37 -08:00
Wyatt Alt
bc814e1f66 lint js docs 2026-02-28 21:50:40 -08:00
Wyatt Alt
82a6e5a196 doctest update 2026-02-28 20:25:29 -08:00
Wyatt Alt
b3100f0e7c add test 2026-02-28 19:26:00 -08:00
Wyatt Alt
8ec60e3c9d feat: add num_deleted to delete result 2026-02-28 19:26:00 -08:00
95 changed files with 660 additions and 1136 deletions

View File

@@ -29,7 +29,6 @@ runs:
if: ${{ inputs.arm-build == 'false' }}
uses: PyO3/maturin-action@v1
with:
maturin-version: "1.12.4"
command: build
working-directory: python
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"
@@ -45,7 +44,6 @@ runs:
if: ${{ inputs.arm-build == 'true' }}
uses: PyO3/maturin-action@v1
with:
maturin-version: "1.12.4"
command: build
working-directory: python
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"

View File

@@ -20,7 +20,6 @@ runs:
uses: PyO3/maturin-action@v1
with:
command: build
maturin-version: "1.12.4"
# TODO: pass through interpreter
args: ${{ inputs.args }}
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"

View File

@@ -25,7 +25,6 @@ runs:
uses: PyO3/maturin-action@v1
with:
command: build
maturin-version: "1.12.4"
args: ${{ inputs.args }}
docker-options: "-e PIP_EXTRA_INDEX_URL='https://pypi.fury.io/lance-format/ https://pypi.fury.io/lancedb/'"
working-directory: python

View File

@@ -356,8 +356,7 @@ jobs:
if [[ $DRY_RUN == "true" ]]; then
ARGS="$ARGS --dry-run"
fi
VERSION=$(node -p "require('./package.json').version")
if [[ $VERSION == *-* ]]; then
if [[ $GITHUB_REF =~ refs/tags/v(.*)-beta.* ]]; then
ARGS="$ARGS --tag preview"
fi
npm publish $ARGS

View File

@@ -10,10 +10,6 @@ on:
- python/**
- rust/**
- .github/workflows/python.yml
- .github/workflows/build_linux_wheel/**
- .github/workflows/build_mac_wheel/**
- .github/workflows/build_windows_wheel/**
- .github/workflows/run_tests/**
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

112
Cargo.lock generated
View File

@@ -3088,8 +3088,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow-array",
"rand 0.9.2",
@@ -4260,8 +4260,8 @@ dependencies = [
[[package]]
name = "lance"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"arrow-arith",
@@ -4315,7 +4315,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"snafu 0.9.0",
"snafu",
"tantivy",
"tokio",
"tokio-stream",
@@ -4327,8 +4327,8 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4338,7 +4338,6 @@ dependencies = [
"arrow-schema",
"arrow-select",
"bytes",
"futures",
"getrandom 0.2.16",
"half",
"jsonb",
@@ -4348,8 +4347,8 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrayref",
"paste",
@@ -4358,8 +4357,8 @@ dependencies = [
[[package]]
name = "lance-core"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4385,7 +4384,7 @@ dependencies = [
"rand 0.9.2",
"roaring",
"serde_json",
"snafu 0.9.0",
"snafu",
"tempfile",
"tokio",
"tokio-stream",
@@ -4396,8 +4395,8 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"arrow-array",
@@ -4420,15 +4419,15 @@ dependencies = [
"pin-project",
"prost",
"prost-build",
"snafu 0.9.0",
"snafu",
"tokio",
"tracing",
]
[[package]]
name = "lance-datagen"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"arrow-array",
@@ -4446,8 +4445,8 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4474,7 +4473,7 @@ dependencies = [
"prost-build",
"prost-types",
"rand 0.9.2",
"snafu 0.9.0",
"snafu",
"strum",
"tokio",
"tracing",
@@ -4484,8 +4483,8 @@ dependencies = [
[[package]]
name = "lance-file"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4510,15 +4509,15 @@ dependencies = [
"prost",
"prost-build",
"prost-types",
"snafu 0.9.0",
"snafu",
"tokio",
"tracing",
]
[[package]]
name = "lance-index"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"arrow-arith",
@@ -4570,7 +4569,7 @@ dependencies = [
"serde",
"serde_json",
"smallvec",
"snafu 0.9.0",
"snafu",
"tantivy",
"tempfile",
"tokio",
@@ -4581,8 +4580,8 @@ dependencies = [
[[package]]
name = "lance-io"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"arrow-arith",
@@ -4614,7 +4613,7 @@ dependencies = [
"prost",
"rand 0.9.2",
"serde",
"snafu 0.9.0",
"snafu",
"tempfile",
"tokio",
"tracing",
@@ -4623,8 +4622,8 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4640,21 +4639,21 @@ dependencies = [
[[package]]
name = "lance-namespace"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"async-trait",
"bytes",
"lance-core",
"lance-namespace-reqwest-client",
"snafu 0.9.0",
"snafu",
]
[[package]]
name = "lance-namespace-impls"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"arrow-ipc",
@@ -4676,7 +4675,7 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
"snafu 0.9.0",
"snafu",
"tokio",
"tower",
"tower-http 0.5.2",
@@ -4698,8 +4697,8 @@ dependencies = [
[[package]]
name = "lance-table"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow",
"arrow-array",
@@ -4729,7 +4728,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"snafu 0.9.0",
"snafu",
"tokio",
"tracing",
"url",
@@ -4738,8 +4737,8 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "3.0.0-rc.3"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.3#de393a26a068dd297929ca7d798e43dc31c57337"
version = "3.0.0-rc.2"
source = "git+https://github.com/lance-format/lance.git?tag=v3.0.0-rc.2#3fb3e705b8a25ab1bb0fc9e1e0158e8a13356181"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -4820,7 +4819,7 @@ dependencies = [
"serde",
"serde_json",
"serde_with",
"snafu 0.8.9",
"snafu",
"tempfile",
"test-log",
"tokenizers",
@@ -4866,7 +4865,7 @@ dependencies = [
"pyo3",
"pyo3-async-runtimes",
"pyo3-build-config",
"snafu 0.8.9",
"snafu",
"tokio",
]
@@ -7778,16 +7777,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e84b3f4eacbf3a1ce05eac6763b4d629d60cbc94d632e4092c54ade71f1e1a2"
dependencies = [
"snafu-derive 0.8.9",
]
[[package]]
name = "snafu"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1d4bced6a69f90b2056c03dcff2c4737f98d6fb9e0853493996e1d253ca29c6"
dependencies = [
"snafu-derive 0.9.0",
"snafu-derive",
]
[[package]]
@@ -7802,18 +7792,6 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "snafu-derive"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "socket2"
version = "0.5.10"

View File

@@ -5,7 +5,7 @@ exclude = ["python"]
resolver = "2"
[workspace.package]
edition = "2024"
edition = "2021"
authors = ["LanceDB Devs <dev@lancedb.com>"]
license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"
@@ -15,20 +15,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=3.0.0-rc.3", default-features = false, "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=3.0.0-rc.3", default-features = false, "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=3.0.0-rc.3", default-features = false, "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=3.0.0-rc.3", "tag" = "v3.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=3.0.0-rc.2", default-features = false, "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=3.0.0-rc.2", default-features = false, "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=3.0.0-rc.2", default-features = false, "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=3.0.0-rc.2", "tag" = "v3.0.0-rc.2", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "57.2", optional = false }

View File

@@ -450,6 +450,31 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
},
);
describe("delete", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
table = await conn.createTable("delete_test", [
{ id: 1, value: "a" },
{ id: 2, value: "b" },
{ id: 3, value: "c" },
{ id: 4, value: "d" },
{ id: 5, value: "e" },
]);
});
afterEach(() => tmpDir.removeCallback());
test("returns num_deleted_rows", async () => {
const result = await table.delete("id > 3");
expect(result.numDeletedRows).toBe(2);
expect(result.version).toBe(2);
expect(await table.countRows()).toBe(3);
});
});
describe("merge insert", () => {
let tmpDir: tmp.DirResult;
let table: Table;
@@ -1697,65 +1722,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
expect(results2[0].text).toBe(data[1].text);
});
test("full text search fast search", async () => {
const db = await connect(tmpDir.name);
const data = [{ text: "hello world", vector: [0.1, 0.2, 0.3], id: 1 }];
const table = await db.createTable("test", data);
await table.createIndex("text", {
config: Index.fts(),
});
// Insert unindexed data after creating the index.
await table.add([{ text: "xyz", vector: [0.4, 0.5, 0.6], id: 2 }]);
const withFlatSearch = await table
.search("xyz", "fts")
.limit(10)
.toArray();
expect(withFlatSearch.length).toBeGreaterThan(0);
const fastSearchResults = await table
.search("xyz", "fts")
.fastSearch()
.limit(10)
.toArray();
expect(fastSearchResults.length).toBe(0);
const nearestToTextFastSearch = await table
.query()
.nearestToText("xyz")
.fastSearch()
.limit(10)
.toArray();
expect(nearestToTextFastSearch.length).toBe(0);
// fastSearch should be chainable with other methods.
const chainedFastSearch = await table
.search("xyz", "fts")
.fastSearch()
.select(["text"])
.limit(5)
.toArray();
expect(chainedFastSearch.length).toBe(0);
await table.optimize();
const indexedFastSearch = await table
.search("xyz", "fts")
.fastSearch()
.limit(10)
.toArray();
expect(indexedFastSearch.length).toBeGreaterThan(0);
const indexedNearestToTextFastSearch = await table
.query()
.nearestToText("xyz")
.fastSearch()
.limit(10)
.toArray();
expect(indexedNearestToTextFastSearch.length).toBeGreaterThan(0);
});
test("prewarm full text search index", async () => {
const db = await connect(tmpDir.name);
const data = [

View File

@@ -8,10 +8,10 @@ use lancedb::database::{CreateTableMode, Database};
use napi::bindgen_prelude::*;
use napi_derive::*;
use crate::ConnectionOptions;
use crate::error::NapiErrorExt;
use crate::header::JsHeaderProvider;
use crate::table::Table;
use crate::ConnectionOptions;
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection};
use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema};

View File

@@ -3,12 +3,12 @@
use std::sync::Mutex;
use lancedb::index::Index as LanceDbIndex;
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
use lancedb::index::vector::{
IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder,
IvfRqIndexBuilder,
};
use lancedb::index::Index as LanceDbIndex;
use napi_derive::napi;
use crate::util::parse_distance_type;

View File

@@ -17,8 +17,8 @@ use lancedb::query::VectorQuery as LanceDbVectorQuery;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use crate::error::NapiErrorExt;
use crate::error::convert_error;
use crate::error::NapiErrorExt;
use crate::iterator::RecordBatchIterator;
use crate::rerankers::RerankHybridCallbackArgs;
use crate::rerankers::Reranker;
@@ -551,12 +551,15 @@ fn parse_fts_query(query: Object) -> napi::Result<FullTextSearchQuery> {
}
};
let mut query = FullTextSearchQuery::new_query(query);
if let Some(cols) = columns
&& !cols.is_empty()
{
query = query.with_columns(&cols).map_err(|e| {
napi::Error::from_reason(format!("Failed to set full text search columns: {}", e))
})?;
if let Some(cols) = columns {
if !cols.is_empty() {
query = query.with_columns(&cols).map_err(|e| {
napi::Error::from_reason(format!(
"Failed to set full text search columns: {}",
e
))
})?;
}
}
Ok(query)
} else {

View File

@@ -145,7 +145,6 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
id_delimiter: config.id_delimiter,
tls_config: config.tls_config.map(Into::into),
header_provider: None, // the header provider is set separately later
mem_wal_enabled: None, // mem_wal is set per-operation in merge_insert
}
}
}

View File

@@ -95,7 +95,7 @@ impl napi::bindgen_prelude::FromNapiValue for Session {
napi_val: napi::sys::napi_value,
) -> napi::Result<Self> {
let object: napi::bindgen_prelude::ClassInstance<Self> =
unsafe { napi::bindgen_prelude::ClassInstance::from_napi_value(env, napi_val)? };
napi::bindgen_prelude::ClassInstance::from_napi_value(env, napi_val)?;
Ok((*object).clone())
}
}

View File

@@ -34,7 +34,6 @@ class LanceMergeInsertBuilder(object):
self._when_not_matched_by_source_condition = None
self._timeout = None
self._use_index = True
self._mem_wal = False
def when_matched_update_all(
self, *, where: Optional[str] = None
@@ -97,47 +96,6 @@ class LanceMergeInsertBuilder(object):
self._use_index = use_index
return self
def mem_wal(self, enabled: bool = True) -> LanceMergeInsertBuilder:
"""
Enable MemWAL (Memory Write-Ahead Log) mode for this merge insert operation.
When enabled, the merge insert will route data through a memory node service
that buffers writes before flushing to storage. This is only supported for
remote (LanceDB Cloud) tables.
**Important:** MemWAL only supports the upsert pattern. You must use:
- `when_matched_update_all()` (without a filter condition)
- `when_not_matched_insert_all()`
MemWAL does NOT support:
- `when_matched_update_all(where=...)` with a filter condition
- `when_not_matched_by_source_delete()`
Parameters
----------
enabled: bool
Whether to enable MemWAL mode. Defaults to `True`.
Raises
------
NotImplementedError
If used on a native (local) table, as MemWAL is only supported for
remote tables.
ValueError
If the merge insert pattern is not supported by MemWAL.
Examples
--------
>>> # Correct usage with MemWAL
>>> table.merge_insert(["id"]) \\
... .when_matched_update_all() \\
... .when_not_matched_insert_all() \\
... .mem_wal() \\
... .execute(new_data)
"""
self._mem_wal = enabled
return self
def execute(
self,
new_data: DATA,

View File

@@ -606,7 +606,6 @@ class LanceQueryBuilder(ABC):
query,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
fast_search=fast_search,
)
if isinstance(query, list):
@@ -1457,14 +1456,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
query: str | FullTextQuery,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
fast_search: bool = None,
):
super().__init__(table)
self._query = query
self._phrase_query = False
self.ordering_field_name = ordering_field_name
self._reranker = None
self._fast_search = fast_search
self._fast_search = None
if isinstance(fts_columns, str):
fts_columns = [fts_columns]
self._fts_columns = fts_columns

View File

@@ -218,6 +218,8 @@ class RemoteTable(Table):
train: bool = True,
):
"""Create an index on the table.
Currently, the only parameters that matter are
the metric and the vector column name.
Parameters
----------
@@ -248,6 +250,11 @@ class RemoteTable(Table):
>>> table.create_index("l2", "vector") # doctest: +SKIP
"""
if num_sub_vectors is not None:
logging.warning(
"num_sub_vectors is not supported on LanceDB cloud."
"This parameter will be tuned automatically."
)
if accelerator is not None:
logging.warning(
"GPU accelerator is not yet supported on LanceDB cloud."

View File

@@ -4181,7 +4181,6 @@ class AsyncTable:
when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition,
timeout=merge._timeout,
use_index=merge._use_index,
mem_wal=merge._mem_wal,
),
)

View File

@@ -324,16 +324,6 @@ def _(value: list):
return "[" + ", ".join(map(value_to_sql, value)) + "]"
@value_to_sql.register(dict)
def _(value: dict):
# https://datafusion.apache.org/user-guide/sql/scalar_functions.html#named-struct
return (
"named_struct("
+ ", ".join(f"'{k}', {value_to_sql(v)}" for k, v in value.items())
+ ")"
)
@value_to_sql.register(np.ndarray)
def _(value: np.ndarray):
return value_to_sql(value.tolist())

View File

@@ -27,7 +27,6 @@ from lancedb.query import (
PhraseQuery,
BooleanQuery,
Occur,
LanceFtsQueryBuilder,
)
import numpy as np
import pyarrow as pa
@@ -921,10 +920,6 @@ def test_fts_fast_search(table):
assert query.limit == 5
assert query.columns == ["text"]
# fast_search should be enabled by keyword argument too
query = LanceFtsQueryBuilder(table, "xyz", fast_search=True).to_query_object()
assert query.fast_search is True
# Verify it executes without error and skips unindexed data
results = table.search("xyz", query_type="fts").fast_search().limit(5).to_list()
assert len(results) == 0

View File

@@ -121,32 +121,6 @@ def test_value_to_sql_string(tmp_path):
assert table.to_pandas().query("search == @value")["replace"].item() == value
def test_value_to_sql_dict():
# Simple flat struct
assert value_to_sql({"a": 1, "b": "hello"}) == "named_struct('a', 1, 'b', 'hello')"
# Nested struct
assert (
value_to_sql({"outer": {"inner": 1}})
== "named_struct('outer', named_struct('inner', 1))"
)
# List inside struct
assert value_to_sql({"a": [1, 2]}) == "named_struct('a', [1, 2])"
# Mixed types
assert (
value_to_sql({"name": "test", "count": 42, "rate": 3.14, "active": True})
== "named_struct('name', 'test', 'count', 42, 'rate', 3.14, 'active', TRUE)"
)
# Null value inside struct
assert value_to_sql({"a": None}) == "named_struct('a', NULL)"
# Empty dict
assert value_to_sql({}) == "named_struct()"
def test_append_vector_columns():
registry = EmbeddingFunctionRegistry.get_instance()
registry.register("test")(MockTextEmbeddingFunction)

View File

@@ -10,7 +10,7 @@ use arrow::{
use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{
Bound, Py, PyAny, PyRef, PyResult, Python, exceptions::PyStopAsyncIteration, pyclass, pymethods,
exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;

View File

@@ -9,10 +9,10 @@ use lancedb::{
database::{CreateTableMode, Database, ReadConsistency},
};
use pyo3::{
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods,
types::{PyDict, PyDictMethods},
Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -506,7 +506,6 @@ pub struct PyClientConfig {
id_delimiter: Option<String>,
tls_config: Option<PyClientTlsConfig>,
header_provider: Option<Py<PyAny>>,
mem_wal_enabled: Option<bool>,
}
#[derive(FromPyObject)]
@@ -591,7 +590,6 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
id_delimiter: value.id_delimiter,
tls_config: value.tls_config.map(Into::into),
header_provider,
mem_wal_enabled: value.mem_wal_enabled,
}
}
}

View File

@@ -2,10 +2,10 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use pyo3::{
PyErr, PyResult, Python,
exceptions::{PyIOError, PyNotImplementedError, PyOSError, PyRuntimeError, PyValueError},
intern,
types::{PyAnyMethods, PyNone},
PyErr, PyResult, Python,
};
use lancedb::error::Error as LanceError;

View File

@@ -3,17 +3,17 @@
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder};
use lancedb::index::{
Index as LanceDbIndex,
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
Index as LanceDbIndex,
};
use pyo3::IntoPyObject;
use pyo3::types::PyStringMethods;
use pyo3::IntoPyObject;
use pyo3::{
Bound, FromPyObject, PyAny, PyResult, Python,
exceptions::{PyKeyError, PyValueError},
intern, pyclass, pymethods,
types::PyAnyMethods,
Bound, FromPyObject, PyAny, PyResult, Python,
};
use crate::util::parse_distance_type;
@@ -41,12 +41,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
let inner_opts = FtsIndexBuilder::default()
.base_tokenizer(params.base_tokenizer)
.language(&params.language)
.map_err(|_| {
PyValueError::new_err(format!(
"LanceDB does not support the requested language: '{}'",
params.language
))
})?
.map_err(|_| PyValueError::new_err(format!("LanceDB does not support the requested language: '{}'", params.language)))?
.with_position(params.with_position)
.lower_case(params.lower_case)
.max_token_length(params.max_token_length)
@@ -57,7 +52,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
.ngram_max_length(params.ngram_max_length)
.ngram_prefix_only(params.prefix_only);
Ok(LanceDbIndex::FTS(inner_opts))
}
},
"IvfFlat" => {
let params = source.extract::<IvfFlatParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -69,11 +64,10 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
}
if let Some(target_partition_size) = params.target_partition_size {
ivf_flat_builder =
ivf_flat_builder.target_partition_size(target_partition_size);
ivf_flat_builder = ivf_flat_builder.target_partition_size(target_partition_size);
}
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
}
},
"IvfPq" => {
let params = source.extract::<IvfPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -92,7 +86,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
}
},
"IvfSq" => {
let params = source.extract::<IvfSqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -107,7 +101,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
ivf_sq_builder = ivf_sq_builder.target_partition_size(target_partition_size);
}
Ok(LanceDbIndex::IvfSq(ivf_sq_builder))
}
},
"IvfRq" => {
let params = source.extract::<IvfRqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -123,7 +117,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
ivf_rq_builder = ivf_rq_builder.target_partition_size(target_partition_size);
}
Ok(LanceDbIndex::IvfRq(ivf_rq_builder))
}
},
"HnswPq" => {
let params = source.extract::<IvfHnswPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -144,7 +138,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
Ok(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))
}
},
"HnswSq" => {
let params = source.extract::<IvfHnswSqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -161,7 +155,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
hnsw_sq_builder = hnsw_sq_builder.target_partition_size(target_partition_size);
}
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
}
},
not_supported => Err(PyValueError::new_err(format!(
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfSq, IvfHnswPq, or IvfHnswSq",
not_supported

View File

@@ -2,14 +2,14 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow::RecordBatchStream;
use connection::{Connection, connect};
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
use pyo3::{
Bound, PyResult, Python, pymodule,
pymodule,
types::{PyModule, PyModuleMethods},
wrap_pyfunction,
wrap_pyfunction, Bound, PyResult, Python,
};
use query::{FTSQuery, HybridQuery, Query, VectorQuery};
use session::Session;

View File

@@ -16,10 +16,10 @@ use lancedb::{
query::Select,
};
use pyo3::{
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
exceptions::PyRuntimeError,
pyclass, pymethods,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
Bound, PyAny, PyRef, PyRefMut, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;

View File

@@ -4,9 +4,9 @@
use std::sync::Arc;
use std::time::Duration;
use arrow::array::make_array;
use arrow::array::Array;
use arrow::array::ArrayData;
use arrow::array::make_array;
use arrow::pyarrow::FromPyArrow;
use arrow::pyarrow::IntoPyArrow;
use arrow::pyarrow::ToPyArrow;
@@ -22,23 +22,23 @@ use lancedb::query::{
VectorQuery as LanceDbVectorQuery,
};
use lancedb::table::AnyQuery;
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
use pyo3::pyfunction;
use pyo3::pymethods;
use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString};
use pyo3::Bound;
use pyo3::IntoPyObject;
use pyo3::PyAny;
use pyo3::PyRef;
use pyo3::PyResult;
use pyo3::Python;
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
use pyo3::pyfunction;
use pyo3::pymethods;
use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString};
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
use pyo3::{PyErr, pyclass};
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
use pyo3::{
exceptions::{PyNotImplementedError, PyValueError},
intern,
};
use pyo3::{pyclass, PyErr};
use pyo3_async_runtimes::tokio::future_into_py;
use crate::util::parse_distance_type;

View File

@@ -4,7 +4,7 @@
use std::sync::Arc;
use lancedb::{ObjectStoreRegistry, Session as LanceSession};
use pyo3::{PyResult, pyclass, pymethods};
use pyo3::{pyclass, pymethods, PyResult};
/// A session for managing caches and object stores across LanceDB operations.
///

View File

@@ -66,10 +66,13 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
.inner
.bind(py)
.call_method0("fetch_storage_options")
.map_err(|e| lance_core::Error::io_source(Box::new(std::io::Error::other(format!(
"Failed to call fetch_storage_options: {}",
e
)))))?;
.map_err(|e| lance_core::Error::IO {
source: Box::new(std::io::Error::other(format!(
"Failed to call fetch_storage_options: {}",
e
))),
location: snafu::location!(),
})?;
// If result is None, return None
if result.is_none() {
@@ -78,19 +81,26 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
// Extract the result dict - should be a flat Map<String, String>
let result_dict = result.downcast::<PyDict>().map_err(|_| {
lance_core::Error::invalid_input(
"fetch_storage_options() must return a dict of string key-value pairs or None",
)
lance_core::Error::InvalidInput {
source: "fetch_storage_options() must return None or a dict of string key-value pairs".into(),
location: snafu::location!(),
}
})?;
// Convert all entries to HashMap<String, String>
let mut storage_options = HashMap::new();
for (key, value) in result_dict.iter() {
let key_str: String = key.extract().map_err(|e| {
lance_core::Error::invalid_input(format!("Storage option key must be a string: {}", e))
lance_core::Error::InvalidInput {
source: format!("Storage option key must be a string: {}", e).into(),
location: snafu::location!(),
}
})?;
let value_str: String = value.extract().map_err(|e| {
lance_core::Error::invalid_input(format!("Storage option value must be a string: {}", e))
lance_core::Error::InvalidInput {
source: format!("Storage option value must be a string: {}", e).into(),
location: snafu::location!(),
}
})?;
storage_options.insert(key_str, value_str);
}
@@ -99,10 +109,13 @@ impl StorageOptionsProvider for PyStorageOptionsProviderWrapper {
})
})
.await
.map_err(|e| lance_core::Error::io_source(Box::new(std::io::Error::other(format!(
"Task join error: {}",
e
)))))?
.map_err(|e| lance_core::Error::IO {
source: Box::new(std::io::Error::other(format!(
"Task join error: {}",
e
))),
location: snafu::location!(),
})?
}
fn provider_id(&self) -> String {

View File

@@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc};
use crate::{
connection::Connection,
error::PythonErrorExt,
index::{IndexConfig, extract_index_params},
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
table::scannable::PyScannable,
};
@@ -19,10 +19,10 @@ use lancedb::table::{
Table as LanceDbTable,
};
use pyo3::{
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -542,7 +542,7 @@ impl Table {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let versions = inner.list_versions().await.infer_error()?;
Python::attach(|py| {
let versions_as_dict = Python::attach(|py| {
versions
.iter()
.map(|v| {
@@ -559,7 +559,9 @@ impl Table {
Ok(dict.unbind())
})
.collect::<PyResult<Vec<_>>>()
})
});
versions_as_dict
})
}
@@ -710,9 +712,6 @@ impl Table {
if let Some(use_index) = parameters.use_index {
builder.use_index(use_index);
}
if let Some(mem_wal) = parameters.mem_wal {
builder.mem_wal(mem_wal);
}
future_into_py(self_.py(), async move {
let res = builder.execute(Box::new(batches)).await.infer_error()?;
@@ -873,7 +872,6 @@ pub struct MergeInsertParams {
when_not_matched_by_source_condition: Option<String>,
timeout: Option<std::time::Duration>,
use_index: Option<bool>,
mem_wal: Option<bool>,
}
#[pyclass]

View File

@@ -10,11 +10,11 @@ use arrow::{
};
use futures::StreamExt;
use lancedb::{
Error,
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
data::scannable::Scannable,
Error,
};
use pyo3::{FromPyObject, Py, PyAny, Python, types::PyAnyMethods};
use pyo3::{types::PyAnyMethods, FromPyObject, Py, PyAny, Python};
/// Adapter that implements Scannable for a Python reader factory callable.
///
@@ -99,15 +99,15 @@ impl Scannable for PyScannable {
// Channel closed. Check if the task panicked — a panic
// drops the sender without sending an error, so without
// this check we'd silently return a truncated stream.
if let Some(handle) = join_handle
&& let Err(join_err) = handle.await
{
return Some((
Err(Error::Runtime {
message: format!("Reader task panicked: {}", join_err),
}),
(rx, None),
));
if let Some(handle) = join_handle {
if let Err(join_err) = handle.await {
return Some((
Err(Error::Runtime {
message: format!("Reader task panicked: {}", join_err),
}),
(rx, None),
));
}
}
None
}

View File

@@ -5,9 +5,8 @@ use std::sync::Mutex;
use lancedb::DistanceType;
use pyo3::{
PyResult,
exceptions::{PyRuntimeError, PyValueError},
pyfunction,
pyfunction, PyResult,
};
/// A wrapper around a rust builder

View File

@@ -9,9 +9,10 @@ use aws_config::Region;
use aws_sdk_bedrockruntime::Client;
use futures::StreamExt;
use lancedb::{
Result, connect,
embeddings::{EmbeddingDefinition, EmbeddingFunction, bedrock::BedrockEmbeddingFunction},
connect,
embeddings::{bedrock::BedrockEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]

View File

@@ -10,10 +10,10 @@ use futures::TryStreamExt;
use lance_index::scalar::FullTextSearchQuery;
use lancedb::connection::Connection;
use lancedb::index::Index;
use lancedb::index::scalar::FtsIndexBuilder;
use lancedb::index::Index;
use lancedb::query::{ExecutableQuery, QueryBase};
use lancedb::{Result, Table, connect};
use lancedb::{connect, Result, Table};
use rand::random;
#[tokio::main]
@@ -46,21 +46,19 @@ fn create_some_records() -> Result<Box<dyn arrow_array::RecordBatchReader + Send
.collect::<Vec<_>>();
let n_terms = 3;
let batches = RecordBatchIterator::new(
vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter_values((0..TOTAL).map(|_| {
(0..n_terms)
.map(|_| words[random::<u32>() as usize % words.len()])
.collect::<Vec<_>>()
.join(" ")
}))),
],
)
.unwrap(),
]
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter_values((0..TOTAL).map(|_| {
(0..n_terms)
.map(|_| words[random::<u32>() as usize % words.len()])
.collect::<Vec<_>>()
.join(" ")
}))),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),

View File

@@ -5,15 +5,16 @@ use arrow_array::{RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lance_index::scalar::FullTextSearchQuery;
use lancedb::index::Index;
use lancedb::index::scalar::FtsIndexBuilder;
use lancedb::index::Index;
use lancedb::{
Result, Table, connect,
connect,
embeddings::{
EmbeddingDefinition, EmbeddingFunction,
sentence_transformers::SentenceTransformersEmbeddings,
sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition,
EmbeddingFunction,
},
query::{QueryBase, QueryExecutionOptions},
Result, Table,
};
use std::{iter::once, sync::Arc};

View File

@@ -14,10 +14,10 @@ use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lancedb::connection::Connection;
use lancedb::index::Index;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::Index;
use lancedb::query::{ExecutableQuery, QueryBase};
use lancedb::{DistanceType, Result, Table, connect};
use lancedb::{connect, DistanceType, Result, Table};
#[tokio::main]
async fn main() -> Result<()> {
@@ -51,21 +51,19 @@ fn create_some_records() -> Result<Box<dyn arrow_array::RecordBatchReader + Send
// Create a RecordBatch stream.
let batches = RecordBatchIterator::new(
vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
DIM as i32,
),
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
DIM as i32,
),
],
)
.unwrap(),
]
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),

View File

@@ -8,9 +8,10 @@ use std::{iter::once, sync::Arc};
use arrow_array::{RecordBatch, StringArray};
use futures::StreamExt;
use lancedb::{
Result, connect,
embeddings::{EmbeddingDefinition, EmbeddingFunction, openai::OpenAIEmbeddingFunction},
connect,
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
query::{ExecutableQuery, QueryBase},
Result,
};
// --8<-- [end:imports]

View File

@@ -7,12 +7,13 @@ use arrow_array::{RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
Result, connect,
connect,
embeddings::{
EmbeddingDefinition, EmbeddingFunction,
sentence_transformers::SentenceTransformersEmbeddings,
sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition,
EmbeddingFunction,
},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]

View File

@@ -14,7 +14,7 @@ use futures::TryStreamExt;
use lancedb::connection::Connection;
use lancedb::index::Index;
use lancedb::query::{ExecutableQuery, QueryBase};
use lancedb::{Result, Table as LanceDbTable, connect};
use lancedb::{connect, Result, Table as LanceDbTable};
#[tokio::main]
async fn main() -> Result<()> {

View File

@@ -12,7 +12,7 @@ use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
#[cfg(feature = "polars")]
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
use crate::{Error, error::Result};
use crate::{error::Result, Error};
/// An iterator of batches that also has a schema
pub trait RecordBatchReader: Iterator<Item = Result<arrow_array::RecordBatch>> {

View File

@@ -17,7 +17,6 @@ use lance_namespace::models::{
#[cfg(feature = "aws")]
use object_store::aws::AwsCredential;
use crate::Table;
use crate::connection::create_table::CreateTableBuilder;
use crate::data::scannable::Scannable;
use crate::database::listing::ListingDatabase;
@@ -32,6 +31,7 @@ use crate::remote::{
client::ClientConfig,
db::{OPT_REMOTE_API_KEY, OPT_REMOTE_HOST_OVERRIDE, OPT_REMOTE_REGION},
};
use crate::Table;
use lance::io::ObjectStoreParams;
pub use lance_encoding::version::LanceFileVersion;
#[cfg(feature = "remote")]
@@ -566,11 +566,8 @@ pub struct ConnectBuilder {
}
#[cfg(feature = "remote")]
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 3] = [
("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name"),
("AZURE_CLIENT_ID", "azure_client_id"),
("AZURE_TENANT_ID", "azure_tenant_id"),
];
const ENV_VARS_TO_STORAGE_OPTS: [(&str, &str); 1] =
[("AZURE_STORAGE_ACCOUNT_NAME", "azure_storage_account_name")];
impl ConnectBuilder {
/// Create a new [`ConnectOptions`] with the given database URI.
@@ -761,10 +758,10 @@ impl ConnectBuilder {
options: &mut HashMap<String, String>,
) {
for (env_key, opt_key) in env_var_to_remote_storage_option {
if let Ok(env_value) = std::env::var(env_key)
&& !options.contains_key(*opt_key)
{
options.insert((*opt_key).to_string(), env_value);
if let Ok(env_value) = std::env::var(env_key) {
if !options.contains_key(*opt_key) {
options.insert((*opt_key).to_string(), env_value);
}
}
}
}
@@ -784,19 +781,13 @@ impl ConnectBuilder {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
// Propagate mem_wal_enabled from options to client_config
let mut client_config = self.request.client_config;
if options.mem_wal_enabled.is_some() {
client_config.mem_wal_enabled = options.mem_wal_enabled;
}
let storage_options = StorageOptions(options.storage_options.clone());
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
&self.request.uri,
&api_key,
&region,
options.host_override,
client_config,
self.request.client_config,
storage_options.into(),
)?);
Ok(Connection {
@@ -1020,13 +1011,14 @@ mod tests {
#[cfg(feature = "remote")]
#[test]
fn test_apply_env_defaults() {
let env_key = "PATH";
let env_val = std::env::var(env_key).expect("PATH should be set in test environment");
let env_key = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_KEY";
let env_val = "TEST_APPLY_ENV_DEFAULTS_ENVIRONMENT_VARIABLE_ENV_VAL";
let opts_key = "test_apply_env_defaults_environment_variable_opts_key";
std::env::set_var(env_key, env_val);
let mut options = HashMap::new();
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);
assert_eq!(Some(&env_val), options.get(opts_key));
assert_eq!(Some(&env_val.to_string()), options.get(opts_key));
options.insert(opts_key.to_string(), "EXPLICIT-VALUE".to_string());
ConnectBuilder::apply_env_defaults(&[(env_key, opts_key)], &mut options);

View File

@@ -6,12 +6,12 @@ use std::sync::Arc;
use lance_io::object_store::StorageOptionsProvider;
use crate::{
Error, Result, Table,
connection::{merge_storage_options, set_storage_options_provider},
data::scannable::{Scannable, WithEmbeddingsScannable},
database::{CreateTableMode, CreateTableRequest, Database},
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
table::WriteOptions,
Error, Result, Table,
};
pub struct CreateTableBuilder {
@@ -167,7 +167,7 @@ impl CreateTableBuilder {
#[cfg(test)]
mod tests {
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, record_batch,
record_batch, Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator,
};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use futures::TryStreamExt;
@@ -380,12 +380,11 @@ mod tests {
.await
.unwrap();
let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)]));
assert!(
db.create_empty_table("test", other_schema.clone())
.execute()
.await
.is_err()
); // TODO: assert what this error is
assert!(db
.create_empty_table("test", other_schema.clone())
.execute()
.await
.is_err()); // TODO: assert what this error is
let overwritten = db
.create_empty_table("test", other_schema.clone())
.mode(CreateTableMode::Overwrite)

View File

@@ -5,9 +5,9 @@ use std::collections::HashMap;
use arrow::compute::kernels::{aggregate::bool_and, length::length};
use arrow_array::{
Array, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatchReader,
cast::AsArray,
types::{ArrowPrimitiveType, Int32Type, Int64Type},
Array, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatchReader,
};
use arrow_ord::cmp::eq;
use arrow_schema::DataType;
@@ -78,7 +78,7 @@ pub fn infer_vector_columns(
_ => {
return Err(Error::Schema {
message: format!("Column {} is not a list", col_name),
});
})
}
} {
if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) {
@@ -102,8 +102,8 @@ mod tests {
use super::*;
use arrow_array::{
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
types::{Float32Type, Float64Type},
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use std::{sync::Arc, vec};

View File

@@ -4,10 +4,10 @@
use std::{iter::repeat_with, sync::Arc};
use arrow_array::{
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
RecordBatchReader,
cast::AsArray,
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
RecordBatchReader,
};
use arrow_cast::{can_cast_types, cast};
use arrow_schema::{ArrowError, DataType, Field, Schema};
@@ -184,7 +184,7 @@ mod tests {
use std::sync::Arc;
use arrow_array::{
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int8Array, Int32Array,
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array,
RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::Field;

View File

@@ -13,16 +13,16 @@ use crate::arrow::{
SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream,
};
use crate::embeddings::{
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, compute_embeddings_for_batch,
compute_output_schema,
compute_embeddings_for_batch, compute_output_schema, EmbeddingDefinition, EmbeddingFunction,
EmbeddingRegistry,
};
use crate::table::{ColumnDefinition, ColumnKind, TableDefinition};
use crate::{Error, Result};
use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{ArrowError, SchemaRef};
use async_trait::async_trait;
use futures::StreamExt;
use futures::stream::once;
use futures::StreamExt;
use lance_datafusion::utils::StreamingWriteSource;
pub trait Scannable: Send {

View File

@@ -19,12 +19,12 @@ use std::sync::Arc;
use std::time::Duration;
use lance::dataset::ReadParams;
use lance_namespace::LanceNamespace;
use lance_namespace::models::{
CreateNamespaceRequest, CreateNamespaceResponse, DescribeNamespaceRequest,
DescribeNamespaceResponse, DropNamespaceRequest, DropNamespaceResponse, ListNamespacesRequest,
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
};
use lance_namespace::LanceNamespace;
use crate::data::scannable::Scannable;
use crate::error::Result;

View File

@@ -8,7 +8,7 @@ use std::path::Path;
use std::{collections::HashMap, sync::Arc};
use lance::dataset::refs::Ref;
use lance::dataset::{ReadParams, WriteMode, builder::DatasetBuilder};
use lance::dataset::{builder::DatasetBuilder, ReadParams, WriteMode};
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use lance_datafusion::utils::StreamingWriteSource;
use lance_encoding::version::LanceFileVersion;
@@ -1097,11 +1097,11 @@ impl Database for ListingDatabase {
#[cfg(test)]
mod tests {
use super::*;
use crate::Table;
use crate::connection::ConnectRequest;
use crate::data::scannable::Scannable;
use crate::database::{CreateTableMode, CreateTableRequest};
use crate::table::WriteOptions;
use crate::Table;
use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::path::PathBuf;

View File

@@ -9,15 +9,16 @@ use std::sync::Arc;
use async_trait::async_trait;
use lance_io::object_store::{ObjectStoreParams, StorageOptionsAccessor};
use lance_namespace::{
LanceNamespace,
models::{
CreateNamespaceRequest, CreateNamespaceResponse, DeclareTableRequest,
DescribeNamespaceRequest, DescribeNamespaceResponse, DescribeTableRequest,
DropNamespaceRequest, DropNamespaceResponse, DropTableRequest, ListNamespacesRequest,
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
CreateEmptyTableRequest, CreateNamespaceRequest, CreateNamespaceResponse,
DeclareTableRequest, DescribeNamespaceRequest, DescribeNamespaceResponse,
DescribeTableRequest, DropNamespaceRequest, DropNamespaceResponse, DropTableRequest,
ListNamespacesRequest, ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
},
LanceNamespace,
};
use lance_namespace_impls::ConnectBuilder;
use log::warn;
use crate::database::ReadConsistency;
use crate::error::{Error, Result};
@@ -212,18 +213,63 @@ impl Database for LanceNamespaceDatabase {
..Default::default()
};
let (location, initial_storage_options) = {
let response = self.namespace.declare_table(declare_request).await?;
let loc = response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from declare_table response".to_string(),
})?;
// Use storage options from response, fall back to self.storage_options
let opts = response
.storage_options
.or_else(|| Some(self.storage_options.clone()))
.filter(|o| !o.is_empty());
(loc, opts)
};
let (location, initial_storage_options) =
match self.namespace.declare_table(declare_request).await {
Ok(response) => {
let loc = response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from declare_table response"
.to_string(),
})?;
// Use storage options from response, fall back to self.storage_options
let opts = response
.storage_options
.or_else(|| Some(self.storage_options.clone()))
.filter(|o| !o.is_empty());
(loc, opts)
}
Err(e) => {
// Check if the error is "not supported" and try create_empty_table as fallback
let err_str = e.to_string().to_lowercase();
if err_str.contains("not supported") || err_str.contains("not implemented") {
warn!(
"declare_table is not supported by the namespace client, \
falling back to deprecated create_empty_table. \
create_empty_table is deprecated and will be removed in Lance 3.0.0. \
Please upgrade your namespace client to support declare_table."
);
#[allow(deprecated)]
let create_empty_request = CreateEmptyTableRequest {
id: Some(table_id.clone()),
..Default::default()
};
#[allow(deprecated)]
let create_response = self
.namespace
.create_empty_table(create_empty_request)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to create empty table: {}", e),
})?;
let loc = create_response.location.ok_or_else(|| Error::Runtime {
message: "Table location is missing from create_empty_table response"
.to_string(),
})?;
// For deprecated path, use self.storage_options
let opts = if self.storage_options.is_empty() {
None
} else {
Some(self.storage_options.clone())
};
(loc, opts)
} else {
return Err(Error::Runtime {
message: format!("Failed to declare table: {}", e),
});
}
}
};
let write_params = if let Some(storage_opts) = initial_storage_options {
let mut params = request.write_options.lance_write_params.unwrap_or_default();

View File

@@ -11,16 +11,16 @@ use lance_core::ROW_ID;
use lance_datafusion::exec::SessionContextExt;
use crate::{
Error, Result, Table,
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
connect,
database::{CreateTableRequest, Database},
dataloader::permutation::{
shuffle::{Shuffler, ShufflerConfig},
split::{SPLIT_ID_COLUMN, SplitStrategy, Splitter},
util::{TemporaryDirectory, rename_column},
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase, Select},
Error, Result, Table,
};
pub const SRC_ROW_ID_COL: &str = "row_id";

View File

@@ -25,8 +25,8 @@ use futures::{StreamExt, TryStreamExt};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::io::RecordBatchStream;
use lance_arrow::RecordBatchExt;
use lance_core::ROW_ID;
use lance_core::error::LanceOptionExt;
use lance_core::ROW_ID;
use std::collections::HashMap;
use std::sync::Arc;
@@ -500,10 +500,10 @@ mod tests {
use rand::seq::SliceRandom;
use crate::{
Table,
arrow::SendableRecordBatchStream,
query::{ExecutableQuery, QueryBase},
test_utils::datagen::{LanceDbDatagenExt, virtual_table},
test_utils::datagen::{virtual_table, LanceDbDatagenExt},
Table,
};
use super::*;

View File

@@ -18,12 +18,12 @@ use lance_io::{
scheduler::{ScanScheduler, SchedulerConfig},
utils::CachedFileSize,
};
use rand::{Rng, RngCore, seq::SliceRandom};
use rand::{seq::SliceRandom, Rng, RngCore};
use crate::{
Error, Result,
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::permutation::util::{TemporaryDirectory, non_crypto_rng},
dataloader::permutation::util::{non_crypto_rng, TemporaryDirectory},
Error, Result,
};
#[derive(Debug, Clone)]
@@ -281,7 +281,7 @@ mod tests {
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed};
use rand::{SeedableRng, rngs::SmallRng};
use rand::{rngs::SmallRng, SeedableRng};
fn test_gen() -> BatchGeneratorBuilder {
lance_datagen::gen_batch()

View File

@@ -2,8 +2,8 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{
Arc,
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc,
};
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
@@ -15,13 +15,13 @@ use lance_arrow::SchemaExt;
use lance_core::ROW_ID;
use crate::{
Error, Result,
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
permutation::shuffle::{Shuffler, ShufflerConfig},
permutation::util::TemporaryDirectory,
},
query::{Query, QueryBase, Select},
Error, Result,
};
pub const SPLIT_ID_COLUMN: &str = "split_id";

View File

@@ -7,12 +7,12 @@ use arrow_array::RecordBatch;
use arrow_schema::{Fields, Schema};
use datafusion_execution::disk_manager::DiskManagerMode;
use futures::TryStreamExt;
use rand::{RngCore, SeedableRng, rngs::SmallRng};
use rand::{rngs::SmallRng, RngCore, SeedableRng};
use tempfile::TempDir;
use crate::{
Error, Result,
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
Error, Result,
};
/// Directory to use for temporary files

View File

@@ -23,9 +23,9 @@ use arrow_schema::{DataType, Field, SchemaBuilder, SchemaRef};
use serde::{Deserialize, Serialize};
use crate::{
Error,
error::Result,
table::{ColumnDefinition, ColumnKind, TableDefinition},
Error,
};
/// Trait for embedding functions

View File

@@ -8,7 +8,7 @@ use arrow::array::{AsArray, Float32Builder};
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use serde_json::{Value, json};
use serde_json::{json, Value};
use super::EmbeddingFunction;
use crate::{Error, Result};

View File

@@ -8,9 +8,9 @@ use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use async_openai::{
Client,
config::OpenAIConfig,
types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat},
Client,
};
use tokio::{runtime::Handle, task};

View File

@@ -7,7 +7,7 @@ use super::EmbeddingFunction;
use arrow::{
array::{AsArray, PrimitiveBuilder},
datatypes::{
ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, Int64Type, UInt8Type, UInt32Type,
ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, Int64Type, UInt32Type, UInt8Type,
},
};
use arrow_array::{Array, FixedSizeListArray, PrimitiveArray};
@@ -16,8 +16,8 @@ use arrow_schema::DataType;
use candle_core::{CpuStorage, Device, Layout, Storage, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, DTYPE};
use hf_hub::{Repo, RepoType, api::sync::Api};
use tokenizers::{PaddingParams, tokenizer::Tokenizer};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{tokenizer::Tokenizer, PaddingParams};
/// Compute embeddings using huggingface sentence-transformers.
pub struct SentenceTransformersEmbeddingsBuilder {
@@ -230,7 +230,7 @@ impl SentenceTransformersEmbeddings {
Storage::Cpu(CpuStorage::BF16(_)) => {
return Err(crate::Error::Runtime {
message: "unsupported data type".to_string(),
});
})
}
_ => unreachable!("we already moved the tensor to the CPU device"),
};
@@ -298,12 +298,12 @@ impl SentenceTransformersEmbeddings {
DataType::Utf8View => {
return Err(crate::Error::Runtime {
message: "Utf8View not yet implemented".to_string(),
});
})
}
_ => {
return Err(crate::Error::Runtime {
message: "invalid type".to_string(),
});
})
}
};

View File

@@ -24,7 +24,7 @@ pub use sql::expr_to_sql_string;
use std::sync::Arc;
use arrow_schema::DataType;
use datafusion_expr::{Expr, ScalarUDF, expr_fn::cast};
use datafusion_expr::{expr_fn::cast, Expr, ScalarUDF};
use datafusion_functions::string::expr_fn as string_expr_fn;
pub use datafusion_expr::{col, lit};

View File

@@ -9,7 +9,7 @@ use std::time::Duration;
use vector::IvfFlatIndexBuilder;
use crate::index::vector::IvfRqIndexBuilder;
use crate::{DistanceType, Error, Result, table::BaseTable};
use crate::{table::BaseTable, DistanceType, Error, Result};
use self::{
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},

View File

@@ -27,7 +27,7 @@
///
/// The btree index does not currently have any parameters though parameters such as the
/// block size may be added in the future.
#[derive(Default, Debug, Clone, serde::Serialize)]
#[derive(Default, Debug, Clone)]
pub struct BTreeIndexBuilder {}
impl BTreeIndexBuilder {}
@@ -39,7 +39,7 @@ impl BTreeIndexBuilder {}
/// This index works best for low-cardinality (i.e., less than 1000 unique values) columns,
/// where the number of unique values is small.
/// The bitmap stores a list of row ids where the value is present.
#[derive(Debug, Clone, Default, serde::Serialize)]
#[derive(Debug, Clone, Default)]
pub struct BitmapIndexBuilder {}
/// Builder for LabelList index.
@@ -48,10 +48,10 @@ pub struct BitmapIndexBuilder {}
/// support queries with `array_contains_all` and `array_contains_any`
/// using an underlying bitmap index.
///
#[derive(Debug, Clone, Default, serde::Serialize)]
#[derive(Debug, Clone, Default)]
pub struct LabelListIndexBuilder {}
pub use lance_index::scalar::inverted::query::*;
pub use lance_index::scalar::FullTextSearchQuery;
pub use lance_index::scalar::InvertedIndexParams as FtsIndexBuilder;
pub use lance_index::scalar::InvertedIndexParams;
pub use lance_index::scalar::inverted::query::*;

View File

@@ -7,7 +7,6 @@
//! Vector indices are only supported on fixed-size-list (tensor) columns of floating point
//! values
use lance::table::format::{IndexMetadata, Manifest};
use serde::Serialize;
use crate::DistanceType;
@@ -182,17 +181,14 @@ macro_rules! impl_hnsw_params_setter {
/// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create.
///
/// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation.
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone)]
pub struct IvfFlatIndexBuilder {
#[serde(rename = "metric_type")]
pub(crate) distance_type: DistanceType,
// IVF
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) target_partition_size: Option<u32>,
}
@@ -217,17 +213,14 @@ impl IvfFlatIndexBuilder {
///
/// This index compresses vectors using scalar quantization and groups them into IVF partitions.
/// It offers a balance between search performance and storage footprint.
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone)]
pub struct IvfSqIndexBuilder {
#[serde(rename = "metric_type")]
pub(crate) distance_type: DistanceType,
// IVF
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) target_partition_size: Option<u32>,
}
@@ -268,23 +261,18 @@ impl IvfSqIndexBuilder {
///
/// Note that training an IVF PQ index on a large dataset is a slow operation and
/// currently is also a memory intensive operation.
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone)]
pub struct IvfPqIndexBuilder {
#[serde(rename = "metric_type")]
pub(crate) distance_type: DistanceType,
// IVF
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) target_partition_size: Option<u32>,
// PQ
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_sub_vectors: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_bits: Option<u32>,
}
@@ -335,18 +323,14 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
///
/// Note that training an IVF RQ index on a large dataset is a slow operation and
/// currently is also a memory intensive operation.
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone)]
pub struct IvfRqIndexBuilder {
// IVF
#[serde(rename = "metric_type")]
pub(crate) distance_type: DistanceType,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_partitions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_bits: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) target_partition_size: Option<u32>,
}
@@ -381,16 +365,13 @@ impl IvfRqIndexBuilder {
/// quickly find the closest vectors to a query vector.
///
/// The PQ (product quantizer) is used to compress the vectors as the same as IVF PQ.
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone)]
pub struct IvfHnswPqIndexBuilder {
// IVF
#[serde(rename = "metric_type")]
pub(crate) distance_type: DistanceType,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) target_partition_size: Option<u32>,
// HNSW
@@ -398,9 +379,7 @@ pub struct IvfHnswPqIndexBuilder {
pub(crate) ef_construction: u32,
// PQ
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_sub_vectors: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_bits: Option<u32>,
}
@@ -436,16 +415,13 @@ impl IvfHnswPqIndexBuilder {
///
/// The SQ (scalar quantizer) is used to compress the vectors,
/// each vector is mapped to a 8-bit integer vector, 4x compression ratio for float32 vector.
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone)]
pub struct IvfHnswSqIndexBuilder {
// IVF
#[serde(rename = "metric_type")]
pub(crate) distance_type: DistanceType,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) target_partition_size: Option<u32>,
// HNSW

View File

@@ -1,9 +1,9 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::Error;
use crate::error::Result;
use crate::table::BaseTable;
use crate::Error;
use log::debug;
use std::time::{Duration, Instant};
use tokio::time::sleep;

View File

@@ -5,11 +5,11 @@
use std::{fmt::Formatter, sync::Arc};
use futures::{TryFutureExt, stream::BoxStream};
use futures::{stream::BoxStream, TryFutureExt};
use lance::io::WrappingObjectStore;
use object_store::{
Error, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, UploadPart, path::Path,
path::Path, Error, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, UploadPart,
};
use async_trait::async_trait;

View File

@@ -10,9 +10,8 @@ use bytes::Bytes;
use futures::stream::BoxStream;
use lance::io::WrappingObjectStore;
use object_store::{
GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
path::Path,
};
#[derive(Debug, Default)]

View File

@@ -5,26 +5,26 @@ use std::sync::Arc;
use std::{future::Future, time::Duration};
use arrow::compute::concat_batches;
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::{DataType, SchemaRef};
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan;
use futures::{FutureExt, TryFutureExt, TryStreamExt, stream, try_join};
use futures::{stream, try_join, FutureExt, TryFutureExt, TryStreamExt};
use half::f16;
use lance::dataset::{ROW_ID, scanner::DatasetRecordBatchStream};
use lance::dataset::{scanner::DatasetRecordBatchStream, ROW_ID};
use lance_arrow::RecordBatchExt;
use lance_datafusion::exec::execute_plan;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::SCORE_COL;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::DistanceType;
use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
use crate::table::BaseTable;
use crate::utils::TimeoutStream;
use crate::DistanceType;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
mod hybrid;
@@ -161,11 +161,10 @@ impl IntoQueryVector for &dyn Array {
if data_type != self.data_type() {
Err(Error::InvalidInput {
message: format!(
"failed to create query vector, the input data type was {:?} but the expected data type was {:?}",
self.data_type(),
data_type
),
})
"failed to create query vector, the input data type was {:?} but the expected data type was {:?}",
self.data_type(),
data_type
)})
} else {
let data = self.to_data();
Ok(make_array(data))
@@ -187,7 +186,7 @@ impl IntoQueryVector for &[f16] {
DataType::Float32 => {
let arr: Vec<f32> = self.iter().map(|x| f32::from(*x)).collect();
Ok(Arc::new(Float32Array::from(arr)))
}
},
DataType::Float64 => {
let arr: Vec<f64> = self.iter().map(|x| f64::from(*x)).collect();
Ok(Arc::new(Float64Array::from(arr)))
@@ -195,7 +194,8 @@ impl IntoQueryVector for &[f16] {
_ => Err(Error::InvalidInput {
message: format!(
"failed to create query vector, the input data type was &[f16] but the embedding model \"{}\" expected data type {:?}",
embedding_model_label, data_type
embedding_model_label,
data_type
),
}),
}
@@ -216,7 +216,7 @@ impl IntoQueryVector for &[f32] {
DataType::Float32 => {
let arr: Vec<f32> = self.to_vec();
Ok(Arc::new(Float32Array::from(arr)))
}
},
DataType::Float64 => {
let arr: Vec<f64> = self.iter().map(|x| *x as f64).collect();
Ok(Arc::new(Float64Array::from(arr)))
@@ -224,7 +224,8 @@ impl IntoQueryVector for &[f32] {
_ => Err(Error::InvalidInput {
message: format!(
"failed to create query vector, the input data type was &[f32] but the embedding model \"{}\" expected data type {:?}",
embedding_model_label, data_type
embedding_model_label,
data_type
),
}),
}
@@ -238,25 +239,26 @@ impl IntoQueryVector for &[f64] {
embedding_model_label: &str,
) -> Result<Arc<dyn Array>> {
match data_type {
DataType::Float16 => {
let arr: Vec<f16> = self.iter().map(|x| f16::from_f64(*x)).collect();
Ok(Arc::new(Float16Array::from(arr)))
DataType::Float16 => {
let arr: Vec<f16> = self.iter().map(|x| f16::from_f64(*x)).collect();
Ok(Arc::new(Float16Array::from(arr)))
}
DataType::Float32 => {
let arr: Vec<f32> = self.iter().map(|x| *x as f32).collect();
Ok(Arc::new(Float32Array::from(arr)))
},
DataType::Float64 => {
let arr: Vec<f64> = self.to_vec();
Ok(Arc::new(Float64Array::from(arr)))
}
_ => Err(Error::InvalidInput {
message: format!(
"failed to create query vector, the input data type was &[f64] but the embedding model \"{}\" expected data type {:?}",
embedding_model_label,
data_type
),
}),
}
DataType::Float32 => {
let arr: Vec<f32> = self.iter().map(|x| *x as f32).collect();
Ok(Arc::new(Float32Array::from(arr)))
}
DataType::Float64 => {
let arr: Vec<f64> = self.to_vec();
Ok(Arc::new(Float64Array::from(arr)))
}
_ => Err(Error::InvalidInput {
message: format!(
"failed to create query vector, the input data type was &[f64] but the embedding model \"{}\" expected data type {:?}",
embedding_model_label, data_type
),
}),
}
}
}
@@ -1009,13 +1011,13 @@ impl VectorQuery {
message: "minimum_nprobes must be greater than 0".to_string(),
});
}
if let Some(maximum_nprobes) = self.request.maximum_nprobes
&& minimum_nprobes > maximum_nprobes
{
return Err(Error::InvalidInput {
message: "minimum_nprobes must be less than or equal to maximum_nprobes"
.to_string(),
});
if let Some(maximum_nprobes) = self.request.maximum_nprobes {
if minimum_nprobes > maximum_nprobes {
return Err(Error::InvalidInput {
message: "minimum_nprobes must be less than or equal to maximum_nprobes"
.to_string(),
});
}
}
self.request.minimum_nprobes = minimum_nprobes;
Ok(self)
@@ -1405,8 +1407,8 @@ mod tests {
use super::*;
use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type};
use arrow_array::{
FixedSizeListArray, Float32Array, Int32Array, RecordBatch, StringArray, cast::AsArray,
types::Float32Type,
cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array,
RecordBatch, StringArray,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::{StreamExt, TryStreamExt};
@@ -1414,7 +1416,7 @@ mod tests {
use rand::seq::IndexedRandom;
use tempfile::tempdir;
use crate::{Table, connect, database::CreateTableMode, index::Index};
use crate::{connect, database::CreateTableMode, index::Index, Table};
#[tokio::test]
async fn test_setters_getters() {
@@ -1752,13 +1754,11 @@ mod tests {
.limit(1)
.execute()
.await;
assert!(
error_result
.err()
.unwrap()
.to_string()
.contains("No vector column found to match with the query vector dimension: 3")
);
assert!(error_result
.err()
.unwrap()
.to_string()
.contains("No vector column found to match with the query vector dimension: 3"));
}
#[tokio::test]
@@ -2010,7 +2010,7 @@ mod tests {
// Sample 1 - 3 tokens for each string value
let tokens = ["a", "b", "c", "d", "e"];
use rand::{Rng, rng};
use rand::{rng, Rng};
let mut rng = rng();
let text: StringArray = (0..nrows)

View File

@@ -5,7 +5,7 @@ use arrow::compute::{
kernels::numeric::{div, sub},
max, min,
};
use arrow_array::{Float32Array, RecordBatch, cast::downcast_array};
use arrow_array::{cast::downcast_array, Float32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema, SortOptions};
use lance::dataset::ROW_ID;
use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL};
@@ -253,10 +253,7 @@ mod test {
let result = rank(batch.clone(), "bad_col", None);
match result {
Err(Error::InvalidInput { message }) => {
assert_eq!(
"expected column bad_col not found in rank. found columns [\"name\", \"score\"]",
message
);
assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message);
}
_ => {
panic!("expected invalid input error, received {:?}", result)

View File

@@ -4,8 +4,8 @@
use http::HeaderName;
use log::debug;
use reqwest::{
Body, Request, RequestBuilder, Response,
header::{HeaderMap, HeaderValue},
Body, Request, RequestBuilder, Response,
};
use std::{collections::HashMap, future::Future, str::FromStr, sync::Arc, time::Duration};
@@ -14,7 +14,6 @@ use crate::remote::db::RemoteOptions;
use crate::remote::retry::{ResolvedRetryConfig, RetryCounter};
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
const MEM_WAL_ENABLED_HEADER: HeaderName = HeaderName::from_static("x-lancedb-mem-wal-enabled");
/// Configuration for TLS/mTLS settings.
#[derive(Clone, Debug, Default)]
@@ -53,10 +52,6 @@ pub struct ClientConfig {
pub tls_config: Option<TlsConfig>,
/// Provider for custom headers to be added to each request
pub header_provider: Option<Arc<dyn HeaderProvider>>,
/// Enable MemWAL write path for streaming writes.
/// When true, write operations will use the MemWAL architecture
/// for high-performance streaming writes.
pub mem_wal_enabled: Option<bool>,
}
impl std::fmt::Debug for ClientConfig {
@@ -72,7 +67,6 @@ impl std::fmt::Debug for ClientConfig {
"header_provider",
&self.header_provider.as_ref().map(|_| "Some(...)"),
)
.field("mem_wal_enabled", &self.mem_wal_enabled)
.finish()
}
}
@@ -87,7 +81,6 @@ impl Default for ClientConfig {
id_delimiter: None,
tls_config: None,
header_provider: None,
mem_wal_enabled: None,
}
}
}
@@ -453,23 +446,13 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
})?,
);
}
// Map azure storage options to x-azure-* headers.
// The option key uses underscores (e.g. "azure_client_id") while the
// header uses hyphens (e.g. "x-azure-client-id").
let azure_opts: [(&str, &str); 3] = [
("azure_storage_account_name", "x-azure-storage-account-name"),
("azure_client_id", "x-azure-client-id"),
("azure_tenant_id", "x-azure-tenant-id"),
];
for (opt_key, header_name) in azure_opts {
if let Some(v) = options.0.get(opt_key) {
headers.insert(
HeaderName::from_static(header_name),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii value '{}' for option '{}'", v, opt_key),
})?,
);
}
if let Some(v) = options.0.get("azure_storage_account_name") {
headers.insert(
HeaderName::from_static("x-azure-storage-account-name"),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name),
})?,
);
}
for (key, value) in &config.extra_headers {
@@ -484,11 +467,6 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
);
}
// Add MemWAL header if enabled
if let Some(true) = config.mem_wal_enabled {
headers.insert(MEM_WAL_ENABLED_HEADER, HeaderValue::from_static("true"));
}
Ok(headers)
}
@@ -672,13 +650,14 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
pub fn extract_request_id(&self, request: &mut Request) -> String {
// Set a request id.
// TODO: allow the user to supply this, through middleware?
if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
self.set_request_id(request, &request_id);
request_id
}
};
request_id
}
/// Set the request ID header
@@ -1097,34 +1076,4 @@ mod tests {
_ => panic!("Expected Runtime error"),
}
}
#[test]
fn test_default_headers_azure_opts() {
let mut opts = HashMap::new();
opts.insert(
"azure_storage_account_name".to_string(),
"myaccount".to_string(),
);
opts.insert("azure_client_id".to_string(), "my-client-id".to_string());
opts.insert("azure_tenant_id".to_string(), "my-tenant-id".to_string());
let remote_opts = RemoteOptions::new(opts);
let headers = RestfulLanceDbClient::<Sender>::default_headers(
"test-key",
"us-east-1",
"testdb",
false,
&remote_opts,
None,
&ClientConfig::default(),
)
.unwrap();
assert_eq!(
headers.get("x-azure-storage-account-name").unwrap(),
"myaccount"
);
assert_eq!(headers.get("x-azure-client-id").unwrap(), "my-client-id");
assert_eq!(headers.get("x-azure-tenant-id").unwrap(), "my-tenant-id");
}
}

View File

@@ -16,7 +16,6 @@ use lance_namespace::models::{
ListNamespacesResponse, ListTablesRequest, ListTablesResponse,
};
use crate::Error;
use crate::database::{
CloneTableRequest, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
OpenTableRequest, ReadConsistency, TableNamesRequest,
@@ -24,11 +23,12 @@ use crate::database::{
use crate::error::Result;
use crate::remote::util::stream_as_body;
use crate::table::BaseTable;
use crate::Error;
use super::ARROW_STREAM_CONTENT_TYPE;
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
use super::table::RemoteTable;
use super::util::parse_server_version;
use super::ARROW_STREAM_CONTENT_TYPE;
// Request structure for the remote clone table API
#[derive(serde::Serialize)]
@@ -78,7 +78,6 @@ pub const OPT_REMOTE_PREFIX: &str = "remote_database_";
pub const OPT_REMOTE_API_KEY: &str = "remote_database_api_key";
pub const OPT_REMOTE_REGION: &str = "remote_database_region";
pub const OPT_REMOTE_HOST_OVERRIDE: &str = "remote_database_host_override";
pub const OPT_REMOTE_MEM_WAL_ENABLED: &str = "remote_database_mem_wal_enabled";
// TODO: add support for configuring client config via key/value options
#[derive(Clone, Debug, Default)]
@@ -99,12 +98,6 @@ pub struct RemoteDatabaseOptions {
/// These options are only used for LanceDB Enterprise and only a subset of options
/// are supported.
pub storage_options: HashMap<String, String>,
/// Enable MemWAL write path for high-performance streaming writes.
///
/// When enabled, write operations (insert, merge_insert, etc.) will use
/// the MemWAL architecture which buffers writes in memory and Write-Ahead Log
/// before asynchronously merging to the base table.
pub mem_wal_enabled: Option<bool>,
}
impl RemoteDatabaseOptions {
@@ -116,9 +109,6 @@ impl RemoteDatabaseOptions {
let api_key = map.get(OPT_REMOTE_API_KEY).cloned();
let region = map.get(OPT_REMOTE_REGION).cloned();
let host_override = map.get(OPT_REMOTE_HOST_OVERRIDE).cloned();
let mem_wal_enabled = map
.get(OPT_REMOTE_MEM_WAL_ENABLED)
.map(|v| v.to_lowercase() == "true");
let storage_options = map
.iter()
.filter(|(key, _)| !key.starts_with(OPT_REMOTE_PREFIX))
@@ -129,7 +119,6 @@ impl RemoteDatabaseOptions {
region,
host_override,
storage_options,
mem_wal_enabled,
})
}
}
@@ -148,12 +137,6 @@ impl DatabaseOptions for RemoteDatabaseOptions {
if let Some(host_override) = &self.host_override {
map.insert(OPT_REMOTE_HOST_OVERRIDE.to_string(), host_override.clone());
}
if let Some(mem_wal_enabled) = &self.mem_wal_enabled {
map.insert(
OPT_REMOTE_MEM_WAL_ENABLED.to_string(),
mem_wal_enabled.to_string(),
);
}
}
}
@@ -198,20 +181,6 @@ impl RemoteDatabaseOptionsBuilder {
self.options.host_override = Some(host_override);
self
}
/// Enable MemWAL write path for high-performance streaming writes.
///
/// When enabled, write operations will use the MemWAL architecture
/// which buffers writes in memory and Write-Ahead Log before
/// asynchronously merging to the base table.
///
/// # Arguments
///
/// * `enabled` - Whether to enable MemWAL writes
pub fn mem_wal_enabled(mut self, enabled: bool) -> Self {
self.options.mem_wal_enabled = Some(enabled);
self
}
}
#[derive(Debug)]
@@ -280,9 +249,9 @@ impl RemoteDatabase {
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
use crate::remote::ClientConfig;
use crate::remote::client::test_utils::MockSender;
use crate::remote::client::test_utils::{client_with_handler, client_with_handler_and_config};
use crate::remote::ClientConfig;
impl RemoteDatabase<MockSender> {
pub fn new_mock<F, T>(handler: F) -> Self
@@ -808,12 +777,7 @@ impl RemoteOptions {
impl From<StorageOptions> for RemoteOptions {
fn from(options: StorageOptions) -> Self {
let supported_opts = vec![
"account_name",
"azure_storage_account_name",
"azure_client_id",
"azure_tenant_id",
];
let supported_opts = vec!["account_name", "azure_storage_account_name"];
let mut filtered = HashMap::new();
for opt in supported_opts {
if let Some(v) = options.0.get(opt) {
@@ -835,9 +799,9 @@ mod tests {
use crate::connection::ConnectBuilder;
use crate::{
Connection, Error,
database::CreateTableMode,
remote::{ARROW_STREAM_CONTENT_TYPE, ClientConfig, HeaderProvider, JSON_CONTENT_TYPE},
remote::{ClientConfig, HeaderProvider, ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
Connection, Error,
};
#[test]

View File

@@ -1,8 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::Error;
use crate::remote::RetryConfig;
use crate::Error;
use log::debug;
use std::time::Duration;

View File

@@ -6,14 +6,15 @@ pub mod insert;
use self::insert::RemoteInsertExec;
use crate::expr::expr_to_sql_string;
use super::ARROW_STREAM_CONTENT_TYPE;
use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE;
use crate::index::waiter::wait_for_index;
use crate::index::Index;
use crate::index::IndexStatistics;
use crate::index::waiter::wait_for_index;
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
use crate::table::query::create_multi_vector_plan;
use crate::table::AddColumnsResult;
use crate::table::AddResult;
use crate::table::AlterColumnsResult;
@@ -22,20 +23,19 @@ use crate::table::DropColumnsResult;
use crate::table::MergeResult;
use crate::table::Tags;
use crate::table::UpdateResult;
use crate::table::query::create_multi_vector_plan;
use crate::table::{AnyQuery, Filter, TableStatistics};
use crate::utils::background_cache::BackgroundCache;
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::{DistanceType, Error};
use crate::{
error::Result,
index::{IndexBuilder, IndexConfig},
query::QueryExecutionOptions,
table::{
AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats, TableDefinition, UpdateBuilder,
merge::MergeInsertBuilder,
merge::MergeInsertBuilder, AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats,
TableDefinition, UpdateBuilder,
},
};
use crate::{DistanceType, Error};
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef};
@@ -50,7 +50,7 @@ use lance::arrow::json::{JsonDataType, JsonSchema};
use lance::dataset::refs::TagContents;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
use lance_datafusion::exec::{OneShotExec, execute_plan};
use lance_datafusion::exec::{execute_plan, OneShotExec};
use reqwest::{RequestBuilder, Response};
use serde::{Deserialize, Serialize};
use serde_json::Number;
@@ -62,7 +62,6 @@ use std::time::Duration;
use tokio::sync::RwLock;
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
const MEM_WAL_ENABLED_HEADER: HeaderName = HeaderName::from_static("x-lancedb-mem-wal-enabled");
const METRIC_TYPE_KEY: &str = "metric_type";
const INDEX_TYPE_KEY: &str = "index_type";
const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(30);
@@ -613,8 +612,8 @@ impl<S: HttpSend> RemoteTable<S> {
message: format!(
"Cannot mutate table reference fixed at version {}. Call checkout_latest() to get a mutable table reference.",
version
),
}),
)
})
}
}
@@ -698,10 +697,10 @@ impl<S: HttpSend> RemoteTable<S> {
Error::Retry { status_code, .. } => *status_code,
_ => None,
};
if let Some(status_code) = status_code
&& Self::should_invalidate_cache_for_status(status_code)
{
self.invalidate_schema_cache();
if let Some(status_code) = status_code {
if Self::should_invalidate_cache_for_status(status_code) {
self.invalidate_schema_cache();
}
}
}
}
@@ -784,9 +783,9 @@ impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
use crate::remote::ClientConfig;
use crate::remote::client::test_utils::client_with_handler;
use crate::remote::client::test_utils::{MockSender, client_with_handler_and_config};
use crate::remote::client::test_utils::{client_with_handler_and_config, MockSender};
use crate::remote::ClientConfig;
impl RemoteTable<MockSender> {
pub fn new_mock<F, T>(name: String, handler: F, version: Option<semver::Version>) -> Self
@@ -1252,13 +1251,13 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
0 => {
return Err(Error::InvalidInput {
message: "No columns specified".into(),
});
})
}
1 => index.columns.pop().unwrap(),
_ => {
return Err(Error::NotSupported {
message: "Indices over multiple columns not yet supported".into(),
});
})
}
};
let mut body = serde_json::json!({
@@ -1277,24 +1276,73 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
);
}
fn to_json(params: &impl serde::Serialize) -> crate::Result<serde_json::Value> {
serde_json::to_value(params).map_err(|e| Error::InvalidInput {
message: format!("failed to serialize index params {:?}", e),
})
}
// Map each Index variant to its wire type name and serializable params.
// Auto is special-cased since it needs schema inspection.
let (index_type_str, params) = match &index.index {
Index::IvfFlat(p) => ("IVF_FLAT", Some(to_json(p)?)),
Index::IvfPq(p) => ("IVF_PQ", Some(to_json(p)?)),
Index::IvfSq(p) => ("IVF_SQ", Some(to_json(p)?)),
Index::IvfHnswSq(p) => ("IVF_HNSW_SQ", Some(to_json(p)?)),
Index::IvfRq(p) => ("IVF_RQ", Some(to_json(p)?)),
Index::BTree(p) => ("BTREE", Some(to_json(p)?)),
Index::Bitmap(p) => ("BITMAP", Some(to_json(p)?)),
Index::LabelList(p) => ("LABEL_LIST", Some(to_json(p)?)),
Index::FTS(p) => ("FTS", Some(to_json(p)?)),
match index.index {
// TODO: Should we pass the actual index parameters? SaaS does not
// yet support them.
Index::IvfFlat(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_FLAT".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::IvfPq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
if let Some(num_bits) = index.num_bits {
body["num_bits"] = serde_json::Value::Number(num_bits.into());
}
}
Index::IvfSq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_SQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::IvfHnswSq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::IvfRq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_RQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
if let Some(num_bits) = index.num_bits {
body["num_bits"] = serde_json::Value::Number(num_bits.into());
}
}
Index::BTree(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
}
Index::Bitmap(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("BITMAP".to_string());
}
Index::LabelList(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("LABEL_LIST".to_string());
}
Index::FTS(fts) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("FTS".to_string());
let params = serde_json::to_value(&fts).map_err(|e| Error::InvalidInput {
message: format!("failed to serialize FTS index params {:?}", e),
})?;
for (key, value) in params.as_object().unwrap() {
body[key] = value.clone();
}
}
Index::Auto => {
let schema = self.schema().await?;
let field = schema
@@ -1303,11 +1351,11 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
message: format!("Column {} not found in schema", column),
})?;
if supported_vector_data_type(field.data_type()) {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(DistanceType::L2.to_string().to_lowercase());
("IVF_PQ", None)
} else if supported_btree_data_type(field.data_type()) {
("BTREE", None)
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
} else {
return Err(Error::NotSupported {
message: format!(
@@ -1321,17 +1369,10 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
_ => {
return Err(Error::NotSupported {
message: "Index type not supported".into(),
});
})
}
};
body[INDEX_TYPE_KEY] = index_type_str.into();
if let Some(params) = params {
for (key, value) in params.as_object().expect("params should be a JSON object") {
body[key] = value.clone();
}
}
let request = request.json(&body);
let (request_id, response) = self.send(request, true).await?;
@@ -1360,7 +1401,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
self.check_mutable().await?;
let timeout = params.timeout;
let mem_wal = params.mem_wal;
let query = MergeInsertRequest::try_from(params)?;
let mut request = self
@@ -1376,10 +1416,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
}
}
if mem_wal {
request = request.header(MEM_WAL_ENABLED_HEADER, "true");
}
let (request_id, response) = self.send_streaming(request, new_data, true).await?;
let response = self.check_table_response(&request_id, response).await?;
@@ -1777,8 +1813,8 @@ impl TryFrom<MergeInsertBuilder> for MergeInsertRequest {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, pin::Pin};
@@ -1787,27 +1823,25 @@ mod tests {
use crate::table::AddDataMode;
use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, record_batch};
use arrow_array::{record_batch, Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use chrono::{DateTime, Utc};
use futures::{StreamExt, TryFutureExt, future::BoxFuture};
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
use lance_index::scalar::inverted::query::MatchQuery;
use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams};
use reqwest::Body;
use rstest::rstest;
use serde_json::json;
use crate::index::vector::{
IvfFlatIndexBuilder, IvfHnswSqIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder,
};
use crate::remote::JSON_CONTENT_TYPE;
use crate::index::vector::{IvfFlatIndexBuilder, IvfHnswSqIndexBuilder};
use crate::remote::db::DEFAULT_SERVER_VERSION;
use crate::remote::JSON_CONTENT_TYPE;
use crate::utils::background_cache::clock;
use crate::{
DistanceType, Error, Table,
index::{Index, IndexStatistics, IndexType, vector::IvfPqIndexBuilder},
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase},
remote::ARROW_FILE_CONTENT_TYPE,
DistanceType, Error, Table,
};
#[tokio::test]
@@ -2036,13 +2070,11 @@ mod tests {
.unwrap(),
"/v1/table/my_table/insert/" => {
assert_eq!(request.method(), "POST");
assert!(
request
.url()
.query_pairs()
.filter(|(k, _)| k == "mode")
.all(|(_, v)| v == "append")
);
assert!(request
.url()
.query_pairs()
.filter(|(k, _)| k == "mode")
.all(|(_, v)| v == "append"));
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
@@ -2963,8 +2995,6 @@ mod tests {
"IVF_FLAT",
json!({
"metric_type": "hamming",
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
),
@@ -2973,8 +3003,6 @@ mod tests {
json!({
"metric_type": "hamming",
"num_partitions": 128,
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfFlat(
IvfFlatIndexBuilder::default()
@@ -2986,8 +3014,6 @@ mod tests {
"IVF_PQ",
json!({
"metric_type": "l2",
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfPq(Default::default()),
),
@@ -2997,8 +3023,6 @@ mod tests {
"metric_type": "cosine",
"num_partitions": 128,
"num_bits": 4,
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfPq(
IvfPqIndexBuilder::default()
@@ -3007,29 +3031,10 @@ mod tests {
.num_bits(4),
),
),
(
"IVF_PQ",
json!({
"metric_type": "l2",
"num_sub_vectors": 16,
"sample_rate": 512,
"max_iterations": 100,
}),
Index::IvfPq(
IvfPqIndexBuilder::default()
.num_sub_vectors(16)
.sample_rate(512)
.max_iterations(100),
),
),
(
"IVF_HNSW_SQ",
json!({
"metric_type": "l2",
"sample_rate": 256,
"max_iterations": 50,
"m": 20,
"ef_construction": 300,
}),
Index::IvfHnswSq(Default::default()),
),
@@ -3038,65 +3043,11 @@ mod tests {
json!({
"metric_type": "l2",
"num_partitions": 128,
"sample_rate": 256,
"max_iterations": 50,
"m": 40,
"ef_construction": 500,
}),
Index::IvfHnswSq(
IvfHnswSqIndexBuilder::default()
.distance_type(DistanceType::L2)
.num_partitions(128)
.num_edges(40)
.ef_construction(500),
),
),
(
"IVF_SQ",
json!({
"metric_type": "l2",
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfSq(Default::default()),
),
(
"IVF_SQ",
json!({
"metric_type": "cosine",
"num_partitions": 64,
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfSq(
IvfSqIndexBuilder::default()
.distance_type(DistanceType::Cosine)
.num_partitions(64),
),
),
(
"IVF_RQ",
json!({
"metric_type": "l2",
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfRq(Default::default()),
),
(
"IVF_RQ",
json!({
"metric_type": "cosine",
"num_partitions": 64,
"num_bits": 8,
"sample_rate": 256,
"max_iterations": 50,
}),
Index::IvfRq(
IvfRqIndexBuilder::default()
.distance_type(DistanceType::Cosine)
.num_partitions(64)
.num_bits(8),
.num_partitions(128),
),
),
// HNSW_PQ isn't yet supported on SaaS
@@ -3600,7 +3551,7 @@ mod tests {
}
fn _make_table_with_indices(unindexed_rows: usize) -> Table {
Table::new_with_handler("my_table", move |request| {
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
let response_body = match request.url().path() {
@@ -3644,7 +3595,8 @@ mod tests {
let body = serde_json::to_string(&response_body).unwrap();
let status = if body == "null" { 404 } else { 200 };
http::Response::builder().status(status).body(body).unwrap()
})
});
table
}
#[tokio::test]
@@ -3855,8 +3807,8 @@ mod tests {
#[tokio::test]
async fn test_uri_caching() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();

View File

@@ -16,12 +16,12 @@ use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, Plan
use futures::StreamExt;
use http::header::CONTENT_TYPE;
use crate::Error;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender};
use crate::remote::table::RemoteTable;
use crate::table::AddResult;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
use crate::table::datafusion::insert::COUNT_SCHEMA;
use crate::table::AddResult;
use crate::Error;
/// ExecutionPlan for inserting data into a remote LanceDB table.
///
@@ -309,12 +309,12 @@ mod tests {
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::Table;
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
use crate::table::datafusion::BaseTableAdapter;
use crate::Table;
fn schema_json() -> &'static str {
r#"{"fields": [{"name": "id", "type": {"type": "int32"}, "nullable": true}]}"#

View File

@@ -5,7 +5,7 @@ use arrow_ipc::CompressionType;
use futures::{Stream, StreamExt};
use reqwest::Response;
use crate::{Result, arrow::SendableRecordBatchStream};
use crate::{arrow::SendableRecordBatchStream, Result};
use super::db::ServerVersion;

View File

@@ -14,7 +14,7 @@ use async_trait::async_trait;
use lance::dataset::ROW_ID;
use crate::error::{Error, Result};
use crate::rerankers::{RELEVANCE_SCORE, Reranker};
use crate::rerankers::{Reranker, RELEVANCE_SCORE};
/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
/// on the scores of vector and FTS search.

View File

@@ -8,29 +8,29 @@ use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_execution::TaskContext;
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::display::DisplayableExecutionPlan;
use futures::StreamExt;
use datafusion_physical_plan::ExecutionPlan;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use lance::dataset::builder::DatasetBuilder;
pub use lance::dataset::ColumnAlteration;
pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams;
pub use lance::dataset::Version;
use lance::dataset::WriteMode;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::{InsertBuilder, WriteParams};
use lance::index::vector::VectorIndexParams;
use lance::index::vector::utils::infer_vector_dim;
use lance::index::vector::VectorIndexParams;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance_datafusion::utils::StreamingWriteSource;
use lance_index::DatasetIndexExt;
use lance_index::IndexType;
use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams};
use lance_index::vector::bq::RQBuildParams;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::sq::builder::SQBuildParams;
use lance_index::DatasetIndexExt;
use lance_index::IndexType;
use lance_io::object_store::{LanceNamespaceStorageOptionsProvider, StorageOptionsAccessor};
pub use query::AnyQuery;
@@ -43,19 +43,19 @@ use std::format;
use std::path::Path;
use std::sync::Arc;
use crate::data::scannable::{PeekedScannable, Scannable, estimate_write_partitions};
use crate::data::scannable::{estimate_write_partitions, PeekedScannable, Scannable};
use crate::database::Database;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::IndexStatistics;
use crate::index::vector::VectorIndex;
use crate::index::{Index, IndexBuilder, vector::suggested_num_sub_vectors};
use crate::index::IndexStatistics;
use crate::index::{vector::suggested_num_sub_vectors, Index, IndexBuilder};
use crate::index::{IndexConfig, IndexStatisticsImpl};
use crate::query::{IntoQueryVector, Query, QueryExecutionOptions, TakeQuery, VectorQuery};
use crate::table::datafusion::insert::InsertExec;
use crate::utils::{
PatchReadParam, PatchWriteParam, supported_bitmap_data_type, supported_btree_data_type,
supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type,
supported_bitmap_data_type, supported_btree_data_type, supported_fts_data_type,
supported_label_list_data_type, supported_vector_data_type, PatchReadParam, PatchWriteParam,
};
use self::dataset::DatasetConsistencyWrapper;
@@ -2555,21 +2555,22 @@ pub struct FragmentSummaryStats {
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use arrow_array::{
builder::{ListBuilder, StringBuilder},
Array, BooleanArray, FixedSizeListArray, Int32Array, LargeStringArray, RecordBatch,
RecordBatchIterator, RecordBatchReader, StringArray,
builder::{ListBuilder, StringBuilder},
};
use arrow_array::{BinaryArray, LargeBinaryArray};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lance::Dataset;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance::Dataset;
use rand::Rng;
use tempfile::tempdir;
use super::*;
@@ -2776,8 +2777,9 @@ mod tests {
false,
)]));
let mut rng = rand::thread_rng();
let float_arr = Float32Array::from(
repeat_with(rand::random::<f32>)
repeat_with(|| rng.gen::<f32>())
.take(512 * dimension as usize)
.collect::<Vec<f32>>(),
);
@@ -2882,8 +2884,8 @@ mod tests {
.await
.unwrap();
use lance::index::DatasetIndexInternalExt;
use lance::index::vector::ivf::v2::IvfPq as LanceIvfPq;
use lance::index::DatasetIndexInternalExt;
use lance_index::metrics::NoOpMetricsCollector;
use lance_index::vector::VectorIndex as LanceVectorIndex;
@@ -2931,8 +2933,9 @@ mod tests {
false,
)]));
let mut rng = rand::thread_rng();
let float_arr = Float32Array::from(
repeat_with(rand::random::<f32>)
repeat_with(|| rng.gen::<f32>())
.take(512 * dimension as usize)
.collect::<Vec<f32>>(),
);
@@ -2990,8 +2993,9 @@ mod tests {
false,
)]));
let mut rng = rand::thread_rng();
let float_arr = Float32Array::from(
repeat_with(rand::random::<f32>)
repeat_with(|| rng.gen::<f32>())
.take(512 * dimension as usize)
.collect::<Vec<f32>>(),
);
@@ -3252,20 +3256,16 @@ mod tests {
.unwrap();
// Can not create btree or bitmap index on list column
assert!(
table
.create_index(&["tags"], Index::BTree(Default::default()))
.execute()
.await
.is_err()
);
assert!(
table
.create_index(&["tags"], Index::Bitmap(Default::default()))
.execute()
.await
.is_err()
);
assert!(table
.create_index(&["tags"], Index::BTree(Default::default()))
.execute()
.await
.is_err());
assert!(table
.create_index(&["tags"], Index::Bitmap(Default::default()))
.execute()
.await
.is_err());
// Create bitmap index on the "category" column
table

View File

@@ -7,8 +7,8 @@ use arrow_schema::{DataType, Fields, Schema};
use lance::dataset::WriteMode;
use serde::{Deserialize, Serialize};
use crate::data::scannable::Scannable;
use crate::data::scannable::scannable_with_embeddings;
use crate::data::scannable::Scannable;
use crate::embeddings::EmbeddingRegistry;
use crate::table::datafusion::cast::cast_to_table_schema;
use crate::table::datafusion::reject_nan::reject_nan_vectors;
@@ -204,14 +204,13 @@ mod tests {
use arrow::datatypes::Float64Type;
use arrow_array::{
FixedSizeListArray, Float32Array, Int32Array, LargeStringArray, ListArray, RecordBatch,
RecordBatchIterator, record_batch,
record_batch, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray, ListArray,
RecordBatch, RecordBatchIterator,
};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use futures::TryStreamExt;
use lance::dataset::{WriteMode, WriteParams};
use crate::Error;
use crate::arrow::{SendableRecordBatchStream, SimpleRecordBatchStream};
use crate::connect;
use crate::data::scannable::Scannable;
@@ -221,8 +220,9 @@ mod tests {
use crate::query::{ExecutableQuery, QueryBase, Select};
use crate::table::add_data::NaNVectorBehavior;
use crate::table::{ColumnDefinition, ColumnKind, Table, TableDefinition, WriteOptions};
use crate::test_utils::TestCustomError;
use crate::test_utils::embeddings::MockEmbed;
use crate::test_utils::TestCustomError;
use crate::Error;
use super::AddDataMode;

View File

@@ -17,17 +17,17 @@ use async_trait::async_trait;
use datafusion_catalog::{Session, TableProvider};
use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType, dml::InsertOp};
use datafusion_expr::{dml::InsertOp, Expr, TableProviderFilterPushDown, TableType};
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, stream::RecordBatchStreamAdapter,
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use futures::{TryFutureExt, TryStreamExt};
use lance::dataset::{WriteMode, WriteParams};
use super::{AnyQuery, BaseTable};
use crate::{
Result,
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
Result,
};
use arrow_schema::{DataType, Field};
use lance_index::scalar::FullTextSearchQuery;
@@ -268,7 +268,7 @@ impl TableProvider for BaseTableAdapter {
InsertOp::Replace => {
return Err(DataFusionError::NotImplemented(
"Replace mode is not supported for LanceDB tables".to_string(),
));
))
}
};
@@ -300,13 +300,13 @@ pub mod tests {
use datafusion_catalog::TableProvider;
use datafusion_common::stats::Precision;
use datafusion_execution::SendableRecordBatchStream;
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, col, lit};
use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
use futures::{StreamExt, TryStreamExt};
use tempfile::tempdir;
use crate::{
connect,
index::{Index, scalar::BTreeIndexBuilder},
index::{scalar::BTreeIndexBuilder, Index},
table::datafusion::BaseTableAdapter,
};

View File

@@ -5,10 +5,10 @@ use std::sync::Arc;
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
use datafusion::functions::core::{get_field, named_struct};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::{cast, Literal};
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr::expressions::{Literal, cast};
use datafusion_physical_plan::expressions::Column;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr};

View File

@@ -16,9 +16,9 @@ use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
};
use lance::Dataset;
use lance::dataset::transaction::{Operation, Transaction};
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
use lance::Dataset;
use lance_table::format::Fragment;
use crate::table::dataset::DatasetConsistencyWrapper;
@@ -195,13 +195,13 @@ impl ExecutionPlan for InsertExec {
}
};
if let Some(transactions) = to_commit
&& let Some(merged_txn) = merge_transactions(transactions)
{
let new_dataset = CommitBuilder::new(dataset.clone())
.execute(merged_txn)
.await?;
ds_wrapper.update(new_dataset);
if let Some(transactions) = to_commit {
if let Some(merged_txn) = merge_transactions(transactions) {
let new_dataset = CommitBuilder::new(dataset.clone())
.execute(merged_txn)
.await?;
ds_wrapper.update(new_dataset);
}
}
Ok(RecordBatch::try_new(
@@ -222,7 +222,7 @@ mod tests {
use std::vec;
use super::*;
use arrow_array::{RecordBatchIterator, record_batch};
use arrow_array::{record_batch, RecordBatchIterator};
use datafusion::prelude::SessionContext;
use datafusion_catalog::MemTable;
use tempfile::tempdir;

View File

@@ -4,11 +4,11 @@
use core::fmt;
use std::sync::{Arc, Mutex};
use datafusion_common::{DataFusionError, Result as DFResult, Statistics, stats::Precision};
use datafusion_common::{stats::Precision, DataFusionError, Result as DFResult, Statistics};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datafusion_physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, execution_plan::EmissionType,
execution_plan::EmissionType, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use crate::{arrow::SendableRecordBatchStreamExt, data::scannable::Scannable};

View File

@@ -9,7 +9,7 @@ use std::sync::Arc;
use datafusion::catalog::TableFunctionImpl;
use datafusion_catalog::TableProvider;
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue, plan_err};
use datafusion_common::{plan_err, DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion_expr::Expr;
use lance_index::scalar::FullTextSearchQuery;
@@ -93,9 +93,9 @@ pub fn from_json(json: &str) -> crate::Result<lance_index::scalar::inverted::que
mod tests {
use super::*;
use crate::{
Connection, Table,
index::{Index, scalar::FtsIndexBuilder},
index::{scalar::FtsIndexBuilder, Index},
table::datafusion::BaseTableAdapter,
Connection, Table,
};
use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
@@ -212,10 +212,10 @@ mod tests {
let explain_results = explain_df.collect().await.unwrap();
for batch in &explain_results {
for row_idx in 0..batch.num_rows() {
if let Some(col) = batch.column_by_name("plan")
&& let Some(plan_str) = col.as_any().downcast_ref::<StringArray>()
{
println!("{}", plan_str.value(row_idx));
if let Some(col) = batch.column_by_name("plan") {
if let Some(plan_str) = col.as_any().downcast_ref::<StringArray>() {
println!("{}", plan_str.value(row_idx));
}
}
}
}
@@ -229,10 +229,10 @@ mod tests {
let explain_analyze_results = explain_analyze_df.collect().await.unwrap();
for batch in &explain_analyze_results {
for row_idx in 0..batch.num_rows() {
if let Some(col) = batch.column_by_name("plan")
&& let Some(plan_str) = col.as_any().downcast_ref::<StringArray>()
{
println!("{}", plan_str.value(row_idx));
if let Some(col) = batch.column_by_name("plan") {
if let Some(plan_str) = col.as_any().downcast_ref::<StringArray>() {
println!("{}", plan_str.value(row_idx));
}
}
}
}

View File

@@ -6,9 +6,9 @@ use std::{
time::Duration,
};
use lance::{Dataset, dataset::refs};
use lance::{dataset::refs, Dataset};
use crate::{Error, error::Result, utils::background_cache::BackgroundCache};
use crate::{error::Result, utils::background_cache::BackgroundCache, Error};
/// A wrapper around a [Dataset] that provides consistency checks.
///

View File

@@ -36,7 +36,7 @@ pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Resu
#[cfg(test)]
mod tests {
use crate::connect;
use arrow_array::{Int32Array, RecordBatch, record_batch};
use arrow_array::{record_batch, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;

View File

@@ -55,7 +55,6 @@ pub struct MergeInsertBuilder {
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
pub(crate) timeout: Option<Duration>,
pub(crate) use_index: bool,
pub(crate) mem_wal: bool,
}
impl MergeInsertBuilder {
@@ -70,7 +69,6 @@ impl MergeInsertBuilder {
when_not_matched_by_source_delete_filt: None,
timeout: None,
use_index: true,
mem_wal: false,
}
}
@@ -150,65 +148,13 @@ impl MergeInsertBuilder {
self
}
/// Enables MemWAL (Memory Write-Ahead Log) mode for this merge insert operation.
///
/// When enabled, the merge insert will route data through a memory node service
/// that buffers writes before flushing to storage. This is only supported for
/// remote (LanceDB Cloud) tables.
///
/// If not set, defaults to `false`.
pub fn mem_wal(&mut self, enabled: bool) -> &mut Self {
self.mem_wal = enabled;
self
}
/// Executes the merge insert operation
///
/// Returns version and statistics about the merge operation including the number of rows
/// inserted, updated, and deleted.
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<MergeResult> {
// Validate MemWAL constraints before execution
if self.mem_wal {
self.validate_mem_wal_pattern()?;
}
self.table.clone().merge_insert(self, new_data).await
}
/// Validate that the merge insert pattern is supported by MemWAL.
///
/// MemWAL only supports the upsert pattern:
/// - when_matched_update_all (without filter)
/// - when_not_matched_insert_all
/// - NO when_not_matched_by_source_delete
fn validate_mem_wal_pattern(&self) -> Result<()> {
// Must have when_matched_update_all without filter
if !self.when_matched_update_all {
return Err(Error::InvalidInput {
message: "MemWAL requires when_matched_update_all() to be set".to_string(),
});
}
if self.when_matched_update_all_filt.is_some() {
return Err(Error::InvalidInput {
message: "MemWAL does not support conditional when_matched_update_all (no filter allowed)".to_string(),
});
}
// Must have when_not_matched_insert_all
if !self.when_not_matched_insert_all {
return Err(Error::InvalidInput {
message: "MemWAL requires when_not_matched_insert_all() to be set".to_string(),
});
}
// Must NOT have when_not_matched_by_source_delete
if self.when_not_matched_by_source_delete {
return Err(Error::InvalidInput {
message: "MemWAL does not support when_not_matched_by_source_delete()".to_string(),
});
}
Ok(())
}
}
/// Internal implementation of the merge insert logic
@@ -219,14 +165,6 @@ pub(crate) async fn execute_merge_insert(
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<MergeResult> {
if params.mem_wal {
return Err(Error::NotSupported {
message: "MemWAL is not supported for native (local) tables. \
MemWAL is only available for remote (LanceDB Cloud) tables."
.to_string(),
});
}
let dataset = table.dataset.get().await?;
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
match (
@@ -386,139 +324,4 @@ mod tests {
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 25);
}
#[tokio::test]
async fn test_mem_wal_validation_valid_pattern() {
let conn = connect("memory://").execute().await.unwrap();
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("mem_wal_test", batches)
.execute()
.await
.unwrap();
// Valid MemWAL pattern: when_matched_update_all + when_not_matched_insert_all
let new_batches = merge_insert_test_batches(5, 1);
let mut builder = table.merge_insert(&["i"]);
builder.when_matched_update_all(None);
builder.when_not_matched_insert_all();
builder.mem_wal(true);
// Should fail because native tables don't support MemWAL, but validation passes
let result = builder.execute(new_batches).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("MemWAL is not supported for native"),
"Expected native table error, got: {}",
err
);
}
#[tokio::test]
async fn test_mem_wal_validation_missing_when_matched() {
let conn = connect("memory://").execute().await.unwrap();
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("mem_wal_test2", batches)
.execute()
.await
.unwrap();
// Missing when_matched_update_all
let new_batches = merge_insert_test_batches(5, 1);
let mut builder = table.merge_insert(&["i"]);
builder.when_not_matched_insert_all();
builder.mem_wal(true);
let result = builder.execute(new_batches).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("requires when_matched_update_all"),
"Expected validation error, got: {}",
err
);
}
#[tokio::test]
async fn test_mem_wal_validation_missing_when_not_matched() {
let conn = connect("memory://").execute().await.unwrap();
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("mem_wal_test3", batches)
.execute()
.await
.unwrap();
// Missing when_not_matched_insert_all
let new_batches = merge_insert_test_batches(5, 1);
let mut builder = table.merge_insert(&["i"]);
builder.when_matched_update_all(None);
builder.mem_wal(true);
let result = builder.execute(new_batches).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("requires when_not_matched_insert_all"),
"Expected validation error, got: {}",
err
);
}
#[tokio::test]
async fn test_mem_wal_validation_with_filter() {
let conn = connect("memory://").execute().await.unwrap();
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("mem_wal_test4", batches)
.execute()
.await
.unwrap();
// With conditional filter - not allowed
let new_batches = merge_insert_test_batches(5, 1);
let mut builder = table.merge_insert(&["i"]);
builder.when_matched_update_all(Some("target.age > 0".to_string()));
builder.when_not_matched_insert_all();
builder.mem_wal(true);
let result = builder.execute(new_batches).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("does not support conditional"),
"Expected filter validation error, got: {}",
err
);
}
#[tokio::test]
async fn test_mem_wal_validation_with_delete() {
let conn = connect("memory://").execute().await.unwrap();
let batches = merge_insert_test_batches(0, 0);
let table = conn
.create_table("mem_wal_test5", batches)
.execute()
.await
.unwrap();
// With when_not_matched_by_source_delete - not allowed
let new_batches = merge_insert_test_batches(5, 1);
let mut builder = table.merge_insert(&["i"]);
builder.when_matched_update_all(None);
builder.when_not_matched_insert_all();
builder.when_not_matched_by_source_delete(None);
builder.mem_wal(true);
let result = builder.execute(new_batches).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("does not support when_not_matched_by_source_delete"),
"Expected delete validation error, got: {}",
err
);
}
}

View File

@@ -9,9 +9,9 @@
use std::sync::Arc;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{CompactionMetrics, IndexRemapperOptions, compact_files};
use lance_index::DatasetIndexExt;
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
use lance_index::optimize::OptimizeOptions;
use lance_index::DatasetIndexExt;
use log::info;
pub use chrono::Duration;
@@ -213,7 +213,7 @@ mod tests {
use std::sync::Arc;
use crate::connect;
use crate::index::{Index, scalar::BTreeIndexBuilder};
use crate::index::{scalar::BTreeIndexBuilder, Index};
use crate::query::ExecutableQuery;
use crate::table::{CompactionOptions, OptimizeAction, OptimizeStats};
use futures::TryStreamExt;

View File

@@ -7,26 +7,26 @@ use super::NativeTable;
use crate::error::{Error, Result};
use crate::expr::expr_to_sql_string;
use crate::query::{
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest, DEFAULT_TOP_K,
};
use crate::utils::{TimeoutStream, default_vector_column};
use crate::utils::{default_vector_column, TimeoutStream};
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::Array;
use arrow_schema::{DataType, Schema};
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::union::UnionExec;
use datafusion_physical_plan::ExecutionPlan;
use futures::future::try_join_all;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::scanner::Scanner;
use lance_datafusion::exec::{analyze_plan as lance_analyze_plan, execute_plan};
use lance_namespace::LanceNamespace;
use lance_namespace::models::{
QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns,
QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery,
};
use lance_namespace::LanceNamespace;
#[derive(Debug, Clone)]
pub enum AnyQuery {

View File

@@ -92,7 +92,7 @@ pub(crate) async fn execute_drop_columns(
#[cfg(test)]
mod tests {
use arrow_array::{Int32Array, StringArray, record_batch};
use arrow_array::{record_batch, Int32Array, StringArray};
use arrow_schema::DataType;
use futures::TryStreamExt;
use lance::dataset::ColumnAlteration;

View File

@@ -115,9 +115,9 @@ mod tests {
use crate::query::QueryBase;
use crate::query::{ExecutableQuery, Select};
use arrow_array::{
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
Int32Array, Int64Array, LargeStringArray, RecordBatch, StringArray,
TimestampMillisecondArray, TimestampNanosecondArray, UInt32Array, record_batch,
record_batch, Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array,
Float64Array, Int32Array, Int64Array, LargeStringArray, RecordBatch, StringArray,
TimestampMillisecondArray, TimestampNanosecondArray, UInt32Array,
};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit};

View File

@@ -10,9 +10,9 @@ use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, ChildStdout, Command};
use tokio::sync::mpsc;
use crate::{Connection, connect};
use anyhow::{Result, anyhow, bail};
use tempfile::{TempDir, tempdir};
use crate::{connect, Connection};
use anyhow::{anyhow, bail, Result};
use tempfile::{tempdir, TempDir};
pub struct TestConnection {
pub uri: String,
@@ -113,13 +113,8 @@ async fn new_remote_connection(script_path: &str) -> Result<TestConnection> {
let (port_sender, mut port_receiver) = mpsc::channel(5);
let _reader = spawn_stdout_reader(stdout, port_sender).await;
let port = match port_receiver.recv().await {
None => bail!(
"Unable to determine the port number used by the phalanx process we spawned, because the reader thread was closed too soon."
),
Some(Err(err)) => bail!(
"Unable to determine the port number used by the phalanx process we spawned, because of an error, {}",
err
),
None => bail!("Unable to determine the port number used by the phalanx process we spawned, because the reader thread was closed too soon."),
Some(Err(err)) => bail!("Unable to determine the port number used by the phalanx process we spawned, because of an error, {}", err),
Some(Ok(port)) => port,
};
let uri = "db://test";

View File

@@ -6,9 +6,8 @@ use futures::TryStreamExt;
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
use crate::{
Error, Table,
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
connect,
connect, Error, Table,
};
#[async_trait::async_trait]

View File

@@ -6,8 +6,8 @@ use std::{borrow::Cow, sync::Arc};
use arrow_array::{Array, FixedSizeListArray, Float32Array};
use arrow_schema::{DataType, Field};
use crate::Result;
use crate::embeddings::EmbeddingFunction;
use crate::Result;
#[derive(Debug, Clone)]
pub struct MockEmbed {

View File

@@ -9,8 +9,8 @@ use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use futures::FutureExt;
use futures::future::{BoxFuture, Shared};
use futures::FutureExt;
type SharedFut<V, E> = Shared<BoxFuture<'static, Result<V, Arc<E>>>>;

View File

@@ -53,7 +53,7 @@ impl PatchStoreParam for Option<ObjectStoreParams> {
pub trait PatchWriteParam {
fn patch_with_store_wrapper(self, wrapper: Arc<dyn WrappingObjectStore>)
-> Result<WriteParams>;
-> Result<WriteParams>;
}
impl PatchWriteParam for WriteParams {
@@ -340,7 +340,7 @@ mod tests {
use arrow_array::Int32Array;
use arrow_schema::Field;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{StreamExt, stream};
use futures::{stream, StreamExt};
use tokio::time::sleep;
use super::*;
@@ -351,12 +351,10 @@ mod tests {
Field::new("id", DataType::Int16, true),
Field::new("tag", DataType::Utf8, false),
]);
assert!(
default_vector_column(&schema_no_vector, None)
.unwrap_err()
.to_string()
.contains("No vector column")
);
assert!(default_vector_column(&schema_no_vector, None)
.unwrap_err()
.to_string()
.contains("No vector column"));
let schema_with_vec_col = Schema::new(vec![
Field::new("id", DataType::Int16, true),
@@ -384,12 +382,10 @@ mod tests {
false,
),
]);
assert!(
default_vector_column(&multi_vec_col, None)
.unwrap_err()
.to_string()
.contains("More than one")
);
assert!(default_vector_column(&multi_vec_col, None)
.unwrap_err()
.to_string()
.contains("More than one"));
}
#[test]
@@ -505,12 +501,10 @@ mod tests {
// Poll the stream again and ensure it returns a timeout error
let second_result = timeout_stream.next().await.unwrap();
assert!(second_result.is_err());
assert!(
second_result
.unwrap_err()
.to_string()
.contains("Query timeout")
);
assert!(second_result
.unwrap_err()
.to_string()
.contains("Query timeout"));
}
#[tokio::test]

View File

@@ -15,9 +15,10 @@ use arrow_array::{
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
Error, Result, connect,
connect,
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
query::ExecutableQuery,
Error, Result,
};
#[tokio::test]
@@ -261,19 +262,17 @@ fn create_some_records() -> Result<Box<dyn arrow_array::RecordBatchReader + Send
// Create a RecordBatch stream.
let batches = RecordBatchIterator::new(
vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter(std::iter::repeat_n(
Some("hello world".to_string()),
TOTAL,
))),
],
)
.unwrap(),
]
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter(std::iter::repeat_n(
Some("hello world".to_string()),
TOTAL,
))),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),

View File

@@ -4,8 +4,8 @@
use std::{
borrow::Cow,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
@@ -16,8 +16,8 @@ use arrow_array::{
};
use arrow_schema::{DataType, Field, Schema};
use lancedb::{
Error, Result,
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
Error, Result,
};
#[derive(Debug)]

View File

@@ -8,7 +8,7 @@ use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use aws_config::{BehaviorVersion, ConfigLoader, Region, SdkConfig};
use aws_sdk_s3::{Client as S3Client, config::Credentials, types::ServerSideEncryption};
use aws_sdk_s3::{config::Credentials, types::ServerSideEncryption, Client as S3Client};
use lancedb::Result;
const CONFIG: &[(&str, &str)] = &[