From 138fc3f66bb23c2e20435c4b4aba9a76324d2ccd Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 8 Feb 2024 09:40:29 -0800 Subject: [PATCH] feat: add a filterable count_rows to all the lancedb APIs (#913) A `count_rows` method that takes a filter was recently added to `LanceTable`. This PR adds it everywhere else except `RemoteTable` (that will come soon). --- node/src/index.ts | 6 ++-- node/src/test/test.ts | 2 ++ nodejs/src/table.rs | 4 +-- nodejs/vectordb/native.d.ts | 2 +- nodejs/vectordb/table.ts | 4 +-- python/lancedb/remote/table.py | 10 ++++++ python/lancedb/table.py | 20 +++++++----- rust/ffi/node/src/table.rs | 15 ++++++++- rust/vectordb/src/index.rs | 2 +- rust/vectordb/src/io/object_store.rs | 2 +- rust/vectordb/src/table.rs | 49 +++++++++++++++++++++------- 11 files changed, 86 insertions(+), 30 deletions(-) diff --git a/node/src/index.ts b/node/src/index.ts index 9b7831b8..2607cc2d 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -372,7 +372,7 @@ export interface Table { /** * Returns the number of rows in this table. */ - countRows: () => Promise + countRows: (filter?: string) => Promise /** * Delete rows from this table. @@ -840,8 +840,8 @@ export class LocalTable implements Table { /** * Returns the number of rows in this table. */ - async countRows (): Promise { - return tableCountRows.call(this._tbl) + async countRows (filter?: string): Promise { + return tableCountRows.call(this._tbl, filter) } /** diff --git a/node/src/test/test.ts b/node/src/test/test.ts index db87cbc6..cda140fd 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -294,6 +294,7 @@ describe('LanceDB client', function () { }) assert.equal(table.name, 'vectors') assert.equal(await table.countRows(), 10) + assert.equal(await table.countRows('vector IS NULL'), 0) assert.deepEqual(await con.tableNames(), ['vectors']) }) @@ -369,6 +370,7 @@ describe('LanceDB client', function () { const table = await con.createTable('f16', data) assert.equal(table.name, 'f16') assert.equal(await table.countRows(), total) + assert.equal(await table.countRows('id < 5'), 5) assert.deepEqual(await con.tableNames(), ['f16']) assert.deepEqual(await table.schema, schema) diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index e99760cb..fdb16b26 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -57,8 +57,8 @@ impl Table { } #[napi] - pub async fn count_rows(&self) -> napi::Result { - self.table.count_rows().await.map_err(|e| { + pub async fn count_rows(&self, filter: Option) -> napi::Result { + self.table.count_rows(filter).await.map_err(|e| { napi::Error::from_reason(format!( "Failed to count rows in table {}: {}", self.table, e diff --git a/nodejs/vectordb/native.d.ts b/nodejs/vectordb/native.d.ts index 192f39a6..5c65d2e9 100644 --- a/nodejs/vectordb/native.d.ts +++ b/nodejs/vectordb/native.d.ts @@ -73,7 +73,7 @@ export class Table { /** Return Schema as empty Arrow IPC file. */ schema(): Buffer add(buf: Buffer): Promise - countRows(): Promise + countRows(filter?: string): Promise delete(predicate: string): Promise createIndex(): IndexBuilder query(): Query diff --git a/nodejs/vectordb/table.ts b/nodejs/vectordb/table.ts index ec3d31b9..2dd3bec1 100644 --- a/nodejs/vectordb/table.ts +++ b/nodejs/vectordb/table.ts @@ -50,8 +50,8 @@ export class Table { } /** Count the total number of rows in the dataset. */ - async countRows(): Promise { - return await this.inner.countRows(); + async countRows(filter?: string): Promise { + return await this.inner.countRows(filter); } /** Delete the rows that satisfy the predicate. */ diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 9815dda6..d298ec1b 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -37,6 +37,9 @@ class RemoteTable(Table): def __repr__(self) -> str: return f"RemoteTable({self._conn.db_name}.{self._name})" + def __len__(self) -> int: + self.count_rows(None) + @cached_property def schema(self) -> pa.Schema: """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) @@ -409,6 +412,13 @@ class RemoteTable(Table): "compact_files() is not supported on the LanceDB cloud" ) + def count_rows(self, filter: Optional[str] = None) -> int: + # payload = {"filter": filter} + # self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload) + return NotImplementedError( + "count_rows() is not yet supported on the LanceDB cloud" + ) + def add_index(tbl: pa.Table, i: int) -> pa.Table: return tbl.add_column( diff --git a/python/lancedb/table.py b/python/lancedb/table.py index cffc280d..e702e7f3 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -176,6 +176,18 @@ class Table(ABC): """ raise NotImplementedError + @abstractmethod + def count_rows(self, filter: Optional[str] = None) -> int: + """ + Count the number of rows in the table. + + Parameters + ---------- + filter: str, optional + A SQL where clause to filter the rows to count. + """ + raise NotImplementedError + def to_pandas(self) -> "pd.DataFrame": """Return the table as a pandas DataFrame. @@ -925,14 +937,6 @@ class LanceTable(Table): self._ref.dataset = ds def count_rows(self, filter: Optional[str] = None) -> int: - """ - Count the number of rows in the table. - - Parameters - ---------- - filter: str, optional - A SQL where clause to filter the rows to count. - """ return self._dataset.count_rows(filter) def __len__(self): diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index 8139cfda..ac5690f5 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -133,13 +133,26 @@ impl JsTable { pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; + let filter = cx + .argument_opt(0) + .and_then(|filt| { + if filt.is_a::(&mut cx) || filt.is_a::(&mut cx) { + None + } else { + Some( + filt.downcast_or_throw::(&mut cx) + .map(|js_filt| js_filt.deref().value(&mut cx)), + ) + } + }) + .transpose()?; let rt = runtime(&mut cx)?; let (deferred, promise) = cx.promise(); let channel = cx.channel(); let table = js_table.table.clone(); rt.spawn(async move { - let num_rows_result = table.count_rows().await; + let num_rows_result = table.count_rows(filter).await; deferred.settle_with(&channel, move |mut cx| { let num_rows = num_rows_result.or_throw(&mut cx)?; diff --git a/rust/vectordb/src/index.rs b/rust/vectordb/src/index.rs index e953f6e6..9131f2da 100644 --- a/rust/vectordb/src/index.rs +++ b/rust/vectordb/src/index.rs @@ -197,7 +197,7 @@ impl IndexBuilder { let num_partitions = if let Some(n) = self.num_partitions { n } else { - suggested_num_partitions(self.table.count_rows().await?) + suggested_num_partitions(self.table.count_rows(None).await?) }; let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors { n diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs index 8b6e4bbf..66efefb4 100644 --- a/rust/vectordb/src/io/object_store.rs +++ b/rust/vectordb/src/io/object_store.rs @@ -372,7 +372,7 @@ mod test { // leave this here for easy debugging let t = res.unwrap(); - assert_eq!(t.count_rows().await.unwrap(), 100); + assert_eq!(t.count_rows(None).await.unwrap(), 100); let q = t .search(&[0.1, 0.1, 0.1, 0.1]) diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index 97eaecd0..5f889060 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -102,7 +102,11 @@ pub trait Table: std::fmt::Display + Send + Sync { fn schema(&self) -> SchemaRef; /// Count the number of rows in this dataset. - async fn count_rows(&self) -> Result; + /// + /// # Arguments + /// + /// * `filter` if present, only count rows matching the filter + async fn count_rows(&self, filter: Option) -> Result; /// Insert new records into this Table /// @@ -719,9 +723,15 @@ impl Table for NativeTable { Arc::new(Schema::from(&lance_schema)) } - async fn count_rows(&self) -> Result { + async fn count_rows(&self, filter: Option) -> Result { let dataset = { self.dataset.lock().expect("lock poison").clone() }; - Ok(dataset.count_rows().await?) + if let Some(filter) = filter { + let mut scanner = dataset.scan(); + scanner.filter(&filter)?; + Ok(scanner.count_rows().await? as usize) + } else { + Ok(dataset.count_rows().await?) + } } async fn add( @@ -886,6 +896,23 @@ mod tests { )); } + #[tokio::test] + async fn test_count_rows() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + let batches = make_test_batches(); + let table = NativeTable::create(&uri, "test", batches, None, None) + .await + .unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 10); + assert_eq!( + table.count_rows(Some("i >= 5".to_string())).await.unwrap(), + 5 + ); + } + #[tokio::test] async fn test_add() { let tmp_dir = tempdir().unwrap(); @@ -896,7 +923,7 @@ mod tests { let table = NativeTable::create(&uri, "test", batches, None, None) .await .unwrap(); - assert_eq!(table.count_rows().await.unwrap(), 10); + assert_eq!(table.count_rows(None).await.unwrap(), 10); let new_batches = RecordBatchIterator::new( vec![RecordBatch::try_new( @@ -910,7 +937,7 @@ mod tests { ); table.add(Box::new(new_batches), None).await.unwrap(); - assert_eq!(table.count_rows().await.unwrap(), 20); + assert_eq!(table.count_rows(None).await.unwrap(), 20); assert_eq!(table.name, "test"); } @@ -924,7 +951,7 @@ mod tests { let table = NativeTable::create(&uri, "test", batches, None, None) .await .unwrap(); - assert_eq!(table.count_rows().await.unwrap(), 10); + assert_eq!(table.count_rows(None).await.unwrap(), 10); // Create new data with i=5..15 let new_batches = Box::new(make_test_batches_with_offset(5)); @@ -934,7 +961,7 @@ mod tests { merge_insert_builder.when_not_matched_insert_all(); merge_insert_builder.execute(new_batches).await.unwrap(); // Only 5 rows should actually be inserted - assert_eq!(table.count_rows().await.unwrap(), 15); + assert_eq!(table.count_rows(None).await.unwrap(), 15); // Create new data with i=15..25 (no id matches) let new_batches = Box::new(make_test_batches_with_offset(15)); @@ -943,7 +970,7 @@ mod tests { merge_insert_builder.when_matched_update_all(); merge_insert_builder.execute(new_batches).await.unwrap(); // No new rows should have been inserted - assert_eq!(table.count_rows().await.unwrap(), 15); + assert_eq!(table.count_rows(None).await.unwrap(), 15); } #[tokio::test] @@ -956,7 +983,7 @@ mod tests { let table = NativeTable::create(uri, "test", batches, None, None) .await .unwrap(); - assert_eq!(table.count_rows().await.unwrap(), 10); + assert_eq!(table.count_rows(None).await.unwrap(), 10); let new_batches = RecordBatchIterator::new( vec![RecordBatch::try_new( @@ -975,7 +1002,7 @@ mod tests { }; table.add(Box::new(new_batches), Some(param)).await.unwrap(); - assert_eq!(table.count_rows().await.unwrap(), 10); + assert_eq!(table.count_rows(None).await.unwrap(), 10); assert_eq!(table.name, "test"); } @@ -1365,7 +1392,7 @@ mod tests { .unwrap(); assert_eq!(table.load_indices().await.unwrap().len(), 1); - assert_eq!(table.count_rows().await.unwrap(), 512); + assert_eq!(table.count_rows(None).await.unwrap(), 512); assert_eq!(table.name, "test"); let indices = table.load_indices().await.unwrap();