diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 4a39ce6f9..9c909d341 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -932,6 +932,70 @@ class RemoteTable(Table): ) ) + def load_columns( + self, + source: Union[str, Iterable[str]], + pk: str, + columns: Union[Iterable[str], Dict[str, str]], + *, + source_format: str = "parquet", + source_pk: Optional[str] = None, + on_missing: str = "carry", + source_storage_options: Optional[Dict[str, str]] = None, + num_workers: Optional[int] = None, + max_workers: Optional[int] = None, + batch_size: Optional[int] = None, + commit_granularity: Optional[int] = None, + priority: Optional[str] = None, + ) -> str: + """Fill existing columns from an external source by primary-key join. + + The distributed-job equivalent of Geneva's ``Table.load_columns()``: + imports precomputed values (e.g. embeddings) from Parquet/Lance/IPC into + this table, matching on a primary key. Returns the load job id. + Server-backed feature (LanceDB Enterprise / Cloud). + + Parameters + ---------- + source: str | list[str] + One source URI or a list of URIs. + pk: str + Destination primary-key column. Also the source key unless + ``source_pk`` is given. + columns: list[str] | dict[str, str] + Value columns to load. A list loads same-named columns; a dict maps + ``{target: source}``. + source_format: str + ``"parquet"`` (default), ``"lance"``, or ``"ipc"``. + source_pk: str, optional + Source primary-key column when it differs from ``pk``. + on_missing: str + Behavior for destination rows with no source match: + ``"carry"`` (default, keep existing), ``"null"``, or ``"error"``. + """ + if isinstance(source, str): + source = [source] + if isinstance(columns, dict): + mappings = [(target, src) for target, src in columns.items()] + else: + mappings = [(c, None) for c in columns] + return LOOP.run( + self._table.load_columns( + list(source), + source_format, + pk, + mappings, + source_key=source_pk, + source_storage_options=source_storage_options, + on_missing=on_missing, + num_workers=num_workers, + max_workers=max_workers, + batch_size=batch_size, + commit_granularity=commit_granularity, + priority=priority, + ) + ) + def alter_columns( self, *alterations: Iterable[Dict[str, str]] ) -> AlterColumnsResult: diff --git a/python/src/table.rs b/python/src/table.rs index ff203e86e..8cab57722 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -17,8 +17,8 @@ use arrow::{ pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}, }; use lancedb::table::{ - AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, NewColumnTransform, - OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable, + AddDataMode, ColumnAlteration, Duration, FieldMetadataUpdate, LoadColumnsRequest, + NewColumnTransform, OptimizeAction, OptimizeOptions, Ref, Table as LanceDbTable, }; use pyo3::{ Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python, @@ -1100,6 +1100,43 @@ impl Table { }) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (source_uris, source_format, target_key, columns, source_key=None, source_storage_options=None, on_missing=None, num_workers=None, max_workers=None, batch_size=None, commit_granularity=None, priority=None))] + pub fn load_columns( + self_: PyRef<'_, Self>, + source_uris: Vec, + source_format: String, + target_key: String, + columns: Vec<(String, Option)>, + source_key: Option, + source_storage_options: Option>, + on_missing: Option, + num_workers: Option, + max_workers: Option, + batch_size: Option, + commit_granularity: Option, + priority: Option, + ) -> PyResult> { + let inner = self_.inner_ref()?.clone(); + let request = LoadColumnsRequest { + source_uris, + source_format, + source_storage_options, + target_key, + source_key, + columns, + on_missing, + num_workers, + max_workers, + batch_size, + commit_granularity, + priority, + }; + future_into_py(self_.py(), async move { + inner.load_columns(request).await.infer_error() + }) + } + pub fn add_columns( self_: PyRef<'_, Self>, definitions: Vec<(String, String)>, diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 9d95e3930..7ddfe24bc 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -2371,6 +2371,64 @@ impl BaseTable for RemoteTable { Ok(body.job_id) } + async fn load_columns(&self, request: crate::table::LoadColumnsRequest) -> Result { + let columns: Vec = request + .columns + .iter() + .map(|(target, source)| { + serde_json::json!({ + "target": target, + "source": source.clone().unwrap_or_else(|| target.clone()), + }) + }) + .collect(); + let mut source = serde_json::json!({ + "uris": request.source_uris, + "format": request.source_format, + }); + if let Some(opts) = request.source_storage_options { + source["storage_options"] = serde_json::to_value(opts).unwrap_or_default(); + } + let mut body = serde_json::json!({ + "columns": columns, + "source": source, + "target_key": request.target_key, + }); + if let Some(k) = request.source_key { + body["source_key"] = serde_json::Value::String(k); + } + if let Some(m) = request.on_missing { + body["on_missing"] = serde_json::Value::String(m); + } + if let Some(n) = request.num_workers { + body["num_workers"] = n.into(); + } + if let Some(n) = request.max_workers { + body["max_workers"] = n.into(); + } + if let Some(n) = request.batch_size { + body["batch_size"] = n.into(); + } + if let Some(n) = request.commit_granularity { + body["commit_granularity"] = n.into(); + } + if let Some(p) = request.priority { + body["priority"] = serde_json::Value::String(p); + } + let http_request = self + .client + .post(&format!("/v1/table/{}/load_columns", self.identifier)) + .json(&body); + let (request_id, response) = self.send(http_request, true).await?; + let response = self.check_table_response(&request_id, response).await?; + #[derive(serde::Deserialize)] + struct LoadColumnsResponse { + job_id: String, + } + let body: LoadColumnsResponse = response.json().await.err_to_http(request_id)?; + Ok(body.job_id) + } + async fn add_columns( &self, transforms: NewColumnTransform, @@ -2887,6 +2945,51 @@ mod tests { assert_eq!(job_id, "j-9"); } + #[tokio::test] + async fn test_load_columns() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/load_columns"); + let body: serde_json::Value = + serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap(); + assert_eq!( + body["columns"], + serde_json::json!([{"target": "embedding", "source": "emb"}]) + ); + assert_eq!(body["source"]["format"], "parquet"); + assert_eq!( + body["source"]["uris"], + serde_json::json!(["s3://b/x.parquet"]) + ); + assert_eq!(body["target_key"], "document_id"); + assert_eq!(body["source_key"], "doc_id"); + assert_eq!(body["on_missing"], "null"); + assert_eq!(body["num_workers"], 4); + + http::Response::builder() + .status(202) + .body(r#"{"job_id":"lc-7"}"#) + .unwrap() + }); + + let request = crate::table::LoadColumnsRequest { + source_uris: vec!["s3://b/x.parquet".to_string()], + source_format: "parquet".to_string(), + source_storage_options: None, + target_key: "document_id".to_string(), + source_key: Some("doc_id".to_string()), + columns: vec![("embedding".to_string(), Some("emb".to_string()))], + on_missing: Some("null".to_string()), + num_workers: Some(4), + max_workers: None, + batch_size: None, + commit_granularity: None, + priority: None, + }; + let job_id = table.load_columns(request).await.unwrap(); + assert_eq!(job_id, "lc-7"); + } + #[tokio::test] async fn test_version() { let table = Table::new_with_handler("my_table", |request| { diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index b12f70a4e..695814086 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -471,6 +471,33 @@ impl LsmWriteSpec { } } +/// Request to fill existing table columns from an external source by +/// primary-key join (Geneva `Table.load_columns()` parity). Server-backed +/// feature (LanceDB Enterprise / Cloud). +#[derive(Debug, Clone)] +pub struct LoadColumnsRequest { + /// External source URIs. + pub source_uris: Vec, + /// Source format: "parquet" | "lance" | "ipc". + pub source_format: String, + /// Source-only storage options (e.g. cloud credentials). + pub source_storage_options: Option>, + /// Destination primary-key column. + pub target_key: String, + /// Source primary-key column. Defaults to `target_key` when None. + pub source_key: Option, + /// Value column mappings as `(target, source)`; a None source defaults to + /// the target name. + pub columns: Vec<(String, Option)>, + /// Missing-row policy: "carry" (default) | "null" | "error". + pub on_missing: Option, + pub num_workers: Option, + pub max_workers: Option, + pub batch_size: Option, + pub commit_granularity: Option, + pub priority: Option, +} + /// A trait for anything "table-like". This is used for both native tables (which target /// Lance datasets) and remote tables (which target LanceDB cloud) /// @@ -653,6 +680,14 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { message: "refresh_column is not supported by this table".into(), }) } + /// Fill existing columns from an external source by primary-key join + /// (Geneva `load_columns`). Returns the load job id. Server-backed feature; + /// the default returns NotSupported. + async fn load_columns(&self, _request: LoadColumnsRequest) -> Result { + Err(Error::NotSupported { + message: "load_columns is not supported by this table".into(), + }) + } /// Alter columns in the table. async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result; /// Drop columns from the table. @@ -1530,6 +1565,12 @@ impl Table { .await } + /// Fill existing columns from an external Parquet/Lance/IPC source by + /// primary-key join (Geneva `Table.load_columns()`). Returns the job id. + pub async fn load_columns(&self, request: LoadColumnsRequest) -> Result { + self.inner.load_columns(request).await + } + /// Change a column's name or nullability. pub async fn alter_columns( &self,