mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-02 02:20:41 +00:00
feat: computed columns as a param on add_columns
Per the interface design: computed columns are parameters on the
existing add_columns operation, not a separate method.
- BaseTable::add_computed_columns((name, sql_type) pairs + a f(args)
expression) -- default NotSupported; RemoteTable posts 'computed'
entries to the existing /v1/table/{id}/add_columns route.
- python add_columns gains computed= on LanceTable, RemoteTable, and
AsyncTable: tbl.add_columns(computed={'doubled': ('FLOAT',
'double_it(val)')}); grouped by expression so struct-returning
functions' columns land adjacently.
This commit is contained in:
@@ -884,8 +884,18 @@ class RemoteTable(Table):
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
def add_columns(
|
||||
self,
|
||||
transforms: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
computed: Optional[Dict[str, tuple]] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
result = None
|
||||
if transforms is not None:
|
||||
result = LOOP.run(self._table.add_columns(transforms))
|
||||
if computed:
|
||||
LOOP.run(self._table.add_columns(computed=computed))
|
||||
return result
|
||||
|
||||
def refresh_column(
|
||||
self,
|
||||
|
||||
@@ -702,6 +702,22 @@ def _normalize_progress(progress):
|
||||
return progress, False
|
||||
|
||||
|
||||
|
||||
def _computed_groups(computed):
|
||||
"""Group {column: (sql_type, expression)} by expression, preserving
|
||||
declaration order (struct-returning functions need their columns
|
||||
adjacent so schema order matches field order)."""
|
||||
groups = []
|
||||
for name, (sql_type, expression) in computed.items():
|
||||
for expr, cols in groups:
|
||||
if expr == expression:
|
||||
cols.append((name, sql_type))
|
||||
break
|
||||
else:
|
||||
groups.append((expression, [(name, sql_type)]))
|
||||
return groups
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
"""
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
@@ -3710,9 +3726,20 @@ class LanceTable(Table):
|
||||
return LOOP.run(self._table.index_stats(index_name))
|
||||
|
||||
def add_columns(
|
||||
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
self,
|
||||
transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema | None = None,
|
||||
*,
|
||||
computed: Optional[Dict[str, tuple]] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
result = None
|
||||
if transforms is not None:
|
||||
result = LOOP.run(self._table.add_columns(transforms))
|
||||
if computed:
|
||||
# computed: {column: (sql_type, expression)} -- declares the
|
||||
# binding only; the server fills the values (server-backed).
|
||||
result_unused = LOOP.run(self._table.add_columns(computed=computed))
|
||||
del result_unused
|
||||
return result
|
||||
|
||||
def refresh_column(
|
||||
self,
|
||||
@@ -5437,8 +5464,11 @@ class AsyncTable:
|
||||
)
|
||||
|
||||
async def add_columns(
|
||||
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
) -> AddColumnsResult:
|
||||
self,
|
||||
transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema | None = None,
|
||||
*,
|
||||
computed: Optional[Dict[str, tuple]] = None,
|
||||
) -> Optional[AddColumnsResult]:
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -5457,6 +5487,7 @@ class AsyncTable:
|
||||
version: the new version number of the table after adding columns.
|
||||
|
||||
"""
|
||||
result = None
|
||||
if isinstance(transforms, pa.Field):
|
||||
transforms = [transforms]
|
||||
if isinstance(transforms, list) and all(
|
||||
@@ -5464,9 +5495,15 @@ class AsyncTable:
|
||||
):
|
||||
transforms = pa.schema(transforms)
|
||||
if isinstance(transforms, pa.Schema):
|
||||
return await self._inner.add_columns_with_schema(transforms)
|
||||
else:
|
||||
return await self._inner.add_columns(list(transforms.items()))
|
||||
result = await self._inner.add_columns_with_schema(transforms)
|
||||
elif transforms is not None:
|
||||
result = await self._inner.add_columns(list(transforms.items()))
|
||||
if computed:
|
||||
# computed: {column: (sql_type, expression)} -- declares the
|
||||
# binding only; the server fills the values (server-backed).
|
||||
for expression, cols in _computed_groups(computed):
|
||||
await self._inner.add_computed_columns(cols, expression)
|
||||
return result
|
||||
|
||||
async def alter_columns(
|
||||
self, *alterations: Iterable[dict[str, Any]]
|
||||
|
||||
@@ -1060,6 +1060,20 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_computed_columns(
|
||||
self_: PyRef<'_, Self>,
|
||||
columns: Vec<(String, String)>,
|
||||
expression: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.add_computed_columns(&columns, &expression)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (columns, where_clause=None, num_workers=None, max_workers=None))]
|
||||
pub fn refresh_column(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
Reference in New Issue
Block a user