feat: add a filterable count_rows to all the lancedb APIs (#913)

A `count_rows` method that takes a filter was recently added to
`LanceTable`. This PR adds it everywhere else except `RemoteTable` (that
will come soon).
This commit is contained in:
Weston Pace
2024-02-08 09:40:29 -08:00
committed by GitHub
parent f53aace89c
commit d2e71c8b08
11 changed files with 86 additions and 30 deletions

View File

@@ -372,7 +372,7 @@ export interface Table<T = number[]> {
/**
* Returns the number of rows in this table.
*/
countRows: () => Promise<number>
countRows: (filter?: string) => Promise<number>
/**
* Delete rows from this table.
@@ -840,8 +840,8 @@ export class LocalTable<T = number[]> implements Table<T> {
/**
* Returns the number of rows in this table.
*/
async countRows (): Promise<number> {
return tableCountRows.call(this._tbl)
async countRows (filter?: string): Promise<number> {
return tableCountRows.call(this._tbl, filter)
}
/**

View File

@@ -294,6 +294,7 @@ describe('LanceDB client', function () {
})
assert.equal(table.name, 'vectors')
assert.equal(await table.countRows(), 10)
assert.equal(await table.countRows('vector IS NULL'), 0)
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -369,6 +370,7 @@ describe('LanceDB client', function () {
const table = await con.createTable('f16', data)
assert.equal(table.name, 'f16')
assert.equal(await table.countRows(), total)
assert.equal(await table.countRows('id < 5'), 5)
assert.deepEqual(await con.tableNames(), ['f16'])
assert.deepEqual(await table.schema, schema)

View File

@@ -57,8 +57,8 @@ impl Table {
}
#[napi]
pub async fn count_rows(&self) -> napi::Result<usize> {
self.table.count_rows().await.map_err(|e| {
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
self.table.count_rows(filter).await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to count rows in table {}: {}",
self.table, e

View File

@@ -73,7 +73,7 @@ export class Table {
/** Return Schema as empty Arrow IPC file. */
schema(): Buffer
add(buf: Buffer): Promise<void>
countRows(): Promise<bigint>
countRows(filter?: string): Promise<bigint>
delete(predicate: string): Promise<void>
createIndex(): IndexBuilder
query(): Query

View File

@@ -50,8 +50,8 @@ export class Table {
}
/** Count the total number of rows in the dataset. */
async countRows(): Promise<bigint> {
return await this.inner.countRows();
async countRows(filter?: string): Promise<bigint> {
return await this.inner.countRows(filter);
}
/** Delete the rows that satisfy the predicate. */

View File

@@ -37,6 +37,9 @@ class RemoteTable(Table):
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self._name})"
def __len__(self) -> int:
self.count_rows(None)
@cached_property
def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
@@ -409,6 +412,13 @@ class RemoteTable(Table):
"compact_files() is not supported on the LanceDB cloud"
)
def count_rows(self, filter: Optional[str] = None) -> int:
# payload = {"filter": filter}
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
return NotImplementedError(
"count_rows() is not yet supported on the LanceDB cloud"
)
def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column(

View File

@@ -177,6 +177,18 @@ class Table(ABC):
"""
raise NotImplementedError
@abstractmethod
def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -926,14 +938,6 @@ class LanceTable(Table):
self._ref.dataset = ds
def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
return self._dataset.count_rows(filter)
def __len__(self):

View File

@@ -133,13 +133,26 @@ impl JsTable {
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let filter = cx
.argument_opt(0)
.and_then(|filt| {
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
None
} else {
Some(
filt.downcast_or_throw::<JsString, _>(&mut cx)
.map(|js_filt| js_filt.deref().value(&mut cx)),
)
}
})
.transpose()?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let table = js_table.table.clone();
rt.spawn(async move {
let num_rows_result = table.count_rows().await;
let num_rows_result = table.count_rows(filter).await;
deferred.settle_with(&channel, move |mut cx| {
let num_rows = num_rows_result.or_throw(&mut cx)?;

View File

@@ -197,7 +197,7 @@ impl IndexBuilder {
let num_partitions = if let Some(n) = self.num_partitions {
n
} else {
suggested_num_partitions(self.table.count_rows().await?)
suggested_num_partitions(self.table.count_rows(None).await?)
};
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
n

View File

@@ -372,7 +372,7 @@ mod test {
// leave this here for easy debugging
let t = res.unwrap();
assert_eq!(t.count_rows().await.unwrap(), 100);
assert_eq!(t.count_rows(None).await.unwrap(), 100);
let q = t
.search(&[0.1, 0.1, 0.1, 0.1])

View File

@@ -102,7 +102,11 @@ pub trait Table: std::fmt::Display + Send + Sync {
fn schema(&self) -> SchemaRef;
/// Count the number of rows in this dataset.
async fn count_rows(&self) -> Result<usize>;
///
/// # Arguments
///
/// * `filter` if present, only count rows matching the filter
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
/// Insert new records into this Table
///
@@ -719,9 +723,15 @@ impl Table for NativeTable {
Arc::new(Schema::from(&lance_schema))
}
async fn count_rows(&self) -> Result<usize> {
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
let dataset = { self.dataset.lock().expect("lock poison").clone() };
Ok(dataset.count_rows().await?)
if let Some(filter) = filter {
let mut scanner = dataset.scan();
scanner.filter(&filter)?;
Ok(scanner.count_rows().await? as usize)
} else {
Ok(dataset.count_rows().await?)
}
}
async fn add(
@@ -886,6 +896,23 @@ mod tests {
));
}
#[tokio::test]
async fn test_count_rows() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches = make_test_batches();
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
5
);
}
#[tokio::test]
async fn test_add() {
let tmp_dir = tempdir().unwrap();
@@ -896,7 +923,7 @@ mod tests {
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
@@ -910,7 +937,7 @@ mod tests {
);
table.add(Box::new(new_batches), None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
assert_eq!(table.count_rows(None).await.unwrap(), 20);
assert_eq!(table.name, "test");
}
@@ -924,7 +951,7 @@ mod tests {
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = Box::new(make_test_batches_with_offset(5));
@@ -934,7 +961,7 @@ mod tests {
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// Only 5 rows should actually be inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
assert_eq!(table.count_rows(None).await.unwrap(), 15);
// Create new data with i=15..25 (no id matches)
let new_batches = Box::new(make_test_batches_with_offset(15));
@@ -943,7 +970,7 @@ mod tests {
merge_insert_builder.when_matched_update_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
assert_eq!(table.count_rows(None).await.unwrap(), 15);
}
#[tokio::test]
@@ -956,7 +983,7 @@ mod tests {
let table = NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
@@ -975,7 +1002,7 @@ mod tests {
};
table.add(Box::new(new_batches), Some(param)).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(table.name, "test");
}
@@ -1365,7 +1392,7 @@ mod tests {
.unwrap();
assert_eq!(table.load_indices().await.unwrap().len(), 1);
assert_eq!(table.count_rows().await.unwrap(), 512);
assert_eq!(table.count_rows(None).await.unwrap(), 512);
assert_eq!(table.name, "test");
let indices = table.load_indices().await.unwrap();