diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 62f80e01e..4c23b6e21 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -704,11 +704,14 @@ def _normalize_progress(progress): def _computed_groups(computed): - """Group {column: (sql_type, expression)} by expression, preserving - declaration order (struct-returning functions need their columns - adjacent so schema order matches field order).""" + """Group computed columns by expression, preserving declaration order + (struct-returning functions need their columns adjacent so schema order + matches field order). Accepts the ergonomic forms -- `fn("col")` values + and tuple keys for struct fan-out -- via `_normalize_computed`.""" + from .udf import _normalize_computed + groups = [] - for name, (sql_type, expression) in computed.items(): + for name, (sql_type, expression) in _normalize_computed(computed).items(): for expr, cols in groups: if expr == expression: cols.append((name, sql_type)) @@ -831,11 +834,23 @@ class Table(ABC): types=None, ) -> None: """Declare computed column(s) bound to a UDF -- no compute happens - here (the agent fills them lazily, or refresh_column() triggers a - run). Sugar over add_columns(computed=): column types come from the - UDF's declared return type (a STRUCT return maps its fields to the - columns positionally); pass `types` when `fn` is a bare name string. - Register the function first. Server-backed (Enterprise / Cloud).""" + here (the agent fills them lazily, or refresh_column() triggers a run). + + .. deprecated:: + A computed column is an expression over a registered function, so + bind it as one: ``add_columns(computed={"vec": embed("data")})``. + ``embed("data")`` applies the function to the `data` column and + infers the type from the function's return signature -- the + function never couples to a particular column. Prefer that form. + """ + import warnings + + warnings.warn( + 'add_computed_column is deprecated; use add_columns(computed=' + '{"vec": embed("data")}).', + DeprecationWarning, + stacklevel=2, + ) from .udf import Udf, struct_field_types multi = isinstance(columns, (tuple, list)) @@ -3770,14 +3785,18 @@ class LanceTable(Table): self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema | None = None, *, - computed: Optional[Dict[str, tuple]] = None, + computed: Optional[Dict] = None, ) -> Optional[AddColumnsResult]: result = None if transforms is not None: result = LOOP.run(self._table.add_columns(transforms)) if computed: - # computed: {column: (sql_type, expression)} -- declares the - # binding only; the server fills the values (server-backed). + # computed binds an expression over a registered function to a + # column: {col: fn("input_col")} -- fn("input_col") yields the + # expression and carries the inferred type; a tuple key fans a + # STRUCT return out to several columns. Declares the binding only; + # the server fills the values (server-backed). The legacy + # {col: (sql_type, expression)} tuple form is still accepted. result_unused = LOOP.run(self._table.add_columns(computed=computed)) del result_unused return result @@ -5508,7 +5527,7 @@ class AsyncTable: self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema | None = None, *, - computed: Optional[Dict[str, tuple]] = None, + computed: Optional[Dict] = None, ) -> Optional[AddColumnsResult]: """ Add new columns with defined values. @@ -5540,8 +5559,12 @@ class AsyncTable: elif transforms is not None: result = await self._inner.add_columns(list(transforms.items())) if computed: - # computed: {column: (sql_type, expression)} -- declares the - # binding only; the server fills the values (server-backed). + # computed binds an expression over a registered function to a + # column: {col: fn("input_col")} -- fn("input_col") yields the + # expression and carries the inferred type; a tuple key fans a + # STRUCT return out to several columns. Declares the binding only; + # the server fills the values (server-backed). The legacy + # {col: (sql_type, expression)} tuple form is still accepted. for expression, cols in _computed_groups(computed): await self._inner.add_computed_columns(cols, expression) return result @@ -5553,8 +5576,21 @@ class AsyncTable: args: Optional[List[str]] = None, types=None, ) -> None: - """Declare computed column(s) bound to a UDF (async). See the sync - `Table.add_computed_column`. Server-backed (Enterprise / Cloud).""" + """Declare computed column(s) bound to a UDF (async). + + .. deprecated:: + Use ``add_columns(computed={"col": fn("input_col")})`` -- a computed + column is an expression over a registered function, so bind it that + way instead of coupling the UDF to the column here. + """ + import warnings + + warnings.warn( + 'add_computed_column is deprecated; use add_columns(computed=' + '{"col": fn("input_col")}).', + DeprecationWarning, + stacklevel=2, + ) from .udf import Udf, struct_field_types multi = isinstance(columns, (tuple, list)) diff --git a/python/python/lancedb/udf.py b/python/python/lancedb/udf.py index 2ac0c2db1..ae38aa069 100644 --- a/python/python/lancedb/udf.py +++ b/python/python/lancedb/udf.py @@ -16,12 +16,16 @@ Register and use them through the existing connection/table API: def embed(text: str) -> list[float]: return model.encode(text).tolist() - db.create_function(embed) # CREATE FUNCTION + db.create_function(embed) # CREATE FUNCTION (once) tbl = db.open_table("docs") - tbl.add_computed_column("vec", embed) # declare (no compute yet) + tbl.add_columns(computed={"vec": embed("text")}) # bind embed(text) -> vec db.job(tbl.refresh_column("vec")).wait() # materialize view = db.create_view("chunks", tbl, ["id", chunk_fn]) +`embed("text")` applies the registered function to the `text` column and yields +the expression `embed(text)`; the function itself stays decoupled from any +column, so the same `embed` works on any column or table. + These operations are server-backed (LanceDB Enterprise / Cloud); the decorator itself works locally (define + call), only registration needs a remote connection. @@ -140,6 +144,68 @@ def param_types(fn) -> "list[tuple[str, str]]": return out +# -- column expressions ------------------------------------------------- + + +class ColumnExpr(str): + """A computed-column expression produced by applying a registered + function to column names, e.g. ``embed("data") -> "embed(data)"``. + + It IS the expression string everywhere a string is expected (views, SQL, + logging), and additionally carries the function's declared return type so + ``add_columns(computed=...)`` can declare the column without a hand-written + type. ``field_types`` holds the per-field SQL types of a STRUCT return, for + fanning one expression out to several columns. + """ + + data_type: "str | None" + field_types: "list[str] | None" + + def __new__(cls, expr: str, data_type=None, field_types=None): + obj = super().__new__(cls, expr) + obj.data_type = data_type + obj.field_types = field_types + return obj + + +def _normalize_computed(computed: dict) -> dict: + """Normalize the user-facing ``computed=`` mapping to the canonical + ``{name: (sql_type, expression)}`` form. + + Accepts, per entry: + - value is a `ColumnExpr` (from ``fn("col")``): the column's SQL type + comes from the function's return type -- no hand-written type needed. A + tuple key (``("chunk", "idx")``) fans a STRUCT return out to one + (type, expression) entry per field, in declared order. + - value is a legacy ``(sql_type, expression)`` tuple: passed through (the + escape hatch, e.g. bare-name function strings). + """ + out: dict = {} + for key, val in computed.items(): + if isinstance(val, ColumnExpr): + expr = str(val) + if isinstance(key, (tuple, list)): + if not val.field_types: + raise ValueError( + f"columns {tuple(key)} need a STRUCT-returning function; " + f"{expr} returns a single value" + ) + if len(val.field_types) != len(key): + raise ValueError( + f"{len(key)} columns but {len(val.field_types)} struct fields " + f"in {expr}" + ) + for name, t in zip(key, val.field_types): + out[name] = (t, expr) + else: + if val.data_type is None: + raise ValueError(f"cannot infer a type for {expr}; pass types=") + out[key] = (val.data_type, expr) + else: + out[key] = val + return out + + # -- the @udf / @table_udf decorators ----------------------------------- @@ -221,14 +287,23 @@ class Udf: def __call__(self, *args, **kwargs): """Call with real values to run locally; call with column-name - strings to build an expression for backfills and views.""" + strings to build an expression for backfills and views, e.g. + ``embed("data")`` -> the expression ``embed(data)`` (a `ColumnExpr` + carrying the function's return type for `add_columns(computed=...)`).""" if args and all(isinstance(a, str) for a in args) and not kwargs: - return f"{self.name}({', '.join(args)})" + return self.expression(*args) return self.fn(*args, **kwargs) - def expression(self, *columns: str) -> str: + def expression(self, *columns: str) -> ColumnExpr: + """The expression applying this function to `columns` (default: the + function's own parameter names). Returns a `ColumnExpr` -- a string + that also carries the declared return type (and struct field types).""" cols = columns or [p for p, _ in self.params] - return f"{self.name}({', '.join(cols)})" + expr = f"{self.name}({', '.join(cols)})" + field_types = None + if self.returns.upper().startswith("STRUCT"): + field_types = struct_field_types(self.returns) + return ColumnExpr(expr, data_type=self.returns, field_types=field_types) def _body(self) -> "tuple[str, str]": """(body literal, body_format). Source when requested and