From b20931b8f74ec5cd9a0f3163e506a7c4a0fe5241 Mon Sep 17 00:00:00 2001 From: Wyatt Alt Date: Sat, 13 Jun 2026 10:11:14 -0700 Subject: [PATCH] feat: async UDF client ergonomics (AsyncConnection/AsyncTable + AsyncView/AsyncJobHandle) Mirrors the sync ergonomics on the async surface: AsyncConnection create_function(udf, replace=)/create_view/job; AsyncTable.add_computed_column; AsyncView + AsyncJobHandle (await + asyncio.sleep; shared submission-prefix matcher with the sync JobHandle). Decorator + REST routes are shared/already validated; this is the async wrapper layer. Exported from the package root. Co-Authored-By: Claude Opus 4.8 (1M context) --- python/python/lancedb/__init__.py | 4 +- python/python/lancedb/db.py | 61 +++++++++++++++++-- python/python/lancedb/table.py | 37 ++++++++++++ python/python/lancedb/udf.py | 98 ++++++++++++++++++++++++++++--- 4 files changed, 186 insertions(+), 14 deletions(-) diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 2359d2a76..2a3b67c93 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection from .remote import ClientConfig from .remote.db import RemoteDBConnection from .expr import Expr, col, lit, func -from .udf import udf, table_udf, Udf, JobHandle, View +from .udf import udf, table_udf, Udf, JobHandle, View, AsyncJobHandle, AsyncView from .schema import vector from .table import AsyncTable, Table from ._lancedb import Session @@ -454,6 +454,8 @@ __all__ = [ "Udf", "JobHandle", "View", + "AsyncJobHandle", + "AsyncView", "connect", "connect_async", "connect_namespace", diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 66e6d1102..cee4af6a1 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -1979,13 +1979,33 @@ class AsyncConnection(object): async def create_function( self, - name: str, - language: str, - return_type: str, - body: str, + name, + language: str = "python", + return_type: Optional[str] = None, + body: Optional[str] = None, options: Optional[Dict[str, str]] = None, + *, + replace: bool = False, ): - """Register a UDF (CREATE FUNCTION).""" + """Register a UDF (CREATE FUNCTION). Accepts a ``@udf``/``@table_udf`` + object (preferred) or the explicit (name, language, return_type, body, + options).""" + from .udf import Udf + + if isinstance(name, Udf): + req = name.create_request() + name, language, return_type, body, options = ( + req["name"], + req["language"], + req["return_type"], + req["body"], + req["options"], + ) + if replace: + try: + await self.drop_function(name) + except Exception: + pass await self._inner.create_function(name, language, return_type, body, options) async def list_functions(self): @@ -2010,6 +2030,37 @@ class AsyncConnection(object): name, query, auto_refresh=auto_refresh, with_no_data=with_no_data ) + async def create_view( + self, + name: str, + source, + select, + *, + where: Optional[str] = None, + auto_refresh: bool = False, + replace: bool = False, + ): + """Create a materialized view from a source + select items; returns + an `AsyncView`. See the sync `create_view` for the select grammar.""" + from .udf import build_view_query, AsyncView + + query = build_view_query(source, select) + if where: + query += f" WHERE {where}" + if replace: + try: + await self.drop_materialized_view(name) + except Exception: + pass + await self.create_materialized_view(name, query, auto_refresh=auto_refresh) + return AsyncView(self, name) + + def job(self, job_id: str): + """An `AsyncJobHandle` for polling/cancelling an inflight job by id.""" + from .udf import AsyncJobHandle + + return AsyncJobHandle(self, job_id) + async def refresh_materialized_view( self, name: str, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 0a76883f0..62f80e01e 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -5546,6 +5546,43 @@ class AsyncTable: await self._inner.add_computed_columns(cols, expression) return result + async def add_computed_column( + self, + columns, + fn, + args: Optional[List[str]] = None, + types=None, + ) -> None: + """Declare computed column(s) bound to a UDF (async). See the sync + `Table.add_computed_column`. Server-backed (Enterprise / Cloud).""" + from .udf import Udf, struct_field_types + + multi = isinstance(columns, (tuple, list)) + if isinstance(fn, Udf): + expr = fn.expression(*(args or [])) + if types is None: + if multi: + if not fn.returns.upper().startswith("STRUCT"): + raise ValueError( + "several columns need a STRUCT-returning function" + ) + types = struct_field_types(fn.returns) + else: + types = fn.returns + else: + if types is None: + raise ValueError("pass types= when fn is a name string") + expr = f"{fn}({', '.join(args or [])})" + if multi: + if len(types) != len(columns): + raise ValueError( + f"{len(columns)} columns but {len(types)} output types" + ) + computed = {c: (t, expr) for c, t in zip(columns, types)} + else: + computed = {columns: (types, expr)} + await self.add_columns(computed=computed) + async def alter_columns( self, *alterations: Iterable[dict[str, Any]] ) -> AlterColumnsResult: diff --git a/python/python/lancedb/udf.py b/python/python/lancedb/udf.py index 84c17929d..2ac0c2db1 100644 --- a/python/python/lancedb/udf.py +++ b/python/python/lancedb/udf.py @@ -29,6 +29,7 @@ remote connection. from __future__ import annotations +import asyncio import base64 import dataclasses import functools @@ -394,6 +395,17 @@ def build_view_query(source, select) -> str: return f"SELECT {', '.join(items)} FROM {src}" +def _job_id_matches(handle_id: str, listed_id: str) -> bool: + # The refresh/backfill endpoints return the submission id (a uuid), but + # the agent names the manifest job "--" -- which is what list_jobs and cancel report. Match the + # canonical id directly, or by that submission prefix. + if listed_id == handle_id: + return True + prefix = handle_id[:8] + return len(prefix) >= 4 and prefix in listed_id + + class View: """A reference to a materialized view (name + connection). View operations are server-backed connection calls bound to the name.""" @@ -440,14 +452,7 @@ class JobHandle: self._seen = False def _matches(self, listed_id: str) -> bool: - # The refresh/backfill endpoints return the submission id (a uuid), - # but the agent names the manifest job "
--" -- which is what list_jobs and cancel report. - # Match the canonical id directly, or by that submission prefix. - if listed_id == self.id: - return True - prefix = self.id[:8] - return len(prefix) >= 4 and prefix in listed_id + return _job_id_matches(self.id, listed_id) def _job(self): for j in self.conn.list_jobs(): @@ -493,3 +498,80 @@ class JobHandle: # via the submission prefix; fall back to the raw id. job = self._job() self.conn.cancel_job(job.job_id if job is not None else self.id) + + +class AsyncView: + """Async reference to a materialized view (name + async connection).""" + + def __init__(self, conn, name: str): + self.conn = conn + self.name = name + + async def refresh(self, full: bool = False): + if full: + raise NotImplementedError( + "full=True refresh is not supported yet (engine gap: the " + "refresh event has no full-rebuild flag)" + ) + return await self.conn.refresh_materialized_view(self.name) + + async def explain_refresh(self, full: bool = False): + return await self.conn.explain_refresh_materialized_view(self.name, full=full) + + async def alter(self, auto_refresh: bool) -> None: + await self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh) + + async def drop(self) -> None: + await self.conn.drop_materialized_view(self.name) + + +class AsyncJobHandle: + """Async reference to an inflight server-side job, with polling helpers.""" + + GRACE_SECONDS = 20.0 + + def __init__(self, conn, job_id: str): + self.conn = conn + self.id = job_id + self._created = time.monotonic() + self._seen = False + + async def _job(self): + for j in await self.conn.list_jobs(): + if _job_id_matches(self.id, j.job_id): + return j + return None + + async def status(self) -> str: + job = await self._job() + if job is not None: + self._seen = True + return job.state + if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS: + return "pending" + return "finished" + + async def progress(self) -> "tuple[int, int] | None": + job = await self._job() + if job is not None and job.units_total is not None: + return job.units_done or 0, job.units_total + return None + + async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + state = await self.status() + if state in ("finished", "stale"): + return state + if state == "pending": + await asyncio.sleep(min(poll, 0.5)) + continue + job = await self._job() + if job is not None and job.committed: + return "finished" + await asyncio.sleep(poll) + raise TimeoutError(f"job {self.id} still {await self.status()} after {timeout}s") + + async def cancel(self) -> None: + job = await self._job() + await self.conn.cancel_job(job.job_id if job is not None else self.id)