feat: async UDF client ergonomics (AsyncConnection/AsyncTable + AsyncView/AsyncJobHandle)

Mirrors the sync ergonomics on the async surface: AsyncConnection
create_function(udf, replace=)/create_view/job; AsyncTable.add_computed_column;
AsyncView + AsyncJobHandle (await + asyncio.sleep; shared submission-prefix
matcher with the sync JobHandle). Decorator + REST routes are shared/already
validated; this is the async wrapper layer. Exported from the package root.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Wyatt Alt
2026-06-13 10:11:14 -07:00
committed by Jack Ye
parent 396d68e490
commit b20931b8f7
4 changed files with 186 additions and 14 deletions

View File

@@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote import ClientConfig
from .remote.db import RemoteDBConnection
from .expr import Expr, col, lit, func
from .udf import udf, table_udf, Udf, JobHandle, View
from .udf import udf, table_udf, Udf, JobHandle, View, AsyncJobHandle, AsyncView
from .schema import vector
from .table import AsyncTable, Table
from ._lancedb import Session
@@ -454,6 +454,8 @@ __all__ = [
"Udf",
"JobHandle",
"View",
"AsyncJobHandle",
"AsyncView",
"connect",
"connect_async",
"connect_namespace",

View File

@@ -1979,13 +1979,33 @@ class AsyncConnection(object):
async def create_function(
self,
name: str,
language: str,
return_type: str,
body: str,
name,
language: str = "python",
return_type: Optional[str] = None,
body: Optional[str] = None,
options: Optional[Dict[str, str]] = None,
*,
replace: bool = False,
):
"""Register a UDF (CREATE FUNCTION)."""
"""Register a UDF (CREATE FUNCTION). Accepts a ``@udf``/``@table_udf``
object (preferred) or the explicit (name, language, return_type, body,
options)."""
from .udf import Udf
if isinstance(name, Udf):
req = name.create_request()
name, language, return_type, body, options = (
req["name"],
req["language"],
req["return_type"],
req["body"],
req["options"],
)
if replace:
try:
await self.drop_function(name)
except Exception:
pass
await self._inner.create_function(name, language, return_type, body, options)
async def list_functions(self):
@@ -2010,6 +2030,37 @@ class AsyncConnection(object):
name, query, auto_refresh=auto_refresh, with_no_data=with_no_data
)
async def create_view(
self,
name: str,
source,
select,
*,
where: Optional[str] = None,
auto_refresh: bool = False,
replace: bool = False,
):
"""Create a materialized view from a source + select items; returns
an `AsyncView`. See the sync `create_view` for the select grammar."""
from .udf import build_view_query, AsyncView
query = build_view_query(source, select)
if where:
query += f" WHERE {where}"
if replace:
try:
await self.drop_materialized_view(name)
except Exception:
pass
await self.create_materialized_view(name, query, auto_refresh=auto_refresh)
return AsyncView(self, name)
def job(self, job_id: str):
"""An `AsyncJobHandle` for polling/cancelling an inflight job by id."""
from .udf import AsyncJobHandle
return AsyncJobHandle(self, job_id)
async def refresh_materialized_view(
self,
name: str,

View File

@@ -5546,6 +5546,43 @@ class AsyncTable:
await self._inner.add_computed_columns(cols, expression)
return result
async def add_computed_column(
self,
columns,
fn,
args: Optional[List[str]] = None,
types=None,
) -> None:
"""Declare computed column(s) bound to a UDF (async). See the sync
`Table.add_computed_column`. Server-backed (Enterprise / Cloud)."""
from .udf import Udf, struct_field_types
multi = isinstance(columns, (tuple, list))
if isinstance(fn, Udf):
expr = fn.expression(*(args or []))
if types is None:
if multi:
if not fn.returns.upper().startswith("STRUCT"):
raise ValueError(
"several columns need a STRUCT-returning function"
)
types = struct_field_types(fn.returns)
else:
types = fn.returns
else:
if types is None:
raise ValueError("pass types= when fn is a name string")
expr = f"{fn}({', '.join(args or [])})"
if multi:
if len(types) != len(columns):
raise ValueError(
f"{len(columns)} columns but {len(types)} output types"
)
computed = {c: (t, expr) for c, t in zip(columns, types)}
else:
computed = {columns: (types, expr)}
await self.add_columns(computed=computed)
async def alter_columns(
self, *alterations: Iterable[dict[str, Any]]
) -> AlterColumnsResult:

View File

@@ -29,6 +29,7 @@ remote connection.
from __future__ import annotations
import asyncio
import base64
import dataclasses
import functools
@@ -394,6 +395,17 @@ def build_view_query(source, select) -> str:
return f"SELECT {', '.join(items)} FROM {src}"
def _job_id_matches(handle_id: str, listed_id: str) -> bool:
# The refresh/backfill endpoints return the submission id (a uuid), but
# the agent names the manifest job "<table>-<type>-<first 8 of the
# submission id>" -- which is what list_jobs and cancel report. Match the
# canonical id directly, or by that submission prefix.
if listed_id == handle_id:
return True
prefix = handle_id[:8]
return len(prefix) >= 4 and prefix in listed_id
class View:
"""A reference to a materialized view (name + connection). View
operations are server-backed connection calls bound to the name."""
@@ -440,14 +452,7 @@ class JobHandle:
self._seen = False
def _matches(self, listed_id: str) -> bool:
# The refresh/backfill endpoints return the submission id (a uuid),
# but the agent names the manifest job "<table>-<type>-<first 8 of
# the submission id>" -- which is what list_jobs and cancel report.
# Match the canonical id directly, or by that submission prefix.
if listed_id == self.id:
return True
prefix = self.id[:8]
return len(prefix) >= 4 and prefix in listed_id
return _job_id_matches(self.id, listed_id)
def _job(self):
for j in self.conn.list_jobs():
@@ -493,3 +498,80 @@ class JobHandle:
# via the submission prefix; fall back to the raw id.
job = self._job()
self.conn.cancel_job(job.job_id if job is not None else self.id)
class AsyncView:
"""Async reference to a materialized view (name + async connection)."""
def __init__(self, conn, name: str):
self.conn = conn
self.name = name
async def refresh(self, full: bool = False):
if full:
raise NotImplementedError(
"full=True refresh is not supported yet (engine gap: the "
"refresh event has no full-rebuild flag)"
)
return await self.conn.refresh_materialized_view(self.name)
async def explain_refresh(self, full: bool = False):
return await self.conn.explain_refresh_materialized_view(self.name, full=full)
async def alter(self, auto_refresh: bool) -> None:
await self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
async def drop(self) -> None:
await self.conn.drop_materialized_view(self.name)
class AsyncJobHandle:
"""Async reference to an inflight server-side job, with polling helpers."""
GRACE_SECONDS = 20.0
def __init__(self, conn, job_id: str):
self.conn = conn
self.id = job_id
self._created = time.monotonic()
self._seen = False
async def _job(self):
for j in await self.conn.list_jobs():
if _job_id_matches(self.id, j.job_id):
return j
return None
async def status(self) -> str:
job = await self._job()
if job is not None:
self._seen = True
return job.state
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
return "pending"
return "finished"
async def progress(self) -> "tuple[int, int] | None":
job = await self._job()
if job is not None and job.units_total is not None:
return job.units_done or 0, job.units_total
return None
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
state = await self.status()
if state in ("finished", "stale"):
return state
if state == "pending":
await asyncio.sleep(min(poll, 0.5))
continue
job = await self._job()
if job is not None and job.committed:
return "finished"
await asyncio.sleep(poll)
raise TimeoutError(f"job {self.id} still {await self.status()} after {timeout}s")
async def cancel(self) -> None:
job = await self._job()
await self.conn.cancel_job(job.job_id if job is not None else self.id)