mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-18 03:30:40 +00:00
Compare commits
1 Commits
main
...
codex/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
999b31b4a6 |
@@ -23,8 +23,6 @@ allow_dirty = true
|
||||
commit = true
|
||||
message = "Bump version: {current_version} → {new_version}"
|
||||
commit_args = ""
|
||||
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
|
||||
allow_shell_hooks = true
|
||||
|
||||
# Java maven files
|
||||
pre_commit_hooks = [
|
||||
|
||||
80
Cargo.lock
generated
80
Cargo.lock
generated
@@ -3432,8 +3432,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.4",
|
||||
@@ -4735,8 +4735,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"arrow",
|
||||
@@ -4810,8 +4810,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4832,7 +4832,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "lance-arrow-scalar"
|
||||
version = "58.0.0"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4846,7 +4846,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "lance-arrow-stats"
|
||||
version = "58.0.0"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4855,8 +4855,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4865,8 +4865,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4904,8 +4904,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4935,8 +4935,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4953,8 +4953,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-derive"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -4963,8 +4963,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4999,8 +4999,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -5030,8 +5030,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"arrow",
|
||||
@@ -5096,8 +5096,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -5138,8 +5138,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5154,8 +5154,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5167,8 +5167,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
@@ -5222,8 +5222,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-select"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -5238,8 +5238,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -5278,8 +5278,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -5292,8 +5292,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-tokenizer"
|
||||
version = "8.0.0-beta.17"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
|
||||
version = "8.0.0-beta.18"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.18#909dea18b1de21a84f7574fab8335bab02dc48b8"
|
||||
dependencies = [
|
||||
"icu_segmenter",
|
||||
"jieba-rs",
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -13,20 +13,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=8.0.0-beta.17", default-features = false, "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=8.0.0-beta.17", default-features = false, "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=8.0.0-beta.17", default-features = false, "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { "version" = "=8.0.0-beta.18", default-features = false, "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=8.0.0-beta.18", default-features = false, "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=8.0.0-beta.18", default-features = false, "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=8.0.0-beta.18", "tag" = "v8.0.0-beta.18", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "58.0.0", optional = false }
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>8.0.0-beta.17</lance-core.version>
|
||||
<lance-core.version>8.0.0-beta.18</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -23,8 +23,6 @@ allow_dirty = true
|
||||
commit = true
|
||||
message = "Bump version: {current_version} → {new_version}"
|
||||
commit_args = ""
|
||||
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
|
||||
allow_shell_hooks = true
|
||||
|
||||
# Update Cargo.lock after version bump
|
||||
pre_commit_hooks = [
|
||||
|
||||
@@ -13,7 +13,7 @@ use serde_json::{Value, json};
|
||||
use super::EmbeddingFunction;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use tokio::runtime::{Handle, RuntimeFlavor};
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::task::block_in_place;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -148,12 +148,6 @@ impl BedrockEmbeddingFunction {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// Bedrock's SDK is async but this trait method is synchronous, so we
|
||||
// bridge with `block_in_place` + `block_on`. That requires a
|
||||
// multi-threaded Tokio runtime; return a typed error instead of
|
||||
// panicking when no compatible runtime is available.
|
||||
let handle = current_multi_thread_handle()?;
|
||||
|
||||
for text in texts {
|
||||
let request_body = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => {
|
||||
@@ -169,28 +163,24 @@ impl BedrockEmbeddingFunction {
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize before entering the blocking section so a serialization
|
||||
// failure surfaces as a typed error rather than an `unwrap` panic.
|
||||
let body = serde_json::to_vec(&request_body).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to serialize Bedrock request: {e}"),
|
||||
})?;
|
||||
|
||||
let client = self.client.clone();
|
||||
let model_id = self.model.model_id().to_string();
|
||||
let request_body = request_body.clone();
|
||||
|
||||
let response = block_in_place(|| {
|
||||
handle.block_on(async move {
|
||||
let response = block_in_place(move || {
|
||||
Handle::current().block_on(async move {
|
||||
client
|
||||
.invoke_model()
|
||||
.model_id(model_id)
|
||||
.body(aws_sdk_bedrockruntime::primitives::Blob::new(body))
|
||||
.body(aws_sdk_bedrockruntime::primitives::Blob::new(
|
||||
serde_json::to_vec(&request_body).unwrap(),
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Bedrock invoke_model request failed: {e}"),
|
||||
})
|
||||
.map_err(Box::new)
|
||||
})
|
||||
})?;
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let response_json: Value =
|
||||
serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime {
|
||||
@@ -198,12 +188,22 @@ impl BedrockEmbeddingFunction {
|
||||
})?;
|
||||
|
||||
let embedding = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => {
|
||||
json_array_to_f32(&response_json["embedding"], "embedding")?
|
||||
}
|
||||
BedrockEmbeddingModel::CohereLarge => {
|
||||
json_array_to_f32(&response_json["embeddings"][0], "embeddings")?
|
||||
}
|
||||
BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embedding in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embeddings in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
};
|
||||
|
||||
builder.append_slice(&embedding);
|
||||
@@ -212,86 +212,3 @@ impl BedrockEmbeddingFunction {
|
||||
Ok(builder.finish())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a handle to the current multi-threaded Tokio runtime, or a typed
|
||||
/// [`Error::Runtime`] when called outside a runtime or on the current-thread
|
||||
/// runtime. This keeps the synchronous-over-async bridge in
|
||||
/// [`BedrockEmbeddingFunction::compute_inner`] from panicking on runtime
|
||||
/// configurations that cannot support `block_in_place`.
|
||||
fn current_multi_thread_handle() -> Result<Handle> {
|
||||
let handle = Handle::try_current().map_err(|e| Error::Runtime {
|
||||
message: format!("Bedrock embedding must be called from within a Tokio runtime: {e}"),
|
||||
})?;
|
||||
if handle.runtime_flavor() == RuntimeFlavor::CurrentThread {
|
||||
return Err(Error::Runtime {
|
||||
message: "Bedrock embedding requires a multi-threaded Tokio runtime; the \
|
||||
current-thread runtime cannot use `block_in_place`"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Converts a JSON value expected to be an array of numbers into `Vec<f32>`.
|
||||
///
|
||||
/// Returns a typed [`Error::Runtime`] (rather than panicking) when the value is
|
||||
/// not an array or contains a non-numeric element, so malformed provider
|
||||
/// responses degrade gracefully.
|
||||
fn json_array_to_f32(value: &Value, field: &str) -> Result<Vec<f32>> {
|
||||
let arr = value.as_array().ok_or_else(|| Error::Runtime {
|
||||
message: format!("Missing or non-array '{field}' field in Bedrock response"),
|
||||
})?;
|
||||
arr.iter()
|
||||
.map(|v| {
|
||||
v.as_f64().map(|f| f as f32).ok_or_else(|| Error::Runtime {
|
||||
message: format!("Non-numeric value in Bedrock '{field}' embedding: {v}"),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_parses_numbers() {
|
||||
let v = json!([1.0, 2, -3.5]);
|
||||
let out = json_array_to_f32(&v, "embedding").unwrap();
|
||||
assert_eq!(out, vec![1.0_f32, 2.0, -3.5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_rejects_non_array() {
|
||||
// Missing field indexes to `Value::Null`; a malformed payload should be
|
||||
// a typed error, not a panic.
|
||||
let v = json!({"unexpected": "shape"});
|
||||
let err = json_array_to_f32(&v["embedding"], "embedding").unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_rejects_non_numeric_element() {
|
||||
let v = json!([1.0, "not-a-number", 3.0]);
|
||||
let err = json_array_to_f32(&v, "embedding").unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_errors_without_runtime() {
|
||||
// No Tokio runtime in scope -> typed error instead of a panic.
|
||||
let err = current_multi_thread_handle().unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn handle_errors_on_current_thread_runtime() {
|
||||
let err = current_multi_thread_handle().unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn handle_ok_on_multi_thread_runtime() {
|
||||
current_multi_thread_handle().expect("multi-threaded runtime should be accepted");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user