Compare commits

...

10 Commits

Author SHA1 Message Date
Lance Release
d59f64b5a3 Bump version: 0.22.0-beta.4 → 0.22.0-beta.5 2025-04-04 21:49:34 +00:00
fzowl
30ed8c4c43 fix: voyageai regression multimodal supercedes text models (#2268)
fix #2160
2025-04-04 14:45:56 -07:00
Will Jones
4a2cdbf299 ci: provide token for deprecate call (#2309)
This should prevent the failures we are seeing in Node release.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Chore**
- Enhanced the package deprecation process with improved security
measures, ensuring smoother and more reliable updates during package
deprecation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-04 14:44:58 -07:00
Will Jones
657843d9e9 perf: remove redundant checkout latest (#2310)
This bug was introduced in https://github.com/lancedb/lancedb/pull/2281

Likely introduced during a rebase when fixing merge conflicts.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Refactor**
- Updated the refresh process so that reloading now uses the existing
dataset version instead of automatically updating to the latest version.
This change may affect workflows that rely on immediate data updates
during refresh.
  
- **New Features**
- Introduced a new module for tracking I/O statistics in object store
operations, enhancing monitoring capabilities.
- Added a new test module to validate the functionality of the dataset
operations.

- **Bug Fixes**
- Reintroduced the `write_options` method in the `CreateTableBuilder`,
ensuring consistent functionality across different builder variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-04 12:56:02 -07:00
Will Jones
1cd76b8498 feat: add timeout to query execution options (#2288)
Closes #2287


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added configurable timeout support for query executions. Users can now
specify maximum wait times for queries, enhancing control over
long-running operations across various integrations.
- **Tests**
- Expanded test coverage to validate timeout behavior in both
synchronous and asynchronous query flows, ensuring timely error
responses when query execution exceeds the specified limit.
- Introduced a new test suite to verify query operations when a timeout
is reached, checking for appropriate error handling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-04 12:34:41 -07:00
Lei Xu
a38f784081 chore: add numpy as dependency (#2308) 2025-04-04 10:33:39 -07:00
Will Jones
647dee4e94 ci: check release builds when we change dependencies (#2299)
The issue we fixed in https://github.com/lancedb/lancedb/pull/2296 was
caused by an upgrade in dependencies. This could have been caught if we
had run these CI jobs when we did the dependency change.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Chores**
- Updated our automated pipeline to trigger additional stability checks
when dependency configurations change, ensuring smoother build and
release processes.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-03 16:19:00 -07:00
Lance Release
0844c2dd64 Updating package-lock.json 2025-04-02 21:23:50 +00:00
Lance Release
fd2692295c Updating package-lock.json 2025-04-02 21:23:34 +00:00
Lance Release
d4ea50fba1 Bump version: 0.19.0-beta.3 → 0.19.0-beta.4 2025-04-02 21:23:19 +00:00
45 changed files with 1186 additions and 197 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.19.0-beta.3"
current_version = "0.19.0-beta.4"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -18,6 +18,7 @@ on:
# This should trigger a dry run (we skip the final publish step)
paths:
- .github/workflows/npm-publish.yml
- Cargo.toml # Change in dependency frequently breaks builds
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@@ -531,6 +532,8 @@ jobs:
npm publish $PUBLISH_ARGS $filename
done
- name: Deprecate
env:
NODE_AUTH_TOKEN: ${{ secrets.LANCEDB_NPM_REGISTRY_TOKEN }}
# We need to deprecate the old package to avoid confusion.
# Each time we publish a new version, it gets undeprecated.
run: npm deprecate vectordb "Use @lancedb/lancedb instead."

View File

@@ -8,6 +8,7 @@ on:
# This should trigger a dry run (we skip the final publish step)
paths:
- .github/workflows/pypi-publish.yml
- Cargo.toml # Change in dependency frequently breaks builds
jobs:
linux:

8
Cargo.lock generated
View File

@@ -4110,7 +4110,7 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
dependencies = [
"arrow",
"arrow-array",
@@ -4197,7 +4197,7 @@ dependencies = [
[[package]]
name = "lancedb-node"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -4222,7 +4222,7 @@ dependencies = [
[[package]]
name = "lancedb-nodejs"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -4240,7 +4240,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.22.0-beta.3"
version = "0.22.0-beta.4"
dependencies = [
"arrow",
"env_logger",

View File

@@ -20,3 +20,13 @@ The maximum number of rows to return in a single batch
Batches may have fewer rows if the underlying data is stored
in smaller chunks.
***
### timeoutMs?
```ts
optional timeoutMs: number;
```
Timeout for query execution in milliseconds

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.3</version>
<version>0.19.0-beta.4</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.3</version>
<version>0.19.0-beta.4</version>
<packaging>pom</packaging>
<name>LanceDB Parent</name>

74
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"cpu": [
"x64",
"arm64"
@@ -52,11 +52,11 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.3",
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.3",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.3",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.3",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.3"
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.4",
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.4",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.4",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.4",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.4"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
@@ -326,6 +326,66 @@
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.4.tgz",
"integrity": "sha512-uS5AuT3Q4swrtM9JAhF8mM8Nt+kvewmB3DQWGiuYbhmMismSu8WlOHQAs9Yyh8N7NBdWENSTjroSExqjHPdFhQ==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.4.tgz",
"integrity": "sha512-kjn3iTqZSx57ek9PN2AdPvJMx14tFkXc8sUFd3MLhY7FdWafx7Wvl0SLz2LubotJVFd6LMxvsPPNJEM5bEgMOw==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.4.tgz",
"integrity": "sha512-iZlR7ffKC+XA1mGuuwXJojgFcUvXkgMt6pKR6lP3hsxXh8UOTWDljN7jkI8jKHcJez3rrqoqt1VjH3xD69fwtA==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.4.tgz",
"integrity": "sha512-uxLeerlT5FuWzuvHlTDLdLCakyUJ+qJitReoCKT6tKhfcjIkbr+NEoLZEHifJC4dRFPtbddVgiYN6VHlnPPD/w==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.19.0-beta.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.4.tgz",
"integrity": "sha512-QSugxudXooLCF7trudaAo9PfOzX7SFBIiHOoL4N6nwjC61u/JAsoiytw1Xjs/+0pOG5cT2WUMufBzBPgJyOxbw==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"win32"
]
},
"node_modules/@neon-rs/cli": {
"version": "0.0.160",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"description": " Serverless, low-latency vector database for AI applications",
"private": false,
"main": "dist/index.js",
@@ -89,10 +89,10 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.3",
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.3",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.3",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.3",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.3"
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.4",
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.4",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.4",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.4",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.4"
}
}

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -867,6 +867,44 @@ describe("When creating an index", () => {
});
});
describe("When querying a table", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should throw an error when timeout is reached", async () => {
const db = await connect(tmpDir.name);
const data = makeArrowTable([
{ text: "a", vector: [0.1, 0.2] },
{ text: "b", vector: [0.3, 0.4] },
]);
const table = await db.createTable("test", data);
await table.createIndex("text", { config: Index.fts() });
await expect(
table.query().where("text != 'a'").toArray({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table.query().nearestTo([0.0, 0.0]).toArrow({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table.search("a", "fts").toArray({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
await expect(
table
.query()
.nearestToText("a")
.nearestTo([0.0, 0.0])
.toArrow({ timeoutMs: 0 }),
).rejects.toThrow("Query timeout");
});
});
describe("Read consistency interval", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {

View File

@@ -63,7 +63,7 @@ class RecordBatchIterable<
// biome-ignore lint/suspicious/noExplicitAny: skip
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
return new RecordBatchIterator(
this.inner.execute(this.options?.maxBatchLength),
this.inner.execute(this.options?.maxBatchLength, this.options?.timeoutMs),
);
}
}
@@ -79,6 +79,11 @@ export interface QueryExecutionOptions {
* in smaller chunks.
*/
maxBatchLength?: number;
/**
* Timeout for query execution in milliseconds
*/
timeoutMs?: number;
}
/**
@@ -283,9 +288,11 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {
if (this.inner instanceof Promise) {
return this.inner.then((inner) => inner.execute(options?.maxBatchLength));
return this.inner.then((inner) =>
inner.execute(options?.maxBatchLength, options?.timeoutMs),
);
} else {
return this.inner.execute(options?.maxBatchLength);
return this.inner.execute(options?.maxBatchLength, options?.timeoutMs);
}
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.19.0-beta.3",
"version": "0.19.0-beta.4",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -131,11 +131,15 @@ impl Query {
pub async fn execute(
&self,
max_batch_length: Option<u32>,
timeout_ms: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
if let Some(timeout_ms) = timeout_ms {
execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64))
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
@@ -330,11 +334,15 @@ impl VectorQuery {
pub async fn execute(
&self,
max_batch_length: Option<u32>,
timeout_ms: Option<u32>,
) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
if let Some(timeout_ms) = timeout_ms {
execution_opts.timeout = Some(std::time::Duration::from_millis(timeout_ms as u64))
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.22.0-beta.4"
current_version = "0.22.0-beta.5"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.22.0-beta.4"
version = "0.22.0-beta.5"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -4,11 +4,12 @@ name = "lancedb"
dynamic = ["version"]
dependencies = [
"deprecation",
"tqdm>=4.27.0",
"numpy",
"overrides>=0.7",
"packaging",
"pyarrow>=14",
"pydantic>=1.10",
"packaging",
"overrides>=0.7",
"tqdm>=4.27.0",
]
description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
@@ -55,6 +56,7 @@ tests = [
"tantivy",
"pyarrow-stubs",
"pylance>=0.23.2",
"requests",
]
dev = [
"ruff",

View File

@@ -1,3 +1,4 @@
from datetime import timedelta
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa
@@ -94,7 +95,9 @@ class Query:
def postfilter(self): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> FTSQuery: ...
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
async def explain_plan(self, verbose: Optional[bool]) -> str: ...
async def analyze_plan(self) -> str: ...
def to_query_request(self) -> PyQueryRequest: ...
@@ -110,7 +113,9 @@ class FTSQuery:
def get_query(self) -> str: ...
def add_query_vector(self, query_vec: pa.Array) -> None: ...
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ...
class VectorQuery:

View File

@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import base64
import os
from typing import ClassVar, TYPE_CHECKING, List, Union
from typing import ClassVar, TYPE_CHECKING, List, Union, Any
from pathlib import Path
from urllib.parse import urlparse
from io import BytesIO
import numpy as np
import pyarrow as pa
@@ -11,12 +14,100 @@ import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, IMAGES
from .utils import api_key_not_found_help, IMAGES, TEXT
if TYPE_CHECKING:
import PIL
def is_valid_url(text):
try:
parsed = urlparse(text)
return bool(parsed.scheme) and bool(parsed.netloc)
except Exception:
return False
def transform_input(input_data: Union[str, bytes, Path]):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(input_data, str):
if is_valid_url(input_data):
content = {"type": "image_url", "image_url": input_data}
else:
content = {"type": "text", "text": input_data}
elif isinstance(input_data, PIL.Image.Image):
buffered = BytesIO()
input_data.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, bytes):
img = PIL.Image.open(BytesIO(input_data))
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
elif isinstance(input_data, Path):
img = PIL.Image.open(input_data)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": "data:image/jpeg;base64," + img_str,
}
else:
raise ValueError("Each input should be either str, bytes, Path or Image.")
return {"content": [content]}
def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
"""
Sanitize the input to the embedding function.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
else:
raise ValueError(
f"Input type {type(inputs)} not allowed with multimodal model."
)
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
raise ValueError("Each input should be either str, bytes, Path or Image.")
return [transform_input(i) for i in inputs]
def sanitize_text_input(inputs: TEXT) -> List[str]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(inputs, str):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
else:
raise ValueError(f"Input type {type(inputs)} not allowed with text model.")
if not all(isinstance(x, str) for x in inputs):
raise ValueError("Each input should be str.")
return inputs
@register("voyageai")
class VoyageAIEmbeddingFunction(EmbeddingFunction):
"""
@@ -74,6 +165,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
]
multimodal_embedding_models: list = ["voyage-multimodal-3"]
def _is_multimodal_model(self, model_name: str):
return (
model_name in self.multimodal_embedding_models or "multimodal" in model_name
)
def ndims(self):
if self.name == "voyage-3-lite":
return 512
@@ -85,55 +181,12 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
"voyage-finance-2",
"voyage-multilingual-2",
"voyage-law-2",
"voyage-multimodal-3",
]:
return 1024
else:
raise ValueError(f"Model {self.name} not supported")
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def generate_text_embeddings(self, text: str, **kwargs) -> np.ndarray:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
input_type: Optional[str]
truncation: Optional[bool]
"""
client = VoyageAIEmbeddingFunction._get_client()
if self.name in self.text_embedding_models:
rs = client.embed(texts=[text], model=self.name, **kwargs)
elif self.name in self.multimodal_embedding_models:
rs = client.multimodal_embed(inputs=[[text]], model=self.name, **kwargs)
else:
raise ValueError(
f"Model {self.name} not supported to generate text embeddings"
)
return rs.embeddings[0]
def generate_image_embedding(
self, image: "PIL.Image.Image", **kwargs
) -> np.ndarray:
rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed(
inputs=[[image]], model=self.name, **kwargs
)
return rs.embeddings[0]
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
@@ -144,23 +197,52 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
Returns
-------
List[np.array]: the list of embeddings
"""
if isinstance(query, str):
return [self.generate_text_embeddings(query, input_type="query")]
client = VoyageAIEmbeddingFunction._get_client()
if self._is_multimodal_model(self.name):
result = client.multimodal_embed(
inputs=[[query]], model=self.name, input_type="query", **kwargs
)
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query, input_type="query")]
else:
raise TypeError("Only text PIL images supported as query")
result = client.embed(
texts=[query], model=self.name, input_type="query", **kwargs
)
return [result.embeddings[0]]
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
self, inputs: Union[TEXT, IMAGES], *args, **kwargs
) -> List[np.array]:
images = self.sanitize_input(images)
return [
self.generate_image_embedding(img, input_type="document") for img in images
]
"""
Compute the embeddings for the inputs
Parameters
----------
inputs : Union[TEXT, IMAGES]
The inputs to embed. The input can be either str, bytes, Path (to an image),
PIL.Image or list of these.
Returns
-------
List[np.array]: the list of embeddings
"""
client = VoyageAIEmbeddingFunction._get_client()
if self._is_multimodal_model(self.name):
inputs = sanitize_multimodal_input(inputs)
result = client.multimodal_embed(
inputs=inputs, model=self.name, input_type="document", **kwargs
)
else:
inputs = sanitize_text_input(inputs)
result = client.embed(
texts=inputs, model=self.name, input_type="document", **kwargs
)
return result.embeddings
@staticmethod
def _get_client():

View File

@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
import abc
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Dict,
@@ -650,7 +651,12 @@ class LanceQueryBuilder(ABC):
"""
return self.to_pandas()
def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame":
def to_pandas(
self,
flatten: Optional[Union[int, bool]] = None,
*,
timeout: Optional[timedelta] = None,
) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
@@ -664,12 +670,15 @@ class LanceQueryBuilder(ABC):
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
tbl = flatten_columns(self.to_arrow(), flatten)
tbl = flatten_columns(self.to_arrow(timeout=timeout), flatten)
return tbl.to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
@@ -677,34 +686,65 @@ class LanceQueryBuilder(ABC):
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
raise NotImplementedError
@abstractmethod
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
def to_batches(
self,
/,
batch_size: Optional[int] = None,
*,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
"""
Execute the query and return the results as a pyarrow
[RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html)
Parameters
----------
batch_size: int
The maximum number of selected records in a RecordBatch object.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
raise NotImplementedError
def to_list(self) -> List[dict]:
def to_list(self, *, timeout: Optional[timedelta] = None) -> List[dict]:
"""
Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected.
"""
return self.to_arrow().to_pylist()
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
return self.to_arrow(timeout=timeout).to_pylist()
def to_pydantic(
self, model: Type[LanceModel], *, timeout: Optional[timedelta] = None
) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
Returns
-------
@@ -712,19 +752,25 @@ class LanceQueryBuilder(ABC):
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
for row in self.to_arrow(timeout=timeout).to_pylist()
]
def to_polars(self) -> "pl.DataFrame":
def to_polars(self, *, timeout: Optional[timedelta] = None) -> "pl.DataFrame":
"""
Execute the query and return the results as a Polars DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
import polars as pl
return pl.from_arrow(self.to_arrow())
return pl.from_arrow(self.to_arrow(timeout=timeout))
def limit(self, limit: Union[int, None]) -> Self:
"""Set the maximum number of results to return.
@@ -1139,7 +1185,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor
return self
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
@@ -1147,8 +1193,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If None, wait indefinitely.
"""
return self.to_batches().read_all()
return self.to_batches(timeout=timeout).read_all()
def to_query_object(self) -> Query:
"""
@@ -1178,7 +1230,13 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
bypass_vector_index=self._bypass_vector_index,
)
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
def to_batches(
self,
/,
batch_size: Optional[int] = None,
*,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
"""
Execute the query and return the result as a RecordBatchReader object.
@@ -1186,6 +1244,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
----------
batch_size: int
The maximum number of selected records in a RecordBatch object.
timeout: timedelta, default None
The maximum time to wait for the query to complete.
If None, wait indefinitely.
Returns
-------
@@ -1195,7 +1256,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector]
query = self.to_query_object()
result_set = self._table._execute_query(query, batch_size)
result_set = self._table._execute_query(
query, batch_size=batch_size, timeout=timeout
)
if self._reranker is not None:
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
@@ -1334,7 +1397,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
path, fs, exist = self._table._get_fts_index_path()
if exist:
return self.tantivy_to_arrow()
@@ -1346,14 +1409,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
"Use tantivy-based index instead for now."
)
query = self.to_query_object()
results = self._table._execute_query(query)
results = self._table._execute_query(query, timeout=timeout)
results = results.read_all()
if self._reranker is not None:
results = self._reranker.rerank_fts(self._query, results)
check_reranker_result(results)
return results
def to_batches(self, /, batch_size: Optional[int] = None):
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
):
raise NotImplementedError("to_batches on an FTS query")
def tantivy_to_arrow(self) -> pa.Table:
@@ -1458,8 +1523,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
return self.to_batches().read_all()
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
return self.to_batches(timeout=timeout).read_all()
def to_query_object(self) -> Query:
return Query(
@@ -1470,9 +1535,11 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
) -> pa.RecordBatchReader:
query = self.to_query_object()
return self._table._execute_query(query, batch_size)
return self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
@@ -1560,7 +1627,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def to_query_object(self) -> Query:
raise NotImplementedError("to_query_object not yet supported on a hybrid query")
def to_arrow(self) -> pa.Table:
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text
)
@@ -1603,9 +1670,11 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._reranker = RRFReranker()
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
fts_future = executor.submit(
self._fts_query.with_row_id(True).to_arrow, timeout=timeout
)
vector_future = executor.submit(
self._vector_query.with_row_id(True).to_arrow
self._vector_query.with_row_id(True).to_arrow, timeout=timeout
)
fts_results = fts_future.result()
vector_results = vector_future.result()
@@ -1692,7 +1761,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
return results
def to_batches(self):
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
):
raise NotImplementedError("to_batches not yet supported on a hybrid query")
@staticmethod
@@ -2056,7 +2127,10 @@ class AsyncQueryBase(object):
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
@@ -2069,34 +2143,56 @@ class AsyncQueryBase(object):
If not specified, a default batch length is used.
It is possible for batches to be smaller than the provided length if the
underlying data is stored in smaller chunks.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
return AsyncRecordBatchReader(
await self._inner.execute(max_batch_length, timeout)
)
async def to_arrow(self) -> pa.Table:
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and collect the results into an Apache Arrow Table.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use
[to_batches][lancedb.query.AsyncQueryBase.to_batches]
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
batch_iter = await self.to_batches()
batch_iter = await self.to_batches(timeout=timeout)
return pa.Table.from_batches(
await batch_iter.read_all(), schema=batch_iter.schema
)
async def to_list(self) -> List[dict]:
async def to_list(self, timeout: Optional[timedelta] = None) -> List[dict]:
"""
Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
return (await self.to_arrow()).to_pylist()
return (await self.to_arrow(timeout=timeout)).to_pylist()
async def to_pandas(
self, flatten: Optional[Union[int, bool]] = None
self,
flatten: Optional[Union[int, bool]] = None,
timeout: Optional[timedelta] = None,
) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
@@ -2125,10 +2221,19 @@ class AsyncQueryBase(object):
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
"""
return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
return (
flatten_columns(await self.to_arrow(timeout=timeout), flatten)
).to_pandas()
async def to_polars(self) -> "pl.DataFrame":
async def to_polars(
self,
timeout: Optional[timedelta] = None,
) -> "pl.DataFrame":
"""
Execute the query and collect the results into a Polars DataFrame.
@@ -2137,6 +2242,13 @@ class AsyncQueryBase(object):
[to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to
polars separately.
Parameters
----------
timeout: Optional[timedelta]
The maximum time to wait for the query to complete.
If not specified, no timeout is applied. If the query does not
complete within the specified time, an error will be raised.
Examples
--------
@@ -2152,7 +2264,7 @@ class AsyncQueryBase(object):
"""
import polars as pl
return pl.from_arrow(await self.to_arrow())
return pl.from_arrow(await self.to_arrow(timeout=timeout))
async def explain_plan(self, verbose: Optional[bool] = False):
"""Return the execution plan for this query.
@@ -2423,9 +2535,12 @@ class AsyncFTSQuery(AsyncQueryBase):
)
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
reader = await super().to_batches()
reader = await super().to_batches(timeout=timeout)
results = pa.Table.from_batches(await reader.read_all(), reader.schema)
if self._reranker:
results = self._reranker.rerank_fts(self.get_query(), results)
@@ -2649,9 +2764,12 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
reader = await super().to_batches()
reader = await super().to_batches(timeout=timeout)
results = pa.Table.from_batches(await reader.read_all(), reader.schema)
if self._reranker:
results = self._reranker.rerank_vector(self._query_string, results)
@@ -2707,7 +2825,10 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
self,
*,
max_batch_length: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> AsyncRecordBatchReader:
fts_query = AsyncFTSQuery(self._inner.to_fts_query())
vec_query = AsyncVectorQuery(self._inner.to_vector_query())
@@ -2719,8 +2840,8 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
vec_query.with_row_id()
fts_results, vector_results = await asyncio.gather(
fts_query.to_arrow(),
vec_query.to_arrow(),
fts_query.to_arrow(timeout=timeout),
vec_query.to_arrow(timeout=timeout),
)
result = LanceHybridQueryBuilder._combine_hybrid_results(

View File

@@ -355,9 +355,15 @@ class RemoteTable(Table):
)
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
async_iter = LOOP.run(self._table._execute_query(query, batch_size=batch_size))
async_iter = LOOP.run(
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
)
def iter_sync():
try:

View File

@@ -1007,7 +1007,11 @@ class Table(ABC):
@abstractmethod
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader: ...
@abstractmethod
@@ -2312,9 +2316,15 @@ class LanceTable(Table):
LOOP.run(self._table.update(values, where=where, updates_sql=values_sql))
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
async_iter = LOOP.run(self._table._execute_query(query, batch_size))
async_iter = LOOP.run(
self._table._execute_query(query, batch_size=batch_size, timeout=timeout)
)
def iter_sync():
try:
@@ -3390,7 +3400,11 @@ class AsyncTable:
return async_query
async def _execute_query(
self, query: Query, batch_size: Optional[int] = None
self,
query: Query,
*,
batch_size: Optional[int] = None,
timeout: Optional[timedelta] = None,
) -> pa.RecordBatchReader:
# The sync table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only
@@ -3398,7 +3412,9 @@ class AsyncTable:
async_query = self._sync_query_to_async(query)
return await async_query.to_batches(max_batch_length=batch_size)
return await async_query.to_batches(
max_batch_length=batch_size, timeout=timeout
)
async def _explain_plan(self, query: Query, verbose: Optional[bool]) -> str:
# This method is used by the sync table

View File

@@ -12,6 +12,7 @@ import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
import requests
# These are integration tests for embedding functions.
# They are slow because they require downloading models
@@ -516,3 +517,61 @@ def test_voyageai_embedding_function():
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_embedding_function():
voyageai = (
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
)
class Images(LanceModel):
label: str
image_uri: str = voyageai.SourceField() # image uri as the source
image_bytes: bytes = voyageai.SourceField() # image bytes as the source
vector: Vector(voyageai.ndims()) = voyageai.VectorField() # vector column
vec_from_bytes: Vector(voyageai.ndims()) = (
voyageai.VectorField()
) # Another vector column
db = lancedb.connect("~/lancedb")
table = db.create_table("test", schema=Images, mode="overwrite")
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
)
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_embedding_text_function():
voyageai = (
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()

View File

@@ -511,7 +511,8 @@ def test_query_builder_with_different_vector_column():
columns=["b"],
vector_column="foo_vector",
),
None,
batch_size=None,
timeout=None,
)
@@ -1076,3 +1077,67 @@ async def test_query_serialization_async(table_async: AsyncTable):
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
with_row_id=False,
)
def test_query_timeout(tmp_path):
# Use local directory instead of memory:// to add a bit of latency to
# operations so a timeout of zero will trigger exceptions.
db = lancedb.connect(tmp_path)
data = pa.table(
{
"text": ["a", "b"],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(4).cast(pa.float32()), 2
),
}
)
table = db.create_table("test", data)
table.create_fts_index("text", use_tantivy=False)
with pytest.raises(Exception, match="Query timeout"):
table.search().where("text = 'a'").to_list(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search("a", query_type="fts").to_pandas(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
table.search(query_type="hybrid").vector([0.0, 0.0]).text("a").to_arrow(
timeout=timedelta(0)
)
@pytest.mark.asyncio
async def test_query_timeout_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
data = pa.table(
{
"text": ["a", "b"],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(4).cast(pa.float32()), 2
),
}
)
table = await db.create_table("test", data)
await table.create_index("text", config=FTS())
with pytest.raises(Exception, match="Query timeout"):
await table.query().where("text != 'a'").to_list(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
await table.vector_search([0.0, 0.0]).to_arrow(timeout=timedelta(0))
with pytest.raises(Exception, match="Query timeout"):
await (await table.search("a", query_type="fts")).to_pandas(
timeout=timedelta(0)
)
with pytest.raises(Exception, match="Query timeout"):
await (
table.query()
.nearest_to_text("a")
.nearest_to([0.0, 0.0])
.to_list(timeout=timedelta(0))
)

View File

@@ -2,6 +2,7 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use std::time::Duration;
use arrow::array::make_array;
use arrow::array::Array;
@@ -294,10 +295,11 @@ impl Query {
})
}
#[pyo3(signature = (max_batch_length=None))]
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
@@ -305,6 +307,9 @@ impl Query {
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
@@ -376,10 +381,11 @@ impl FTSQuery {
self.inner = self.inner.clone().postfilter();
}
#[pyo3(signature = (max_batch_length=None))]
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_
.inner
@@ -391,6 +397,9 @@ impl FTSQuery {
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
@@ -513,10 +522,11 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index()
}
#[pyo3(signature = (max_batch_length=None))]
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
max_batch_length: Option<u32>,
timeout: Option<Duration>,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
@@ -524,6 +534,9 @@ impl VectorQuery {
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
if let Some(timeout) = timeout {
opts.timeout = Some(timeout);
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.19.0-beta.3"
version = "0.19.0-beta.4"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -142,12 +142,6 @@ impl CreateTableBuilder<true> {
}
}
/// Apply the given write options when writing the initial data
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
self.request.write_options = write_options;
self
}
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
let embedding_registry = self.embedding_registry.clone();
@@ -229,6 +223,12 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
self
}
/// Apply the given write options when writing the initial data
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
self.request.write_options = write_options;
self
}
/// Set an option for the storage layer.
///
/// Options already set on the connection will be inherited by the table,

View File

@@ -14,6 +14,9 @@ use object_store::{
use async_trait::async_trait;
#[cfg(test)]
pub mod io_tracking;
#[derive(Debug)]
struct MirroringObjectStore {
primary: Arc<dyn ObjectStore>,

View File

@@ -0,0 +1,237 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
fmt::{Display, Formatter},
sync::{Arc, Mutex},
};
use bytes::Bytes;
use futures::stream::BoxStream;
use lance::io::WrappingObjectStore;
use object_store::{
path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
PutMultipartOpts, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
};
#[derive(Debug, Default)]
pub struct IoStats {
pub read_iops: u64,
pub read_bytes: u64,
pub write_iops: u64,
pub write_bytes: u64,
}
impl Display for IoStats {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
#[derive(Debug, Clone)]
pub struct IoTrackingStore {
target: Arc<dyn ObjectStore>,
stats: Arc<Mutex<IoStats>>,
}
impl Display for IoTrackingStore {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self)
}
}
#[derive(Debug, Default, Clone)]
pub struct IoStatsHolder(Arc<Mutex<IoStats>>);
impl IoStatsHolder {
pub fn incremental_stats(&self) -> IoStats {
std::mem::take(&mut self.0.lock().expect("failed to lock IoStats"))
}
}
impl WrappingObjectStore for IoStatsHolder {
fn wrap(&self, target: Arc<dyn ObjectStore>) -> Arc<dyn ObjectStore> {
Arc::new(IoTrackingStore {
target,
stats: self.0.clone(),
})
}
}
impl IoTrackingStore {
pub fn new_wrapper() -> (Arc<dyn WrappingObjectStore>, Arc<Mutex<IoStats>>) {
let stats = Arc::new(Mutex::new(IoStats::default()));
(Arc::new(IoStatsHolder(stats.clone())), stats)
}
fn record_read(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap();
stats.read_iops += 1;
stats.read_bytes += num_bytes;
}
fn record_write(&self, num_bytes: u64) {
let mut stats = self.stats.lock().unwrap();
stats.write_iops += 1;
stats.write_bytes += num_bytes;
}
}
#[async_trait::async_trait]
#[deny(clippy::missing_trait_methods)]
impl ObjectStore for IoTrackingStore {
async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult<PutResult> {
self.record_write(bytes.content_length() as u64);
self.target.put(location, bytes).await
}
async fn put_opts(
&self,
location: &Path,
bytes: PutPayload,
opts: PutOptions,
) -> OSResult<PutResult> {
self.record_write(bytes.content_length() as u64);
self.target.put_opts(location, bytes, opts).await
}
async fn put_multipart(&self, location: &Path) -> OSResult<Box<dyn MultipartUpload>> {
let target = self.target.put_multipart(location).await?;
Ok(Box::new(IoTrackingMultipartUpload {
target,
stats: self.stats.clone(),
}))
}
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOpts,
) -> OSResult<Box<dyn MultipartUpload>> {
let target = self.target.put_multipart_opts(location, opts).await?;
Ok(Box::new(IoTrackingMultipartUpload {
target,
stats: self.stats.clone(),
}))
}
async fn get(&self, location: &Path) -> OSResult<GetResult> {
let result = self.target.get(location).await;
if let Ok(result) = &result {
let num_bytes = result.range.end - result.range.start;
self.record_read(num_bytes as u64);
}
result
}
async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult<GetResult> {
let result = self.target.get_opts(location, options).await;
if let Ok(result) = &result {
let num_bytes = result.range.end - result.range.start;
self.record_read(num_bytes as u64);
}
result
}
async fn get_range(&self, location: &Path, range: std::ops::Range<usize>) -> OSResult<Bytes> {
let result = self.target.get_range(location, range).await;
if let Ok(result) = &result {
self.record_read(result.len() as u64);
}
result
}
async fn get_ranges(
&self,
location: &Path,
ranges: &[std::ops::Range<usize>],
) -> OSResult<Vec<Bytes>> {
let result = self.target.get_ranges(location, ranges).await;
if let Ok(result) = &result {
self.record_read(result.iter().map(|b| b.len() as u64).sum());
}
result
}
async fn head(&self, location: &Path) -> OSResult<ObjectMeta> {
self.record_read(0);
self.target.head(location).await
}
async fn delete(&self, location: &Path) -> OSResult<()> {
self.record_write(0);
self.target.delete(location).await
}
fn delete_stream<'a>(
&'a self,
locations: BoxStream<'a, OSResult<Path>>,
) -> BoxStream<'a, OSResult<Path>> {
self.target.delete_stream(locations)
}
fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, OSResult<ObjectMeta>> {
self.record_read(0);
self.target.list(prefix)
}
fn list_with_offset(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> BoxStream<'_, OSResult<ObjectMeta>> {
self.record_read(0);
self.target.list_with_offset(prefix, offset)
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult<ListResult> {
self.record_read(0);
self.target.list_with_delimiter(prefix).await
}
async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.copy(from, to).await
}
async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.rename(from, to).await
}
async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.rename_if_not_exists(from, to).await
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> {
self.record_write(0);
self.target.copy_if_not_exists(from, to).await
}
}
#[derive(Debug)]
struct IoTrackingMultipartUpload {
target: Box<dyn MultipartUpload>,
stats: Arc<Mutex<IoStats>>,
}
#[async_trait::async_trait]
impl MultipartUpload for IoTrackingMultipartUpload {
async fn abort(&mut self) -> OSResult<()> {
self.target.abort().await
}
async fn complete(&mut self) -> OSResult<PutResult> {
self.target.complete().await
}
fn put_part(&mut self, payload: PutPayload) -> UploadPart {
{
let mut stats = self.stats.lock().unwrap();
stats.write_iops += 1;
stats.write_bytes += payload.content_length() as u64;
}
self.target.put_part(payload)
}
}

View File

@@ -1,8 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::future::Future;
use std::sync::Arc;
use std::{future::Future, time::Duration};
use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
@@ -25,6 +25,7 @@ use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
use crate::table::BaseTable;
use crate::utils::TimeoutStream;
use crate::DistanceType;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
@@ -525,12 +526,15 @@ pub struct QueryExecutionOptions {
///
/// By default, this is 1024
pub max_batch_length: u32,
/// Max duration to wait for the query to execute before timing out.
pub timeout: Option<Duration>,
}
impl Default for QueryExecutionOptions {
fn default() -> Self {
Self {
max_batch_length: 1024,
timeout: None,
}
}
}
@@ -1007,7 +1011,10 @@ impl VectorQuery {
self
}
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
pub async fn execute_hybrid(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
// clone query and specify we want to include row IDs, which can be needed for reranking
let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone();
@@ -1016,7 +1023,10 @@ impl VectorQuery {
let mut vector_query = self.clone().with_row_id();
vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
let (fts_results, vec_results) = try_join!(
fts_query.execute_with_options(options.clone()),
vector_query.inner_execute_with_options(options)
)?;
let (fts_results, vec_results) = try_join!(
fts_results.try_collect::<Vec<_>>(),
@@ -1074,6 +1084,20 @@ impl VectorQuery {
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
))
}
async fn inner_execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan(options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
inner
};
Ok(DatasetRecordBatchStream::new(inner).into())
}
}
impl ExecutableQuery for VectorQuery {
@@ -1087,16 +1111,13 @@ impl ExecutableQuery for VectorQuery {
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
if self.request.base.full_text_search.is_some() {
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
let hybrid_result = async move { self.execute_hybrid(options).await }
.boxed()
.await?;
return Ok(hybrid_result);
}
Ok(SendableRecordBatchStream::from(
DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
Default::default(),
)?),
))
self.inner_execute_with_options(options).await
}
async fn explain_plan(&self, verbose: bool) -> Result<String> {

View File

@@ -13,7 +13,7 @@ use reqwest::{
use crate::error::{Error, Result};
use crate::remote::db::RemoteOptions;
const REQUEST_ID_HEADER: &str = "x-request-id";
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
/// Configuration for the LanceDB Cloud HTTP client.
#[derive(Clone, Debug)]
@@ -299,7 +299,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
"x-api-key",
HeaderName::from_static("x-api-key"),
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidInput {
message: "non-ascii api key provided".to_string(),
})?,
@@ -307,7 +307,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
if region == "local" {
let host = format!("{}.local.api.lancedb.com", db_name);
headers.insert(
"Host",
http::header::HOST,
HeaderValue::from_str(&host).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
@@ -315,7 +315,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
if has_host_override {
headers.insert(
"x-lancedb-database",
HeaderName::from_static("x-lancedb-database"),
HeaderValue::from_str(db_name).map_err(|_| Error::InvalidInput {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
@@ -323,7 +323,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
if db_prefix.is_some() {
headers.insert(
"x-lancedb-database-prefix",
HeaderName::from_static("x-lancedb-database-prefix"),
HeaderValue::from_str(db_prefix.unwrap()).map_err(|_| Error::InvalidInput {
message: format!(
"non-ascii database prefix '{}' provided",
@@ -335,7 +335,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
if let Some(v) = options.0.get("account_name") {
headers.insert(
"x-azure-storage-account-name",
HeaderName::from_static("x-azure-storage-account-name"),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name),
})?,
@@ -343,7 +343,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
if let Some(v) = options.0.get("azure_storage_account_name") {
headers.insert(
"x-azure-storage-account-name",
HeaderName::from_static("x-azure-storage-account-name"),
HeaderValue::from_str(v).map_err(|_| Error::InvalidInput {
message: format!("non-ascii storage account name '{}' provided", db_name),
})?,

View File

@@ -20,7 +20,7 @@ use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
use futures::TryStreamExt;
use http::header::CONTENT_TYPE;
use http::StatusCode;
use http::{HeaderName, StatusCode};
use lance::arrow::json::{JsonDataType, JsonSchema};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
@@ -44,6 +44,8 @@ use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE;
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
#[derive(Debug)]
pub struct RemoteTable<S: HttpSend = Sender> {
#[allow(dead_code)]
@@ -332,9 +334,19 @@ impl<S: HttpSend> RemoteTable<S> {
async fn execute_query(
&self,
query: &AnyQuery,
_options: QueryExecutionOptions,
options: &QueryExecutionOptions,
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name));
if let Some(timeout) = options.timeout {
// Client side timeout
request = request.timeout(timeout);
// Also send to server, so it can abort the query if it takes too long.
// (If it doesn't fit into u64, it's not worth sending anyways.)
if let Ok(timeout_ms) = u64::try_from(timeout.as_millis()) {
request = request.header(REQUEST_TIMEOUT_HEADER, timeout_ms);
}
}
let query_bodies = self.prepare_query_bodies(query).await?;
let requests: Vec<reqwest::RequestBuilder> = query_bodies
@@ -543,7 +555,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let streams = self.execute_query(query, options).await?;
let streams = self.execute_query(query, &options).await?;
if streams.len() == 1 {
let stream = streams.into_iter().next().unwrap();
Ok(Arc::new(OneShotExec::new(stream)))
@@ -559,9 +571,9 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
async fn query(
&self,
query: &AnyQuery,
_options: QueryExecutionOptions,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let streams = self.execute_query(query, _options).await?;
let streams = self.execute_query(query, &options).await?;
if streams.len() == 1 {
Ok(DatasetRecordBatchStream::new(

View File

@@ -68,7 +68,7 @@ use crate::query::{
use crate::utils::{
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type,
PatchReadParam, PatchWriteParam,
PatchReadParam, PatchWriteParam, TimeoutStream,
};
use self::dataset::DatasetConsistencyWrapper;
@@ -1775,11 +1775,14 @@ impl NativeTable {
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let plan = self.create_plan(query, options).await?;
Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
Default::default(),
)?))
let plan = self.create_plan(query, options.clone()).await?;
let inner = execute_plan(plan, Default::default())?;
let inner = if let Some(timeout) = options.timeout {
TimeoutStream::new_boxed(inner, timeout)
} else {
inner
};
Ok(DatasetRecordBatchStream::new(inner))
}
/// Check whether the table uses V2 manifest paths.

View File

@@ -48,7 +48,6 @@ impl DatasetRef {
refresh_task,
..
} => {
dataset.checkout_latest().await?;
// Replace the refresh task
if let Some(refresh_task) = refresh_task {
refresh_task.abort();
@@ -372,3 +371,48 @@ impl DerefMut for DatasetWriteGuard<'_> {
}
}
}
#[cfg(test)]
mod tests {
use arrow_schema::{DataType, Field, Schema};
use lance::{dataset::WriteParams, io::ObjectStoreParams};
use super::*;
use crate::{connect, io::object_store::io_tracking::IoStatsHolder, table::WriteOptions};
#[tokio::test]
async fn test_iops_open_strong_consistency() {
let db = connect("memory://")
.read_consistency_interval(Some(Duration::ZERO))
.execute()
.await
.expect("Failed to connect to database");
let io_stats = IoStatsHolder::default();
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let table = db
.create_empty_table("test", schema)
.write_options(WriteOptions {
lance_write_params: Some(WriteParams {
store_params: Some(ObjectStoreParams {
object_store_wrapper: Some(Arc::new(io_stats.clone())),
..Default::default()
}),
..Default::default()
}),
})
.execute()
.await
.unwrap();
io_stats.incremental_stats();
// We should only need 1 read IOP to check the schema: looking for the
// latest version.
table.schema().await.unwrap();
let stats = io_stats.incremental_stats();
assert_eq!(stats.read_iops, 1);
}
}

View File

@@ -3,14 +3,20 @@
use std::sync::Arc;
use arrow_schema::{DataType, Schema};
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Schema, SchemaRef};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_execution::RecordBatchStream;
use futures::{FutureExt, Stream};
use lance::arrow::json::JsonDataType;
use lance::dataset::{ReadParams, WriteParams};
use lance::index::vector::utils::infer_vector_dim;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lazy_static::lazy_static;
use std::pin::Pin;
use crate::error::{Error, Result};
use datafusion_physical_plan::SendableRecordBatchStream;
lazy_static! {
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
@@ -178,11 +184,97 @@ pub fn string_to_datatype(s: &str) -> Option<DataType> {
(&json_type).try_into().ok()
}
enum TimeoutState {
NotStarted {
timeout: std::time::Duration,
},
Started {
deadline: Pin<Box<tokio::time::Sleep>>,
timeout: std::time::Duration,
},
Completed,
}
/// A `Stream` wrapper that implements a timeout.
///
/// The timeout starts when the first `poll_next` is called. As soon as the timeout
/// duration has passed, the stream will return an `Err` indicating a timeout error
/// for the next poll.
pub struct TimeoutStream {
inner: SendableRecordBatchStream,
state: TimeoutState,
}
impl TimeoutStream {
pub fn new(inner: SendableRecordBatchStream, timeout: std::time::Duration) -> Self {
Self {
inner,
state: TimeoutState::NotStarted { timeout },
}
}
pub fn new_boxed(
inner: SendableRecordBatchStream,
timeout: std::time::Duration,
) -> SendableRecordBatchStream {
Box::pin(Self::new(inner, timeout))
}
fn timeout_error(timeout: &std::time::Duration) -> DataFusionError {
DataFusionError::Execution(format!("Query timeout after {} ms", timeout.as_millis()))
}
}
impl RecordBatchStream for TimeoutStream {
fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}
impl Stream for TimeoutStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match &mut self.state {
TimeoutState::NotStarted { timeout } => {
if timeout.is_zero() {
return std::task::Poll::Ready(Some(Err(Self::timeout_error(timeout))));
}
let deadline = Box::pin(tokio::time::sleep(*timeout));
self.state = TimeoutState::Started {
deadline,
timeout: *timeout,
};
self.poll_next(cx)
}
TimeoutState::Started { deadline, timeout } => match deadline.poll_unpin(cx) {
std::task::Poll::Ready(_) => {
let err = Self::timeout_error(timeout);
self.state = TimeoutState::Completed;
std::task::Poll::Ready(Some(Err(err)))
}
std::task::Poll::Pending => {
let inner = Pin::new(&mut self.inner);
inner.poll_next(cx)
}
},
TimeoutState::Completed => std::task::Poll::Ready(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::Int32Array;
use arrow_schema::Field;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{stream, StreamExt};
use tokio::time::sleep;
use arrow_schema::{DataType, Field};
use super::*;
#[test]
fn test_guess_default_column() {
@@ -249,4 +341,85 @@ mod tests {
let expected = DataType::Int32;
assert_eq!(string_to_datatype(string), Some(expected));
}
fn sample_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"col1",
DataType::Int32,
false,
)]));
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap()
}
#[tokio::test]
async fn test_timeout_stream() {
let batch = sample_batch();
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
let timeout_duration = std::time::Duration::from_millis(10);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
// Poll the stream to get the first batch
let first_result = timeout_stream.next().await;
assert!(first_result.is_some());
assert!(first_result.unwrap().is_ok());
// Sleep for the timeout duration
sleep(timeout_duration).await;
// Poll the stream again and ensure it returns a timeout error
let second_result = timeout_stream.next().await.unwrap();
assert!(second_result.is_err());
assert!(second_result
.unwrap_err()
.to_string()
.contains("Query timeout"));
}
#[tokio::test]
async fn test_timeout_stream_zero_duration() {
let batch = sample_batch();
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
// Setup similar to test_timeout_stream
let timeout_duration = std::time::Duration::from_secs(0);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
// First poll should immediately return a timeout error
let result = timeout_stream.next().await.unwrap();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Query timeout"));
}
#[tokio::test]
async fn test_timeout_stream_completes_normally() {
let batch = sample_batch();
let schema = batch.schema();
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
let sendable_stream: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
// Setup a stream with 2 batches
// Use a longer timeout that won't trigger
let timeout_duration = std::time::Duration::from_secs(1);
let mut timeout_stream = TimeoutStream::new(sendable_stream, timeout_duration);
// Both polls should return data normally
assert!(timeout_stream.next().await.unwrap().is_ok());
assert!(timeout_stream.next().await.unwrap().is_ok());
// Stream should be empty now
assert!(timeout_stream.next().await.is_none());
}
}