diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 2359d2a76..2a3b67c93 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -17,7 +17,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection from .remote import ClientConfig from .remote.db import RemoteDBConnection from .expr import Expr, col, lit, func -from .udf import udf, table_udf, Udf, JobHandle, View +from .udf import udf, table_udf, Udf, JobHandle, View, AsyncJobHandle, AsyncView from .schema import vector from .table import AsyncTable, Table from ._lancedb import Session @@ -454,6 +454,8 @@ __all__ = [ "Udf", "JobHandle", "View", + "AsyncJobHandle", + "AsyncView", "connect", "connect_async", "connect_namespace", diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 66e6d1102..cee4af6a1 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -1979,13 +1979,33 @@ class AsyncConnection(object): async def create_function( self, - name: str, - language: str, - return_type: str, - body: str, + name, + language: str = "python", + return_type: Optional[str] = None, + body: Optional[str] = None, options: Optional[Dict[str, str]] = None, + *, + replace: bool = False, ): - """Register a UDF (CREATE FUNCTION).""" + """Register a UDF (CREATE FUNCTION). Accepts a ``@udf``/``@table_udf`` + object (preferred) or the explicit (name, language, return_type, body, + options).""" + from .udf import Udf + + if isinstance(name, Udf): + req = name.create_request() + name, language, return_type, body, options = ( + req["name"], + req["language"], + req["return_type"], + req["body"], + req["options"], + ) + if replace: + try: + await self.drop_function(name) + except Exception: + pass await self._inner.create_function(name, language, return_type, body, options) async def list_functions(self): @@ -2010,6 +2030,37 @@ class AsyncConnection(object): name, query, auto_refresh=auto_refresh, with_no_data=with_no_data ) + async def create_view( + self, + name: str, + source, + select, + *, + where: Optional[str] = None, + auto_refresh: bool = False, + replace: bool = False, + ): + """Create a materialized view from a source + select items; returns + an `AsyncView`. See the sync `create_view` for the select grammar.""" + from .udf import build_view_query, AsyncView + + query = build_view_query(source, select) + if where: + query += f" WHERE {where}" + if replace: + try: + await self.drop_materialized_view(name) + except Exception: + pass + await self.create_materialized_view(name, query, auto_refresh=auto_refresh) + return AsyncView(self, name) + + def job(self, job_id: str): + """An `AsyncJobHandle` for polling/cancelling an inflight job by id.""" + from .udf import AsyncJobHandle + + return AsyncJobHandle(self, job_id) + async def refresh_materialized_view( self, name: str, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 0a76883f0..62f80e01e 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -5546,6 +5546,43 @@ class AsyncTable: await self._inner.add_computed_columns(cols, expression) return result + async def add_computed_column( + self, + columns, + fn, + args: Optional[List[str]] = None, + types=None, + ) -> None: + """Declare computed column(s) bound to a UDF (async). See the sync + `Table.add_computed_column`. Server-backed (Enterprise / Cloud).""" + from .udf import Udf, struct_field_types + + multi = isinstance(columns, (tuple, list)) + if isinstance(fn, Udf): + expr = fn.expression(*(args or [])) + if types is None: + if multi: + if not fn.returns.upper().startswith("STRUCT"): + raise ValueError( + "several columns need a STRUCT-returning function" + ) + types = struct_field_types(fn.returns) + else: + types = fn.returns + else: + if types is None: + raise ValueError("pass types= when fn is a name string") + expr = f"{fn}({', '.join(args or [])})" + if multi: + if len(types) != len(columns): + raise ValueError( + f"{len(columns)} columns but {len(types)} output types" + ) + computed = {c: (t, expr) for c, t in zip(columns, types)} + else: + computed = {columns: (types, expr)} + await self.add_columns(computed=computed) + async def alter_columns( self, *alterations: Iterable[dict[str, Any]] ) -> AlterColumnsResult: diff --git a/python/python/lancedb/udf.py b/python/python/lancedb/udf.py index 84c17929d..2ac0c2db1 100644 --- a/python/python/lancedb/udf.py +++ b/python/python/lancedb/udf.py @@ -29,6 +29,7 @@ remote connection. from __future__ import annotations +import asyncio import base64 import dataclasses import functools @@ -394,6 +395,17 @@ def build_view_query(source, select) -> str: return f"SELECT {', '.join(items)} FROM {src}" +def _job_id_matches(handle_id: str, listed_id: str) -> bool: + # The refresh/backfill endpoints return the submission id (a uuid), but + # the agent names the manifest job "