From 7a8d2f37c4a171cc6b18ad09d03c5078cd431b21 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sat, 21 Sep 2024 21:26:19 -0700 Subject: [PATCH] feat(rust): add with_row_id to rust SDK (#1683) --- rust/lancedb/src/query.rs | 34 ++++++++++++++++++++++++++++++++++ rust/lancedb/src/table.rs | 4 ++++ 2 files changed, 38 insertions(+) diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index d28956686..6118e6b76 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -402,6 +402,9 @@ pub trait QueryBase { /// /// By default, it is false. fn fast_search(self) -> Self; + + /// Return the `_rowid` meta column from the Table. + fn with_row_id(self) -> Self; } pub trait HasQuery { @@ -438,6 +441,11 @@ impl QueryBase for T { self.mut_query().fast_search = true; self } + + fn with_row_id(mut self) -> Self { + self.mut_query().with_row_id = true; + self + } } /// Options for controlling the execution of a query @@ -548,6 +556,11 @@ pub struct Query { /// /// By default, this is false. pub(crate) fast_search: bool, + + /// If set to true, the query will return the `_rowid` meta column. + /// + /// By default, this is false. + pub(crate) with_row_id: bool, } impl Query { @@ -560,6 +573,7 @@ impl Query { full_text_search: None, select: Select::All, fast_search: false, + with_row_id: false, } } @@ -1160,4 +1174,24 @@ mod tests { .unwrap(); assert!(!plan.contains("Take")); } + + #[tokio::test] + async fn test_with_row_id() { + let tmp_dir = tempdir().unwrap(); + let table = make_test_table(&tmp_dir).await; + let results = table + .vector_search(&[0.1, 0.2, 0.3, 0.4]) + .unwrap() + .with_row_id() + .limit(10) + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + for batch in results { + assert!(batch.column_by_name("_rowid").is_some()); + } + } } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 8096a8daf..5e753e9f7 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1938,6 +1938,10 @@ impl TableInternal for NativeTable { Select::All => {} } + if query.base.with_row_id { + scanner.with_row_id(); + } + if let Some(opts) = options { scanner.batch_size(opts.max_batch_length as usize); }