mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-03 02:50:41 +00:00
feat: async UDF client ergonomics (AsyncConnection/AsyncTable + AsyncView/AsyncJobHandle)
Mirrors the sync ergonomics on the async surface: AsyncConnection create_function(udf, replace=)/create_view/job; AsyncTable.add_computed_column; AsyncView + AsyncJobHandle (await + asyncio.sleep; shared submission-prefix matcher with the sync JobHandle). Decorator + REST routes are shared/already validated; this is the async wrapper layer. Exported from the package root. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .expr import Expr, col, lit, func
|
||||
from .udf import udf, table_udf, Udf, JobHandle, View
|
||||
from .udf import udf, table_udf, Udf, JobHandle, View, AsyncJobHandle, AsyncView
|
||||
from .schema import vector
|
||||
from .table import AsyncTable, Table
|
||||
from ._lancedb import Session
|
||||
@@ -454,6 +454,8 @@ __all__ = [
|
||||
"Udf",
|
||||
"JobHandle",
|
||||
"View",
|
||||
"AsyncJobHandle",
|
||||
"AsyncView",
|
||||
"connect",
|
||||
"connect_async",
|
||||
"connect_namespace",
|
||||
|
||||
@@ -1979,13 +1979,33 @@ class AsyncConnection(object):
|
||||
|
||||
async def create_function(
|
||||
self,
|
||||
name: str,
|
||||
language: str,
|
||||
return_type: str,
|
||||
body: str,
|
||||
name,
|
||||
language: str = "python",
|
||||
return_type: Optional[str] = None,
|
||||
body: Optional[str] = None,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Register a UDF (CREATE FUNCTION)."""
|
||||
"""Register a UDF (CREATE FUNCTION). Accepts a ``@udf``/``@table_udf``
|
||||
object (preferred) or the explicit (name, language, return_type, body,
|
||||
options)."""
|
||||
from .udf import Udf
|
||||
|
||||
if isinstance(name, Udf):
|
||||
req = name.create_request()
|
||||
name, language, return_type, body, options = (
|
||||
req["name"],
|
||||
req["language"],
|
||||
req["return_type"],
|
||||
req["body"],
|
||||
req["options"],
|
||||
)
|
||||
if replace:
|
||||
try:
|
||||
await self.drop_function(name)
|
||||
except Exception:
|
||||
pass
|
||||
await self._inner.create_function(name, language, return_type, body, options)
|
||||
|
||||
async def list_functions(self):
|
||||
@@ -2010,6 +2030,37 @@ class AsyncConnection(object):
|
||||
name, query, auto_refresh=auto_refresh, with_no_data=with_no_data
|
||||
)
|
||||
|
||||
async def create_view(
|
||||
self,
|
||||
name: str,
|
||||
source,
|
||||
select,
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
auto_refresh: bool = False,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Create a materialized view from a source + select items; returns
|
||||
an `AsyncView`. See the sync `create_view` for the select grammar."""
|
||||
from .udf import build_view_query, AsyncView
|
||||
|
||||
query = build_view_query(source, select)
|
||||
if where:
|
||||
query += f" WHERE {where}"
|
||||
if replace:
|
||||
try:
|
||||
await self.drop_materialized_view(name)
|
||||
except Exception:
|
||||
pass
|
||||
await self.create_materialized_view(name, query, auto_refresh=auto_refresh)
|
||||
return AsyncView(self, name)
|
||||
|
||||
def job(self, job_id: str):
|
||||
"""An `AsyncJobHandle` for polling/cancelling an inflight job by id."""
|
||||
from .udf import AsyncJobHandle
|
||||
|
||||
return AsyncJobHandle(self, job_id)
|
||||
|
||||
async def refresh_materialized_view(
|
||||
self,
|
||||
name: str,
|
||||
|
||||
@@ -5546,6 +5546,43 @@ class AsyncTable:
|
||||
await self._inner.add_computed_columns(cols, expression)
|
||||
return result
|
||||
|
||||
async def add_computed_column(
|
||||
self,
|
||||
columns,
|
||||
fn,
|
||||
args: Optional[List[str]] = None,
|
||||
types=None,
|
||||
) -> None:
|
||||
"""Declare computed column(s) bound to a UDF (async). See the sync
|
||||
`Table.add_computed_column`. Server-backed (Enterprise / Cloud)."""
|
||||
from .udf import Udf, struct_field_types
|
||||
|
||||
multi = isinstance(columns, (tuple, list))
|
||||
if isinstance(fn, Udf):
|
||||
expr = fn.expression(*(args or []))
|
||||
if types is None:
|
||||
if multi:
|
||||
if not fn.returns.upper().startswith("STRUCT"):
|
||||
raise ValueError(
|
||||
"several columns need a STRUCT-returning function"
|
||||
)
|
||||
types = struct_field_types(fn.returns)
|
||||
else:
|
||||
types = fn.returns
|
||||
else:
|
||||
if types is None:
|
||||
raise ValueError("pass types= when fn is a name string")
|
||||
expr = f"{fn}({', '.join(args or [])})"
|
||||
if multi:
|
||||
if len(types) != len(columns):
|
||||
raise ValueError(
|
||||
f"{len(columns)} columns but {len(types)} output types"
|
||||
)
|
||||
computed = {c: (t, expr) for c, t in zip(columns, types)}
|
||||
else:
|
||||
computed = {columns: (types, expr)}
|
||||
await self.add_columns(computed=computed)
|
||||
|
||||
async def alter_columns(
|
||||
self, *alterations: Iterable[dict[str, Any]]
|
||||
) -> AlterColumnsResult:
|
||||
|
||||
@@ -29,6 +29,7 @@ remote connection.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import dataclasses
|
||||
import functools
|
||||
@@ -394,6 +395,17 @@ def build_view_query(source, select) -> str:
|
||||
return f"SELECT {', '.join(items)} FROM {src}"
|
||||
|
||||
|
||||
def _job_id_matches(handle_id: str, listed_id: str) -> bool:
|
||||
# The refresh/backfill endpoints return the submission id (a uuid), but
|
||||
# the agent names the manifest job "<table>-<type>-<first 8 of the
|
||||
# submission id>" -- which is what list_jobs and cancel report. Match the
|
||||
# canonical id directly, or by that submission prefix.
|
||||
if listed_id == handle_id:
|
||||
return True
|
||||
prefix = handle_id[:8]
|
||||
return len(prefix) >= 4 and prefix in listed_id
|
||||
|
||||
|
||||
class View:
|
||||
"""A reference to a materialized view (name + connection). View
|
||||
operations are server-backed connection calls bound to the name."""
|
||||
@@ -440,14 +452,7 @@ class JobHandle:
|
||||
self._seen = False
|
||||
|
||||
def _matches(self, listed_id: str) -> bool:
|
||||
# The refresh/backfill endpoints return the submission id (a uuid),
|
||||
# but the agent names the manifest job "<table>-<type>-<first 8 of
|
||||
# the submission id>" -- which is what list_jobs and cancel report.
|
||||
# Match the canonical id directly, or by that submission prefix.
|
||||
if listed_id == self.id:
|
||||
return True
|
||||
prefix = self.id[:8]
|
||||
return len(prefix) >= 4 and prefix in listed_id
|
||||
return _job_id_matches(self.id, listed_id)
|
||||
|
||||
def _job(self):
|
||||
for j in self.conn.list_jobs():
|
||||
@@ -493,3 +498,80 @@ class JobHandle:
|
||||
# via the submission prefix; fall back to the raw id.
|
||||
job = self._job()
|
||||
self.conn.cancel_job(job.job_id if job is not None else self.id)
|
||||
|
||||
|
||||
class AsyncView:
|
||||
"""Async reference to a materialized view (name + async connection)."""
|
||||
|
||||
def __init__(self, conn, name: str):
|
||||
self.conn = conn
|
||||
self.name = name
|
||||
|
||||
async def refresh(self, full: bool = False):
|
||||
if full:
|
||||
raise NotImplementedError(
|
||||
"full=True refresh is not supported yet (engine gap: the "
|
||||
"refresh event has no full-rebuild flag)"
|
||||
)
|
||||
return await self.conn.refresh_materialized_view(self.name)
|
||||
|
||||
async def explain_refresh(self, full: bool = False):
|
||||
return await self.conn.explain_refresh_materialized_view(self.name, full=full)
|
||||
|
||||
async def alter(self, auto_refresh: bool) -> None:
|
||||
await self.conn.alter_materialized_view(self.name, auto_refresh=auto_refresh)
|
||||
|
||||
async def drop(self) -> None:
|
||||
await self.conn.drop_materialized_view(self.name)
|
||||
|
||||
|
||||
class AsyncJobHandle:
|
||||
"""Async reference to an inflight server-side job, with polling helpers."""
|
||||
|
||||
GRACE_SECONDS = 20.0
|
||||
|
||||
def __init__(self, conn, job_id: str):
|
||||
self.conn = conn
|
||||
self.id = job_id
|
||||
self._created = time.monotonic()
|
||||
self._seen = False
|
||||
|
||||
async def _job(self):
|
||||
for j in await self.conn.list_jobs():
|
||||
if _job_id_matches(self.id, j.job_id):
|
||||
return j
|
||||
return None
|
||||
|
||||
async def status(self) -> str:
|
||||
job = await self._job()
|
||||
if job is not None:
|
||||
self._seen = True
|
||||
return job.state
|
||||
if not self._seen and time.monotonic() - self._created < self.GRACE_SECONDS:
|
||||
return "pending"
|
||||
return "finished"
|
||||
|
||||
async def progress(self) -> "tuple[int, int] | None":
|
||||
job = await self._job()
|
||||
if job is not None and job.units_total is not None:
|
||||
return job.units_done or 0, job.units_total
|
||||
return None
|
||||
|
||||
async def wait(self, timeout: float = 3600.0, poll: float = 2.0) -> str:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
state = await self.status()
|
||||
if state in ("finished", "stale"):
|
||||
return state
|
||||
if state == "pending":
|
||||
await asyncio.sleep(min(poll, 0.5))
|
||||
continue
|
||||
job = await self._job()
|
||||
if job is not None and job.committed:
|
||||
return "finished"
|
||||
await asyncio.sleep(poll)
|
||||
raise TimeoutError(f"job {self.id} still {await self.status()} after {timeout}s")
|
||||
|
||||
async def cancel(self) -> None:
|
||||
job = await self._job()
|
||||
await self.conn.cancel_job(job.job_id if job is not None else self.id)
|
||||
|
||||
Reference in New Issue
Block a user