diff --git a/rust/lancedb/src/index/scalar.rs b/rust/lancedb/src/index/scalar.rs index e7066548..8003688a 100644 --- a/rust/lancedb/src/index/scalar.rs +++ b/rust/lancedb/src/index/scalar.rs @@ -53,7 +53,10 @@ pub struct LabelListIndexBuilder {} /// A full text search index is an index on a string column that allows for full text search #[derive(Debug, Clone)] pub struct FtsIndexBuilder { - pub(crate) with_position: bool, + /// Whether to store the position of the tokens + /// This is used for phrase queries + pub with_position: bool, + pub tokenizer_configs: TokenizerConfig, } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 486912d8..90fd0510 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -570,7 +570,19 @@ impl TableInternal for RemoteTable { Index::BTree(_) => ("BTREE", None), Index::Bitmap(_) => ("BITMAP", None), Index::LabelList(_) => ("LABEL_LIST", None), - Index::FTS(_) => ("FTS", None), + Index::FTS(fts) => { + let with_position = fts.with_position; + let configs = serde_json::to_value(fts.tokenizer_configs).map_err(|e| { + Error::InvalidInput { + message: format!("failed to serialize FTS index params {:?}", e), + } + })?; + for (key, value) in configs.as_object().unwrap() { + body[key] = value.clone(); + } + body["with_position"] = serde_json::Value::Bool(with_position); + ("FTS", None) + } Index::Auto => { let schema = self.schema().await?; let field = schema @@ -1496,6 +1508,7 @@ mod tests { ]; for (index_type, distance_type, index) in cases { + let params = index.clone(); let table = Table::new_with_handler("my_table", move |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/v1/table/my_table/create_index/"); @@ -1512,6 +1525,17 @@ mod tests { if let Some(distance_type) = distance_type { expected_body["metric_type"] = distance_type.to_lowercase().into(); } + if let Index::FTS(fts) = ¶ms { + expected_body["with_position"] = fts.with_position.into(); + expected_body["base_tokenizer"] = "simple".into(); + expected_body["language"] = "English".into(); + expected_body["max_token_length"] = 40.into(); + expected_body["lower_case"] = true.into(); + expected_body["stem"] = false.into(); + expected_body["remove_stop_words"] = false.into(); + expected_body["ascii_folding"] = false.into(); + } + assert_eq!(body, expected_body); http::Response::builder().status(200).body("{}").unwrap()