diff --git a/docs/src/guides/tables.md b/docs/src/guides/tables.md index 1f9a1795..13cced31 100644 --- a/docs/src/guides/tables.md +++ b/docs/src/guides/tables.md @@ -790,6 +790,101 @@ Use the `drop_table()` method on the database to remove a table. This permanently removes the table and is not recoverable, unlike deleting rows. If the table does not exist an exception is raised. +## Changing schemas + +While tables must have a schema specified when they are created, you can +change the schema over time. There's three methods to alter the schema of +a table: + +* `add_columns`: Add new columns to the table +* `alter_columns`: Alter the name, nullability, or data type of a column +* `drop_columns`: Drop columns from the table + +### Adding new columns + +You can add new columns to the table with the `add_columns` method. New columns +are filled with values based on a SQL expression. For example, you can add a new +column `y` to the table and fill it with the value of `x + 1`. + +=== "Python" + + ```python + table.add_columns({"double_price": "price * 2"}) + ``` + **API Reference:** [lancedb.table.Table.add_columns][] + +=== "Typescript" + + ```typescript + --8<-- "nodejs/examples/basic.test.ts:add_columns" + ``` + **API Reference:** [lancedb.Table.addColumns](../js/classes/Table.md/#addcolumns) + +If you want to fill it with null, you can use `cast(NULL as )` as +the SQL expression to fill the column with nulls, while controlling the data +type of the column. Available data types are base on the +[DataFusion data types](https://datafusion.apache.org/user-guide/sql/data_types.html). +You can use any of the SQL types, such as `BIGINT`: + +```sql +cast(NULL as BIGINT) +``` + +Using Arrow data types and the `arrow_typeof` function is not yet supported. + + + +### Altering existing columns + +You can alter the name, nullability, or data type of a column with the `alter_columns` +method. + +Changing the name or nullability of a column just updates the metadata. Because +of this, it's a fast operation. Changing the data type of a column requires +rewriting the column, which can be a heavy operation. + +=== "Python" + + ```python + import pyarrow as pa + table.alter_column({"path": "double_price", "rename": "dbl_price", + "data_type": pa.float32(), "nullable": False}) + ``` + **API Reference:** [lancedb.table.Table.alter_columns][] + +=== "Typescript" + + ```typescript + --8<-- "nodejs/examples/basic.test.ts:alter_columns" + ``` + **API Reference:** [lancedb.Table.alterColumns](../js/classes/Table.md/#altercolumns) + +### Dropping columns + +You can drop columns from the table with the `drop_columns` method. This will +will remove the column from the schema. + + + +=== "Python" + + ```python + table.drop_columns(["dbl_price"]) + ``` + **API Reference:** [lancedb.table.Table.drop_columns][] + +=== "Typescript" + + ```typescript + --8<-- "nodejs/examples/basic.test.ts:drop_columns" + ``` + **API Reference:** [lancedb.Table.dropColumns](../js/classes/Table.md/#altercolumns) + + ## Handling bad vectors In LanceDB Python, you can use the `on_bad_vectors` parameter to choose how diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 456fda15..fd146169 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -825,6 +825,18 @@ describe("schema evolution", function () { new Field("price", new Float64(), true), ]); expect(await table.schema()).toEqual(expectedSchema); + + await table.alterColumns([{ path: "new_id", dataType: "int32" }]); + const expectedSchema2 = new Schema([ + new Field("new_id", new Int32(), true), + new Field( + "vector", + new FixedSizeList(2, new Field("item", new Float32(), true)), + true, + ), + new Field("price", new Float64(), true), + ]); + expect(await table.schema()).toEqual(expectedSchema2); }); it("can drop a column from the schema", async function () { diff --git a/nodejs/examples/basic.test.ts b/nodejs/examples/basic.test.ts index 829b446e..d0485142 100644 --- a/nodejs/examples/basic.test.ts +++ b/nodejs/examples/basic.test.ts @@ -116,6 +116,26 @@ test("basic table examples", async () => { await tbl.add(data); // --8<-- [end:add_data] } + + { + // --8<-- [start:add_columns] + await tbl.addColumns([{ name: "double_price", valueSql: "price * 2" }]); + // --8<-- [end:add_columns] + // --8<-- [start:alter_columns] + await tbl.alterColumns([ + { + path: "double_price", + rename: "dbl_price", + dataType: "float", + nullable: true, + }, + ]); + // --8<-- [end:alter_columns] + // --8<-- [start:drop_columns] + await tbl.dropColumns(["dbl_price"]); + // --8<-- [end:drop_columns] + } + { // --8<-- [start:vector_search] const res = await tbl.search([100, 100]).limit(2).toArray(); diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index a52f9fbc..5a9f4298 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -178,16 +178,20 @@ impl Table { #[napi(catch_unwind)] pub async fn alter_columns(&self, alterations: Vec) -> napi::Result<()> { for alteration in &alterations { - if alteration.rename.is_none() && alteration.nullable.is_none() { + if alteration.rename.is_none() + && alteration.nullable.is_none() + && alteration.data_type.is_none() + { return Err(napi::Error::from_reason( - "Alteration must have a 'rename' or 'nullable' field.", + "Alteration must have a 'rename', 'dataType', or 'nullable' field.", )); } } let alterations = alterations .into_iter() - .map(LanceColumnAlteration::from) - .collect::>(); + .map(LanceColumnAlteration::try_from) + .collect::, String>>() + .map_err(napi::Error::from_reason)?; self.inner_ref()? .alter_columns(&alterations) @@ -433,24 +437,43 @@ pub struct ColumnAlteration { /// The new name of the column. If not provided then the name will not be changed. /// This must be distinct from the names of all other columns in the table. pub rename: Option, + /// A new data type for the column. If not provided then the data type will not be changed. + /// Changing data types is limited to casting to the same general type. For example, these + /// changes are valid: + /// * `int32` -> `int64` (integers) + /// * `double` -> `float` (floats) + /// * `string` -> `large_string` (strings) + /// But these changes are not: + /// * `int32` -> `double` (mix integers and floats) + /// * `string` -> `int32` (mix strings and integers) + pub data_type: Option, /// Set the new nullability. Note that a nullable column cannot be made non-nullable. pub nullable: Option, } -impl From for LanceColumnAlteration { - fn from(js: ColumnAlteration) -> Self { +impl TryFrom for LanceColumnAlteration { + type Error = String; + fn try_from(js: ColumnAlteration) -> std::result::Result { let ColumnAlteration { path, rename, nullable, + data_type, } = js; - Self { + let data_type = if let Some(data_type) = data_type { + Some( + lancedb::utils::string_to_datatype(&data_type) + .ok_or_else(|| format!("Invalid data type: {}", data_type))?, + ) + } else { + None + }; + Ok(Self { path, rename, nullable, - // TODO: wire up this field - data_type: None, - } + data_type, + }) } } diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 1dc0fbaa..9fb743c2 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -490,19 +490,13 @@ class RemoteTable(Table): return LOOP.run(self._table.count_rows(filter)) def add_columns(self, transforms: Dict[str, str]): - raise NotImplementedError( - "add_columns() is not yet supported on the LanceDB cloud" - ) + return LOOP.run(self._table.add_columns(transforms)) - def alter_columns(self, alterations: Iterable[Dict[str, str]]): - raise NotImplementedError( - "alter_columns() is not yet supported on the LanceDB cloud" - ) + def alter_columns(self, *alterations: Iterable[Dict[str, str]]): + return LOOP.run(self._table.alter_columns(*alterations)) def drop_columns(self, columns: Iterable[str]): - raise NotImplementedError( - "drop_columns() is not yet supported on the LanceDB cloud" - ) + return LOOP.run(self._table.drop_columns(columns)) def add_index(tbl: pa.Table, i: int) -> pa.Table: diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 15c67dc0..4edb4aa3 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -967,8 +967,6 @@ class Table(ABC): """ Add new columns with defined values. - This is not yet available in LanceDB Cloud. - Parameters ---------- transforms: Dict[str, str] @@ -978,20 +976,21 @@ class Table(ABC): """ @abstractmethod - def alter_columns(self, alterations: Iterable[Dict[str, str]]): + def alter_columns(self, *alterations: Iterable[Dict[str, str]]): """ Alter column names and nullability. - This is not yet available in LanceDB Cloud. - alterations : Iterable[Dict[str, Any]] A sequence of dictionaries, each with the following keys: - "path": str The column path to alter. For a top-level column, this is the name. For a nested column, this is the dot-separated path, e.g. "a.b.c". - - "name": str, optional + - "rename": str, optional The new name of the column. If not specified, the column name is not changed. + - "data_type": pyarrow.DataType, optional + The new data type of the column. Existing values will be casted + to this type. If not specified, the column data type is not changed. - "nullable": bool, optional Whether the column should be nullable. If not specified, the column nullability is not changed. Only non-nullable columns can be changed @@ -1004,8 +1003,6 @@ class Table(ABC): """ Drop columns from the table. - This is not yet available in LanceDB Cloud. - Parameters ---------- columns : Iterable[str] @@ -2923,6 +2920,53 @@ class AsyncTable: return await self._inner.update(updates_sql, where) + async def add_columns(self, transforms: Dict[str, str]): + """ + Add new columns with defined values. + + Parameters + ---------- + transforms: Dict[str, str] + A map of column name to a SQL expression to use to calculate the + value of the new column. These expressions will be evaluated for + each row in the table, and can reference existing columns. + """ + await self._inner.add_columns(list(transforms.items())) + + async def alter_columns(self, *alterations: Iterable[Dict[str, str]]): + """ + Alter column names and nullability. + + alterations : Iterable[Dict[str, Any]] + A sequence of dictionaries, each with the following keys: + - "path": str + The column path to alter. For a top-level column, this is the name. + For a nested column, this is the dot-separated path, e.g. "a.b.c". + - "rename": str, optional + The new name of the column. If not specified, the column name is + not changed. + - "data_type": pyarrow.DataType, optional + The new data type of the column. Existing values will be casted + to this type. If not specified, the column data type is not changed. + - "nullable": bool, optional + Whether the column should be nullable. If not specified, the column + nullability is not changed. Only non-nullable columns can be changed + to nullable. Currently, you cannot change a nullable column to + non-nullable. + """ + await self._inner.alter_columns(alterations) + + async def drop_columns(self, columns: Iterable[str]): + """ + Drop columns from the table. + + Parameters + ---------- + columns : Iterable[str] + The names of the columns to drop. + """ + await self._inner.drop_columns(columns) + async def version(self) -> int: """ Retrieve the version of the table diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 2b65a620..a4dd2e9d 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1292,6 +1292,19 @@ def test_add_columns(tmp_path): assert table.to_arrow().column_names == ["id", "new_col"] assert table.to_arrow()["new_col"].to_pylist() == [2, 3] + table.add_columns({"null_int": "cast(null as bigint)"}) + assert table.schema.field("null_int").type == pa.int64() + + +@pytest.mark.asyncio +async def test_add_columns_async(db_async: AsyncConnection): + data = pa.table({"id": [0, 1]}) + table = await db_async.create_table("my_table", data=data) + await table.add_columns({"new_col": "id + 2"}) + data = await table.to_arrow() + assert data.column_names == ["id", "new_col"] + assert data["new_col"].to_pylist() == [2, 3] + def test_alter_columns(tmp_path): db = lancedb.connect(tmp_path) @@ -1301,6 +1314,18 @@ def test_alter_columns(tmp_path): assert table.to_arrow().column_names == ["new_id"] +@pytest.mark.asyncio +async def test_alter_columns_async(db_async: AsyncConnection): + data = pa.table({"id": [0, 1]}) + table = await db_async.create_table("my_table", data=data) + await table.alter_columns({"path": "id", "rename": "new_id"}) + assert (await table.to_arrow()).column_names == ["new_id"] + await table.alter_columns(dict(path="new_id", data_type=pa.int16(), nullable=True)) + data = await table.to_arrow() + assert data.column(0).type == pa.int16() + assert data.schema.field(0).nullable + + def test_drop_columns(tmp_path): db = lancedb.connect(tmp_path) data = pa.table({"id": [0, 1], "category": ["a", "b"]}) @@ -1309,6 +1334,14 @@ def test_drop_columns(tmp_path): assert table.to_arrow().column_names == ["id"] +@pytest.mark.asyncio +async def test_drop_columns_async(db_async: AsyncConnection): + data = pa.table({"id": [0, 1], "category": ["a", "b"]}) + table = await db_async.create_table("my_table", data=data) + await table.drop_columns(["category"]) + assert (await table.to_arrow()).column_names == ["id"] + + @pytest.mark.asyncio async def test_time_travel(db_async: AsyncConnection): # Setup diff --git a/python/src/table.rs b/python/src/table.rs index bd470a63..25a3b97e 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -1,14 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors use arrow::{ + datatypes::DataType, ffi_stream::ArrowArrayStreamReader, pyarrow::{FromPyArrow, ToPyArrow}, }; use lancedb::table::{ - AddDataMode, Duration, OptimizeAction, OptimizeOptions, Table as LanceDbTable, + AddDataMode, ColumnAlteration, Duration, NewColumnTransform, OptimizeAction, OptimizeOptions, + Table as LanceDbTable, }; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, pyclass, pymethods, - types::{IntoPyDict, PyDict, PyDictMethods, PyString}, + types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods, PyString}, Bound, FromPyObject, PyAny, PyRef, PyResult, Python, ToPyObject, }; use pyo3_asyncio_0_21::tokio::future_into_py; @@ -406,6 +410,72 @@ impl Table { .infer_error() }) } + + pub fn add_columns( + self_: PyRef<'_, Self>, + definitions: Vec<(String, String)>, + ) -> PyResult> { + let definitions = NewColumnTransform::SqlExpressions(definitions); + + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + inner.add_columns(definitions, None).await.infer_error()?; + Ok(()) + }) + } + + pub fn alter_columns<'a>( + self_: PyRef<'a, Self>, + alterations: Vec>, + ) -> PyResult> { + let alterations = alterations + .iter() + .map(|alteration| { + let path = alteration + .get_item("path")? + .ok_or_else(|| PyValueError::new_err("Missing path"))? + .extract()?; + let rename = { + // We prefer rename, but support name for backwards compatibility + let rename = if let Ok(Some(rename)) = alteration.get_item("rename") { + Some(rename) + } else { + alteration.get_item("name")? + }; + rename.map(|name| name.extract()).transpose()? + }; + let nullable = alteration + .get_item("nullable")? + .map(|val| val.extract()) + .transpose()?; + let data_type = alteration + .get_item("data_type")? + .map(|val| DataType::from_pyarrow_bound(&val)) + .transpose()?; + Ok(ColumnAlteration { + path, + rename, + nullable, + data_type, + }) + }) + .collect::>>()?; + + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + inner.alter_columns(&alterations).await.infer_error()?; + Ok(()) + }) + } + + pub fn drop_columns(self_: PyRef, columns: Vec) -> PyResult> { + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + let column_refs = columns.iter().map(String::as_str).collect::>(); + inner.drop_columns(&column_refs).await.infer_error()?; + Ok(()) + }) + } } #[derive(FromPyObject)] diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index c8ce2ec4..256ea2d9 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -9,7 +9,7 @@ use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{Error, Table}; use arrow_array::RecordBatchReader; use arrow_ipc::reader::FileReader; -use arrow_schema::{DataType, SchemaRef}; +use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef}; use async_trait::async_trait; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; @@ -643,25 +643,85 @@ impl TableInternal for RemoteTable { } async fn add_columns( &self, - _transforms: NewColumnTransform, + transforms: NewColumnTransform, _read_columns: Option>, ) -> Result<()> { self.check_mutable().await?; - Err(Error::NotSupported { - message: "add_columns is not yet supported.".into(), - }) + match transforms { + NewColumnTransform::SqlExpressions(expressions) => { + let body = expressions + .into_iter() + .map(|(name, expression)| { + serde_json::json!({ + "name": name, + "expression": expression, + }) + }) + .collect::>(); + let body = serde_json::json!({ "new_columns": body }); + let request = self + .client + .post(&format!("/v1/table/{}/add_columns/", self.name)) + .json(&body); + let (request_id, response) = self.client.send(request, false).await?; + self.check_table_response(&request_id, response).await?; + Ok(()) + } + _ => { + return Err(Error::NotSupported { + message: "Only SQL expressions are supported for adding columns".into(), + }); + } + } } - async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> { + + async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()> { self.check_mutable().await?; - Err(Error::NotSupported { - message: "alter_columns is not yet supported.".into(), - }) + let body = alterations + .iter() + .map(|alteration| { + let mut value = serde_json::json!({ + "path": alteration.path, + }); + if let Some(rename) = &alteration.rename { + value["rename"] = serde_json::Value::String(rename.clone()); + } + if let Some(data_type) = &alteration.data_type { + // TODO: we can later simplify this substantially, after getting: + // https://github.com/lancedb/lance/pull/3161 + let dummy_schema = + ArrowSchema::new(vec![ArrowField::new("dummy", data_type.clone(), false)]); + let json_schema = JsonSchema::try_from(&dummy_schema).unwrap(); + let json_string = serde_json::to_string(&json_schema).unwrap(); + let json_value: serde_json::Value = serde_json::from_str(&json_string).unwrap(); + value["data_type"] = json_value["fields"][0]["type"].clone(); + } + if let Some(nullable) = &alteration.nullable { + value["nullable"] = serde_json::Value::Bool(*nullable); + } + value + }) + .collect::>(); + let body = serde_json::json!({ "alterations": body }); + let request = self + .client + .post(&format!("/v1/table/{}/alter_columns/", self.name)) + .json(&body); + let (request_id, response) = self.client.send(request, false).await?; + self.check_table_response(&request_id, response).await?; + Ok(()) } - async fn drop_columns(&self, _columns: &[&str]) -> Result<()> { + + async fn drop_columns(&self, columns: &[&str]) -> Result<()> { self.check_mutable().await?; - Err(Error::NotSupported { - message: "drop_columns is not yet supported.".into(), - }) + let body = serde_json::json!({ "columns": columns }); + let request = self + .client + .post(&format!("/v1/table/{}/drop_columns/", self.name)) + .json(&body); + let (request_id, response) = self.client.send(request, false).await?; + self.check_table_response(&request_id, response).await?; + Ok(()) } async fn list_indices(&self) -> Result> { @@ -844,7 +904,17 @@ mod tests { Box::pin(table.update().column("a", "a + 1").execute().map_ok(|_| ())), Box::pin(table.add(example_data()).execute().map_ok(|_| ())), Box::pin(table.merge_insert(&["test"]).execute(example_data())), - Box::pin(table.delete("false")), // TODO: other endpoints. + Box::pin(table.delete("false")), + Box::pin(table.add_columns( + NewColumnTransform::SqlExpressions(vec![("x".into(), "y".into())]), + None, + )), + Box::pin(async { + let alterations = vec![ColumnAlteration::new("x".into()).rename("y".into())]; + table.alter_columns(&alterations).await + }), + Box::pin(table.drop_columns(&["a"])), + // TODO: other endpoints. ]; for result in results { @@ -1799,4 +1869,114 @@ mod tests { .await; assert!(matches!(res, Err(Error::NotSupported { .. }))); } + + #[tokio::test] + async fn test_add_columns() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/add_columns/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let new_columns = value.get("new_columns").unwrap().as_array().unwrap(); + assert!(new_columns.len() == 2); + + let col_name = new_columns[0]["name"].as_str().unwrap(); + let expression = new_columns[0]["expression"].as_str().unwrap(); + assert_eq!(col_name, "b"); + assert_eq!(expression, "a + 1"); + + let col_name = new_columns[1]["name"].as_str().unwrap(); + let expression = new_columns[1]["expression"].as_str().unwrap(); + assert_eq!(col_name, "x"); + assert_eq!(expression, "cast(NULL as int32)"); + + http::Response::builder().status(200).body("{}").unwrap() + }); + + table + .add_columns( + NewColumnTransform::SqlExpressions(vec![ + ("b".into(), "a + 1".into()), + ("x".into(), "cast(NULL as int32)".into()), + ]), + None, + ) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_alter_columns() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/alter_columns/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let alterations = value.get("alterations").unwrap().as_array().unwrap(); + assert!(alterations.len() == 2); + + let path = alterations[0]["path"].as_str().unwrap(); + let data_type = alterations[0]["data_type"]["type"].as_str().unwrap(); + assert_eq!(path, "b.c"); + assert_eq!(data_type, "int32"); + + let path = alterations[1]["path"].as_str().unwrap(); + let nullable = alterations[1]["nullable"].as_bool().unwrap(); + let rename = alterations[1]["rename"].as_str().unwrap(); + assert_eq!(path, "x"); + assert!(nullable); + assert_eq!(rename, "y"); + + http::Response::builder().status(200).body("{}").unwrap() + }); + + table + .alter_columns(&[ + ColumnAlteration::new("b.c".into()).cast_to(DataType::Int32), + ColumnAlteration::new("x".into()) + .rename("y".into()) + .set_nullable(true), + ]) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_drop_columns() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/drop_columns/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let columns = value.get("columns").unwrap().as_array().unwrap(); + assert!(columns.len() == 2); + + let col1 = columns[0].as_str().unwrap(); + let col2 = columns[1].as_str().unwrap(); + assert_eq!(col1, "a"); + assert_eq!(col2, "b"); + + http::Response::builder().status(200).body("{}").unwrap() + }); + + table.drop_columns(&["a", "b"]).await.unwrap(); + } } diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index f165a367..3a4c3b91 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use arrow_schema::{DataType, Schema}; +use lance::arrow::json::JsonSchema; use lance::dataset::{ReadParams, WriteParams}; use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lazy_static::lazy_static; @@ -175,6 +176,25 @@ pub fn supported_vector_data_type(dtype: &DataType) -> bool { } } +/// Note: this is temporary until we get a proper datatype conversion in Lance. +pub fn string_to_datatype(s: &str) -> Option { + // TODO: we can later simplify this substantially, after getting: + // https://github.com/lancedb/lance/pull/3161 + let dummy_schema = format!( + "{{\"fields\": [\ + {{ \"name\": \"n\", \ + \"nullable\": true, \ + \"type\": {{\ + \"type\": \"{}\"\ + }} }}] }}", + s + ); + let json_schema: JsonSchema = serde_json::from_str(&dummy_schema).ok()?; + let schema = Schema::try_from(json_schema).ok()?; + let data_type = schema.field(0).data_type().clone(); + Some(data_type) +} + #[cfg(test)] mod tests { use super::*;