mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-02 18:40:40 +00:00
feat: fold UDF authoring into lancedb (udf module + connection/table ergonomics)
Brings the @udf/@table_udf decorator + type inference into lancedb as lancedb.udf (Apache-2.0), and adds the ergonomic glue to the existing connection/table so there's no separate object model: - create_function() accepts a Udf (and a replace= flag) - Table.add_computed_column(column, udf) - create_view(name, source, select, ...) -> View (assembles the SELECT) - Connection.job(job_id) -> JobHandle - View / JobHandle are thin references over a connection Exports udf/table_udf/Udf/JobHandle/View from the package root. The operations stay the existing remote-only methods (enterprise/cloud); the decorator works locally. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .expr import Expr, col, lit, func
|
||||
from .udf import udf, table_udf, Udf, JobHandle, View
|
||||
from .schema import vector
|
||||
from .table import AsyncTable, Table
|
||||
from ._lancedb import Session
|
||||
@@ -448,6 +449,11 @@ async def connect_async(
|
||||
|
||||
|
||||
__all__ = [
|
||||
"udf",
|
||||
"table_udf",
|
||||
"Udf",
|
||||
"JobHandle",
|
||||
"View",
|
||||
"connect",
|
||||
"connect_async",
|
||||
"connect_namespace",
|
||||
|
||||
@@ -569,18 +569,26 @@ class DBConnection(EnforceOverrides):
|
||||
|
||||
def create_function(
|
||||
self,
|
||||
name: str,
|
||||
language: str,
|
||||
return_type: str,
|
||||
body: str,
|
||||
name,
|
||||
language: str = "python",
|
||||
return_type: Optional[str] = None,
|
||||
body: Optional[str] = None,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Register a UDF (CREATE FUNCTION).
|
||||
|
||||
Pass a ``@udf`` / ``@table_udf``-decorated function (preferred):
|
||||
|
||||
db.create_function(embed)
|
||||
|
||||
or the explicit fields:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Function name.
|
||||
name: str or Udf
|
||||
A decorated UDF object, or the function name.
|
||||
language: str
|
||||
Implementation language (currently "python").
|
||||
return_type: str
|
||||
@@ -592,7 +600,25 @@ class DBConnection(EnforceOverrides):
|
||||
options: dict, optional
|
||||
input_columns, pip, num_gpus, batch_size, timeout,
|
||||
error_policy, docker_image, body_format, ...
|
||||
replace: bool
|
||||
Drop an existing function of the same name first.
|
||||
"""
|
||||
from .udf import Udf
|
||||
|
||||
if isinstance(name, Udf):
|
||||
req = name.create_request()
|
||||
name, language, return_type, body, options = (
|
||||
req["name"],
|
||||
req["language"],
|
||||
req["return_type"],
|
||||
req["body"],
|
||||
req["options"],
|
||||
)
|
||||
if replace:
|
||||
try:
|
||||
self.drop_function(name)
|
||||
except Exception:
|
||||
pass
|
||||
LOOP.run(self._conn.create_function(name, language, return_type, body, options))
|
||||
|
||||
def list_functions(self):
|
||||
@@ -624,6 +650,45 @@ class DBConnection(EnforceOverrides):
|
||||
)
|
||||
)
|
||||
|
||||
def create_view(
|
||||
self,
|
||||
name: str,
|
||||
source,
|
||||
select,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
auto_refresh: bool = False,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Create a materialized view from a source and select items, and
|
||||
return a `View` handle.
|
||||
|
||||
`source` is a table name or table; `select` items are column names,
|
||||
expression strings ("embed(body)"), (alias, expression) tuples, or
|
||||
``@udf`` / ``@table_udf`` objects. Sugar over create_materialized_view:
|
||||
it assembles the SELECT, which the server parses (one parser, shared
|
||||
with SQL).
|
||||
"""
|
||||
from .udf import build_view_query, View
|
||||
|
||||
query = build_view_query(source, select)
|
||||
if where:
|
||||
query += f" WHERE {where}"
|
||||
if replace:
|
||||
try:
|
||||
self.drop_materialized_view(name)
|
||||
except Exception:
|
||||
pass
|
||||
self.create_materialized_view(name, query, auto_refresh=auto_refresh)
|
||||
return View(self, name)
|
||||
|
||||
def job(self, job_id: str):
|
||||
"""A `JobHandle` for polling/cancelling an inflight job by id (e.g.
|
||||
``db.job(tbl.refresh_column("c")).wait()``)."""
|
||||
from .udf import JobHandle
|
||||
|
||||
return JobHandle(self, job_id)
|
||||
|
||||
def refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
|
||||
@@ -823,6 +823,47 @@ class Table(ABC):
|
||||
"""The number of rows in this Table"""
|
||||
return self.count_rows(None)
|
||||
|
||||
def add_computed_column(
|
||||
self,
|
||||
columns,
|
||||
fn,
|
||||
args: Optional[List[str]] = None,
|
||||
types=None,
|
||||
) -> None:
|
||||
"""Declare computed column(s) bound to a UDF -- no compute happens
|
||||
here (the agent fills them lazily, or refresh_column() triggers a
|
||||
run). Sugar over add_columns(computed=): column types come from the
|
||||
UDF's declared return type (a STRUCT return maps its fields to the
|
||||
columns positionally); pass `types` when `fn` is a bare name string.
|
||||
Register the function first. Server-backed (Enterprise / Cloud)."""
|
||||
from .udf import Udf, struct_field_types
|
||||
|
||||
multi = isinstance(columns, (tuple, list))
|
||||
if isinstance(fn, Udf):
|
||||
expr = fn.expression(*(args or []))
|
||||
if types is None:
|
||||
if multi:
|
||||
if not fn.returns.upper().startswith("STRUCT"):
|
||||
raise ValueError(
|
||||
"several columns need a STRUCT-returning function"
|
||||
)
|
||||
types = struct_field_types(fn.returns)
|
||||
else:
|
||||
types = fn.returns
|
||||
else:
|
||||
if types is None:
|
||||
raise ValueError("pass types= when fn is a name string")
|
||||
expr = f"{fn}({', '.join(args or [])})"
|
||||
if multi:
|
||||
if len(types) != len(columns):
|
||||
raise ValueError(
|
||||
f"{len(columns)} columns but {len(types)} output types"
|
||||
)
|
||||
computed = {c: (t, expr) for c, t in zip(columns, types)}
|
||||
else:
|
||||
computed = {columns: (types, expr)}
|
||||
self.add_columns(computed=computed)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
|
||||
482
python/python/lancedb/udf.py
Normal file
482
python/python/lancedb/udf.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
"""UDF authoring for LanceDB derived compute (server-backed).
|
||||
|
||||
`@udf` / `@table_udf` turn a plain Python function into a registrable
|
||||
server-side UDF: a cloudpickled (or source) body, a SQL signature inferred
|
||||
from type hints, and the runtime options (pip deps, GPUs, batching, ...).
|
||||
Register and use them through the existing connection/table API:
|
||||
|
||||
import lancedb
|
||||
from lancedb import udf, table_udf
|
||||
|
||||
db = lancedb.connect("db://my_db", api_key="...", host_override="...")
|
||||
|
||||
@udf(pip=["torch>=2.0"], num_gpus=1)
|
||||
def embed(text: str) -> list[float]:
|
||||
return model.encode(text).tolist()
|
||||
|
||||
db.create_function(embed) # CREATE FUNCTION
|
||||
tbl = db.open_table("docs")
|
||||
tbl.add_computed_column("vec", embed) # declare (no compute yet)
|
||||
db.job(tbl.refresh_column("vec")).wait() # materialize
|
||||
view = db.create_view("chunks", tbl, ["id", chunk_fn])
|
||||
|
||||
These operations are server-backed (LanceDB Enterprise / Cloud); the
|
||||
decorator itself works locally (define + call), only registration needs a
|
||||
remote connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import typing
|
||||
|
||||
# -- type hints -> SQL type strings -------------------------------------
|
||||
|
||||
_SCALARS = {
|
||||
int: "BIGINT",
|
||||
# Pragmatic default for ML workloads: python float maps to FLOAT
|
||||
# (Float32). Use an explicit `returns=` for DOUBLE.
|
||||
float: "FLOAT",
|
||||
str: "VARCHAR",
|
||||
bool: "BOOLEAN",
|
||||
bytes: "BLOB",
|
||||
}
|
||||
|
||||
|
||||
class TypeInferenceError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
def sql_type(hint) -> str:
|
||||
"""SQL type string for a python type hint."""
|
||||
if hint in _SCALARS:
|
||||
return _SCALARS[hint]
|
||||
origin = typing.get_origin(hint)
|
||||
if origin in (list, typing.List):
|
||||
(item,) = typing.get_args(hint) or (None,)
|
||||
if item in _SCALARS:
|
||||
return f"{_SCALARS[item]}[]"
|
||||
raise TypeInferenceError(
|
||||
f"unsupported list item type {item!r}; use an explicit returns="
|
||||
)
|
||||
fields = _struct_fields(hint)
|
||||
if fields is not None:
|
||||
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
|
||||
return f"STRUCT({inner})"
|
||||
raise TypeInferenceError(
|
||||
f"cannot infer a SQL type for {hint!r}; pass an explicit type string"
|
||||
)
|
||||
|
||||
|
||||
def _struct_fields(hint):
|
||||
"""(name, hint) pairs for a TypedDict or dataclass, else None."""
|
||||
if dataclasses.is_dataclass(hint):
|
||||
return [(f.name, f.type) for f in dataclasses.fields(hint)]
|
||||
# TypedDict detection: a dict subclass with __annotations__.
|
||||
if isinstance(hint, type) and issubclass(hint, dict) and typing.get_type_hints(hint):
|
||||
return list(typing.get_type_hints(hint).items())
|
||||
return None
|
||||
|
||||
|
||||
def return_type(fn, override: "str | None", table: bool) -> str:
|
||||
"""SQL return type for a function: explicit override wins, else the
|
||||
return annotation. Table functions render as TABLE(...) and accept
|
||||
struct-shaped hints (TypedDict/dataclass, optionally list-wrapped)."""
|
||||
if override is not None:
|
||||
s = override.strip()
|
||||
if table and not s.upper().startswith("TABLE"):
|
||||
if s.upper().startswith("STRUCT"):
|
||||
return "TABLE" + s[len("STRUCT") :]
|
||||
raise TypeInferenceError(
|
||||
"a table function's returns= must be TABLE(...) or STRUCT(...)"
|
||||
)
|
||||
return s
|
||||
|
||||
hints = typing.get_type_hints(fn)
|
||||
ret = hints.get("return")
|
||||
if ret is None:
|
||||
raise TypeInferenceError(
|
||||
f"function {fn.__name__!r} needs a return annotation or returns="
|
||||
)
|
||||
if table:
|
||||
# Accept list[Row] / Row where Row is a TypedDict or dataclass.
|
||||
if typing.get_origin(ret) in (list, typing.List):
|
||||
(ret,) = typing.get_args(ret)
|
||||
fields = _struct_fields(ret)
|
||||
if fields is None:
|
||||
raise TypeInferenceError(
|
||||
"a table function must return rows shaped as a TypedDict or "
|
||||
"dataclass (optionally list-wrapped); or pass returns=..."
|
||||
)
|
||||
inner = ", ".join(f"{name} {sql_type(h)}" for name, h in fields)
|
||||
return f"TABLE({inner})"
|
||||
return sql_type(ret)
|
||||
|
||||
|
||||
def param_types(fn) -> "list[tuple[str, str]]":
|
||||
"""(name, sql type) per parameter, from annotations. Each UDF
|
||||
parameter binds to a source column of the same name by default."""
|
||||
hints = typing.get_type_hints(fn)
|
||||
out = []
|
||||
for name, p in inspect.signature(fn).parameters.items():
|
||||
if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD):
|
||||
raise TypeInferenceError("*args/**kwargs are not supported in UDFs")
|
||||
hint = hints.get(name)
|
||||
if hint is None:
|
||||
raise TypeInferenceError(
|
||||
f"parameter {name!r} of {fn.__name__!r} needs a type annotation"
|
||||
)
|
||||
out.append((name, sql_type(hint)))
|
||||
return out
|
||||
|
||||
|
||||
# -- the @udf / @table_udf decorators -----------------------------------
|
||||
|
||||
|
||||
class Udf:
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
*,
|
||||
returns: "str | None" = None,
|
||||
table: bool = False,
|
||||
name: "str | None" = None,
|
||||
pip: "list[str] | None" = None,
|
||||
pip_index_url: "str | None" = None,
|
||||
pip_extra_index_urls: "list[str] | None" = None,
|
||||
find_links: "list[str] | None" = None,
|
||||
requirements: "str | list[str] | None" = None,
|
||||
conda: "list[str] | None" = None,
|
||||
conda_channels: "list[str] | None" = None,
|
||||
env: "dict[str, str] | list[str] | None" = None,
|
||||
num_cpus: "int | None" = None,
|
||||
num_gpus: "int | None" = None,
|
||||
batch_size: "int | None" = None,
|
||||
timeout: "float | None" = None,
|
||||
error_policy: "str | None" = None,
|
||||
max_skip_ratio: "float | None" = None,
|
||||
retries: "int | None" = None,
|
||||
docker_image: "str | None" = None,
|
||||
description: "str | None" = None,
|
||||
prefer_source: bool = False,
|
||||
):
|
||||
functools.update_wrapper(self, fn)
|
||||
self.fn = fn
|
||||
self.name = name or fn.__name__
|
||||
self.table = table
|
||||
self.params = param_types(fn)
|
||||
self.returns = return_type(fn, returns, table)
|
||||
self.prefer_source = prefer_source
|
||||
self.options: "dict[str, str]" = {}
|
||||
if conda and (pip or requirements):
|
||||
raise ValueError("pass conda or pip/requirements, not both")
|
||||
if conda_channels and not conda:
|
||||
raise ValueError("conda_channels requires conda")
|
||||
if pip:
|
||||
self.options["pip"] = ",".join(pip)
|
||||
if pip_extra_index_urls:
|
||||
self.options["pip_extra_index_urls"] = ",".join(pip_extra_index_urls)
|
||||
if find_links:
|
||||
self.options["find_links"] = ",".join(find_links)
|
||||
if requirements:
|
||||
self.options["requirements"] = _format_requirements(requirements)
|
||||
if conda:
|
||||
self.options["conda"] = ",".join(conda)
|
||||
if conda_channels:
|
||||
self.options["conda_channels"] = ",".join(conda_channels)
|
||||
if env:
|
||||
self.options["env"] = _format_env(env)
|
||||
for key, val in [
|
||||
("pip_index_url", pip_index_url),
|
||||
("num_cpus", num_cpus),
|
||||
("num_gpus", num_gpus),
|
||||
("batch_size", batch_size),
|
||||
("timeout", timeout),
|
||||
("error_policy", error_policy),
|
||||
("max_skip_ratio", max_skip_ratio),
|
||||
("retries", retries),
|
||||
("docker_image", docker_image),
|
||||
]:
|
||||
if val is not None:
|
||||
self.options[key] = str(val)
|
||||
# Keep the source in the description (when available) so the
|
||||
# catalog stays inspectable even for pickled bodies.
|
||||
if description is not None:
|
||||
self.options["description"] = description
|
||||
else:
|
||||
try:
|
||||
self.options["description"] = textwrap.dedent(inspect.getsource(fn))
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Call with real values to run locally; call with column-name
|
||||
strings to build an expression for backfills and views."""
|
||||
if args and all(isinstance(a, str) for a in args) and not kwargs:
|
||||
return f"{self.name}({', '.join(args)})"
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def expression(self, *columns: str) -> str:
|
||||
cols = columns or [p for p, _ in self.params]
|
||||
return f"{self.name}({', '.join(cols)})"
|
||||
|
||||
def _body(self) -> "tuple[str, str]":
|
||||
"""(body literal, body_format). Source when requested and
|
||||
retrievable; cloudpickle otherwise (handles closures)."""
|
||||
if self.prefer_source:
|
||||
try:
|
||||
src = textwrap.dedent(inspect.getsource(self.fn))
|
||||
# Strip the decorator line(s) so the stored body is a
|
||||
# plain function definition.
|
||||
lines = src.splitlines(keepends=True)
|
||||
while lines and lines[0].lstrip().startswith("@"):
|
||||
lines.pop(0)
|
||||
return "".join(lines), "source"
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
import cloudpickle
|
||||
|
||||
raw = cloudpickle.dumps(self.fn)
|
||||
return base64.b64encode(raw).decode("ascii"), "cloudpickle"
|
||||
|
||||
def _body_and_options(self) -> "tuple[str, dict[str, str]]":
|
||||
"""The body literal plus the finalized options (body_format /
|
||||
python_version / cloudpickle-pip bookkeeping for a non-source
|
||||
body)."""
|
||||
body, body_format = self._body()
|
||||
options = dict(self.options)
|
||||
if body_format != "source":
|
||||
options["body_format"] = body_format
|
||||
# Pickled code objects only load under the same interpreter
|
||||
# minor version; record ours so the worker can fail with a
|
||||
# clear message instead of a bytecode error.
|
||||
options["python_version"] = self.pickle_environment()
|
||||
# The worker deserializes the body with cloudpickle; make sure
|
||||
# the job's pip environment provides it. Conda bakes inject
|
||||
# cloudpickle server-side, so do not create an invalid pip+conda
|
||||
# declaration here.
|
||||
if "conda" not in options:
|
||||
pip = [d for d in options.get("pip", "").split(",") if d]
|
||||
if not any(d.startswith("cloudpickle") for d in pip):
|
||||
pip.append("cloudpickle")
|
||||
options["pip"] = ",".join(pip)
|
||||
return body, options
|
||||
|
||||
def create_request(self) -> dict:
|
||||
"""Keyword arguments for `connection.create_function`."""
|
||||
body, options = self._body_and_options()
|
||||
return {
|
||||
"name": self.name,
|
||||
"language": "python",
|
||||
"return_type": self.returns,
|
||||
"body": body,
|
||||
"options": options,
|
||||
}
|
||||
|
||||
def create_statement(self) -> str:
|
||||
"""The equivalent `CREATE FUNCTION` SQL (for SQL-surface callers)."""
|
||||
params = ", ".join(f"{n} {t}" for n, t in self.params)
|
||||
body, options = self._body_and_options()
|
||||
with_clause = ""
|
||||
if options:
|
||||
rendered = ", ".join(
|
||||
f"{k} = '{_escape(v)}'" for k, v in sorted(options.items())
|
||||
)
|
||||
with_clause = f" WITH ({rendered})"
|
||||
return (
|
||||
f"CREATE FUNCTION {self.name}({params}) RETURNS {self.returns} "
|
||||
f"LANGUAGE python AS '{_escape_body(body)}'{with_clause}"
|
||||
)
|
||||
|
||||
def pickle_environment(self) -> str:
|
||||
"""Python version the body pickles under -- workers should match
|
||||
the minor version for cloudpickle compatibility."""
|
||||
return f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
|
||||
def _escape(s: str) -> str:
|
||||
return str(s).replace("'", "''")
|
||||
|
||||
|
||||
def _format_requirements(requirements: "str | list[str]") -> str:
|
||||
if isinstance(requirements, str):
|
||||
return requirements
|
||||
return "\n".join(str(req) for req in requirements)
|
||||
|
||||
|
||||
def _format_env(env: "dict[str, str] | list[str]") -> str:
|
||||
if isinstance(env, dict):
|
||||
return "; ".join(f"{key}={value}" for key, value in env.items())
|
||||
return "; ".join(str(entry) for entry in env)
|
||||
|
||||
|
||||
def _escape_body(body: str) -> str:
|
||||
# The server unescapes \n / \t in single-quoted bodies; encode real
|
||||
# newlines accordingly and escape quotes.
|
||||
return body.replace("\\", "\\\\").replace("'", "''").replace("\n", "\\n").replace("\t", "\\t")
|
||||
|
||||
|
||||
def udf(fn=None, **kwargs):
|
||||
"""Decorate a function as a scalar (or struct-returning) UDF.
|
||||
|
||||
@udf
|
||||
def doubled(val: int) -> float: ...
|
||||
|
||||
@udf(pip=["torch>=2"], num_gpus=1)
|
||||
def embed(body: str) -> list[float]: ...
|
||||
"""
|
||||
if fn is not None:
|
||||
return Udf(fn, **kwargs)
|
||||
return lambda f: Udf(f, **kwargs)
|
||||
|
||||
|
||||
def table_udf(fn=None, **kwargs):
|
||||
"""Decorate a table function (UDTF): each input row may emit zero or
|
||||
more output rows. Only usable in materialized views.
|
||||
|
||||
class Chunk(TypedDict):
|
||||
chunk: str
|
||||
chunk_idx: int
|
||||
|
||||
@table_udf
|
||||
def chunker(body: str) -> list[Chunk]: ...
|
||||
"""
|
||||
kwargs["table"] = True
|
||||
if fn is not None:
|
||||
return Udf(fn, **kwargs)
|
||||
return lambda f: Udf(f, **kwargs)
|
||||
|
||||
|
||||
# -- view / job handles (thin references over a connection) -------------
|
||||
|
||||
|
||||
def struct_field_types(returns: str) -> "list[str]":
|
||||
"""Field type strings of a STRUCT(...) SQL type, in declared order."""
|
||||
inner = returns.strip()[len("STRUCT(") : -1]
|
||||
fields, depth, start = [], 0, 0
|
||||
for i, c in enumerate(inner):
|
||||
if c in "([":
|
||||
depth += 1
|
||||
elif c in ")]":
|
||||
depth -= 1
|
||||
elif c == "," and depth == 0:
|
||||
fields.append(inner[start:i].strip())
|
||||
start = i + 1
|
||||
fields.append(inner[start:].strip())
|
||||
# Each field is "name TYPE"; drop the name.
|
||||
return [f.split(None, 1)[1] for f in fields]
|
||||
|
||||
|
||||
def build_view_query(source, select) -> str:
|
||||
"""Assemble a view SELECT from a source (name or table) and select
|
||||
items: a column name, an expression string, a (alias, expression)
|
||||
tuple, or a @udf/@table_udf object."""
|
||||
src = source.name if hasattr(source, "name") else source
|
||||
items = []
|
||||
for item in select:
|
||||
if isinstance(item, Udf):
|
||||
items.append(item.expression())
|
||||
elif isinstance(item, tuple):
|
||||
alias, expr = item
|
||||
expr = expr.expression() if isinstance(expr, Udf) else expr
|
||||
items.append(f"{expr} AS {alias}")
|
||||
else:
|
||||
items.append(item)
|
||||
return f"SELECT {', '.join(items)} FROM {src}"
|
||||
|
||||
|
||||
class View:
|
||||
"""A reference to a materialized view (name + connection). View
|
||||
operations are server-backed connection calls bound to the name."""
|
||||
|
||||
def __init__(self, conn, name: str):
|
||||
self.conn = conn
|
||||
self.name = name
|
||||
|
||||
def refresh(self, full: bool = False):
|
||||
if full:
|
||||
# full/force-rebuild is not honored on any surface yet (the
|
||||
# refresh event carries no `full` flag) -- do not pretend.
|
||||
raise NotImplementedError(
|
||||
"full=True refresh is not supported yet (engine gap: the "
|
||||
"refresh event has no full-rebuild flag)"
|
||||
)
|
||||
return self.conn.refresh_materialized_view(self.name)
|
||||
|
||||
def explain_refresh(self, full: bool = False):
|
||||
"""Plan a refresh without running it (EXPLAIN REFRESH)."""
|
||||
return self.conn.explain_refresh_materialized_view(self.name, full=full)
|
||||
|
||||
def alter(self, auto_refresh: bool) -> None:
|
||||
self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
|
||||
|
||||
def drop(self) -> None:
|
||||
self.conn.drop_materialized_view(self.name)
|
||||
|
||||
|
||||
_PROGRESS = re.compile(r"(\d+)/(\d+)")
|
||||
|
||||
|
||||
class JobHandle:
|
||||
"""A reference to an inflight server-side job, with polling helpers."""
|
||||
|
||||
#: How long an unseen job is treated as still materializing (submission
|
||||
#: -> agent cycle -> manifest write is async).
|
||||
GRACE_SECONDS = 20.0
|
||||
|
||||
def __init__(self, conn, job_id: str):
|
||||
self.conn = conn
|
||||
self.id = job_id
|
||||
self._created = time.monotonic()
|
||||
self._seen = False
|
||||
|
||||
def _job(self):
|
||||
for j in self.conn.list_jobs():
|
||||
if j.job_id == self.id:
|
||||
return j
|
||||
return None
|
||||
|
||||
def status(self) -> str:
|
||||
"""pending / running / cancelling / stale, or 'finished' once the
|
||||
job has left the inflight listing."""
|
||||
job = self._job()
|
||||
if job is not None:
|
||||
self._seen = True
|
||||
return job.state
|
||||
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
|
||||
return "pending"
|
||||
return "finished"
|
||||
|
||||
def progress(self) -> "tuple[int, int] | None":
|
||||
"""(units_done, units_total) while running, else None."""
|
||||
job = self._job()
|
||||
if job is not None and job.units_total is not None:
|
||||
return job.units_done or 0, job.units_total
|
||||
return None
|
||||
|
||||
def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
state = self.status()
|
||||
if state in ("finished", "stale"):
|
||||
return state
|
||||
if state == "pending":
|
||||
time.sleep(min(poll, 0.5))
|
||||
continue
|
||||
job = self._job()
|
||||
if job is not None and job.committed:
|
||||
return "finished"
|
||||
time.sleep(poll)
|
||||
raise TimeoutError(f"job {self.id} still {self.status()} after {timeout}s")
|
||||
|
||||
def cancel(self) -> None:
|
||||
self.conn.cancel_job(self.id)
|
||||
Reference in New Issue
Block a user