mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 05:42:58 +00:00
feat: add a filterable count_rows to all the lancedb APIs (#913)
A `count_rows` method that takes a filter was recently added to `LanceTable`. This PR adds it everywhere else except `RemoteTable` (that will come soon).
This commit is contained in:
@@ -372,7 +372,7 @@ export interface Table<T = number[]> {
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
countRows: () => Promise<number>
|
||||
countRows: (filter?: string) => Promise<number>
|
||||
|
||||
/**
|
||||
* Delete rows from this table.
|
||||
@@ -840,8 +840,8 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
async countRows (): Promise<number> {
|
||||
return tableCountRows.call(this._tbl)
|
||||
async countRows (filter?: string): Promise<number> {
|
||||
return tableCountRows.call(this._tbl, filter)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -294,6 +294,7 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.equal(await table.countRows(), 10)
|
||||
assert.equal(await table.countRows('vector IS NULL'), 0)
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
@@ -369,6 +370,7 @@ describe('LanceDB client', function () {
|
||||
const table = await con.createTable('f16', data)
|
||||
assert.equal(table.name, 'f16')
|
||||
assert.equal(await table.countRows(), total)
|
||||
assert.equal(await table.countRows('id < 5'), 5)
|
||||
assert.deepEqual(await con.tableNames(), ['f16'])
|
||||
assert.deepEqual(await table.schema, schema)
|
||||
|
||||
|
||||
@@ -57,8 +57,8 @@ impl Table {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn count_rows(&self) -> napi::Result<usize> {
|
||||
self.table.count_rows().await.map_err(|e| {
|
||||
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
|
||||
self.table.count_rows(filter).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to count rows in table {}: {}",
|
||||
self.table, e
|
||||
|
||||
2
nodejs/vectordb/native.d.ts
vendored
2
nodejs/vectordb/native.d.ts
vendored
@@ -73,7 +73,7 @@ export class Table {
|
||||
/** Return Schema as empty Arrow IPC file. */
|
||||
schema(): Buffer
|
||||
add(buf: Buffer): Promise<void>
|
||||
countRows(): Promise<bigint>
|
||||
countRows(filter?: string): Promise<bigint>
|
||||
delete(predicate: string): Promise<void>
|
||||
createIndex(): IndexBuilder
|
||||
query(): Query
|
||||
|
||||
@@ -50,8 +50,8 @@ export class Table {
|
||||
}
|
||||
|
||||
/** Count the total number of rows in the dataset. */
|
||||
async countRows(): Promise<bigint> {
|
||||
return await this.inner.countRows();
|
||||
async countRows(filter?: string): Promise<bigint> {
|
||||
return await this.inner.countRows(filter);
|
||||
}
|
||||
|
||||
/** Delete the rows that satisfy the predicate. */
|
||||
|
||||
@@ -37,6 +37,9 @@ class RemoteTable(Table):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
|
||||
@cached_property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
@@ -409,6 +412,13 @@ class RemoteTable(Table):
|
||||
"compact_files() is not supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
# payload = {"filter": filter}
|
||||
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
|
||||
return NotImplementedError(
|
||||
"count_rows() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
return tbl.add_column(
|
||||
|
||||
@@ -177,6 +177,18 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
@@ -926,14 +938,6 @@ class LanceTable(Table):
|
||||
self._ref.dataset = ds
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
return self._dataset.count_rows(filter)
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -133,13 +133,26 @@ impl JsTable {
|
||||
|
||||
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let filter = cx
|
||||
.argument_opt(0)
|
||||
.and_then(|filt| {
|
||||
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
filt.downcast_or_throw::<JsString, _>(&mut cx)
|
||||
.map(|js_filt| js_filt.deref().value(&mut cx)),
|
||||
)
|
||||
}
|
||||
})
|
||||
.transpose()?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let num_rows_result = table.count_rows().await;
|
||||
let num_rows_result = table.count_rows(filter).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let num_rows = num_rows_result.or_throw(&mut cx)?;
|
||||
|
||||
@@ -197,7 +197,7 @@ impl IndexBuilder {
|
||||
let num_partitions = if let Some(n) = self.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.table.count_rows().await?)
|
||||
suggested_num_partitions(self.table.count_rows(None).await?)
|
||||
};
|
||||
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
|
||||
n
|
||||
|
||||
@@ -372,7 +372,7 @@ mod test {
|
||||
// leave this here for easy debugging
|
||||
let t = res.unwrap();
|
||||
|
||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||
assert_eq!(t.count_rows(None).await.unwrap(), 100);
|
||||
|
||||
let q = t
|
||||
.search(&[0.1, 0.1, 0.1, 0.1])
|
||||
|
||||
@@ -102,7 +102,11 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
fn schema(&self) -> SchemaRef;
|
||||
|
||||
/// Count the number of rows in this dataset.
|
||||
async fn count_rows(&self) -> Result<usize>;
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `filter` if present, only count rows matching the filter
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
|
||||
/// Insert new records into this Table
|
||||
///
|
||||
@@ -719,9 +723,15 @@ impl Table for NativeTable {
|
||||
Arc::new(Schema::from(&lance_schema))
|
||||
}
|
||||
|
||||
async fn count_rows(&self) -> Result<usize> {
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
let dataset = { self.dataset.lock().expect("lock poison").clone() };
|
||||
Ok(dataset.count_rows().await?)
|
||||
if let Some(filter) = filter {
|
||||
let mut scanner = dataset.scan();
|
||||
scanner.filter(&filter)?;
|
||||
Ok(scanner.count_rows().await? as usize)
|
||||
} else {
|
||||
Ok(dataset.count_rows().await?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn add(
|
||||
@@ -886,6 +896,23 @@ mod tests {
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_count_rows() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(
|
||||
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -896,7 +923,7 @@ mod tests {
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
@@ -910,7 +937,7 @@ mod tests {
|
||||
);
|
||||
|
||||
table.add(Box::new(new_batches), None).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 20);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 20);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
|
||||
@@ -924,7 +951,7 @@ mod tests {
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
// Create new data with i=5..15
|
||||
let new_batches = Box::new(make_test_batches_with_offset(5));
|
||||
@@ -934,7 +961,7 @@ mod tests {
|
||||
merge_insert_builder.when_not_matched_insert_all();
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
// Only 5 rows should actually be inserted
|
||||
assert_eq!(table.count_rows().await.unwrap(), 15);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||
|
||||
// Create new data with i=15..25 (no id matches)
|
||||
let new_batches = Box::new(make_test_batches_with_offset(15));
|
||||
@@ -943,7 +970,7 @@ mod tests {
|
||||
merge_insert_builder.when_matched_update_all();
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
// No new rows should have been inserted
|
||||
assert_eq!(table.count_rows().await.unwrap(), 15);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -956,7 +983,7 @@ mod tests {
|
||||
let table = NativeTable::create(uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
@@ -975,7 +1002,7 @@ mod tests {
|
||||
};
|
||||
|
||||
table.add(Box::new(new_batches), Some(param)).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
|
||||
@@ -1365,7 +1392,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.load_indices().await.unwrap().len(), 1);
|
||||
assert_eq!(table.count_rows().await.unwrap(), 512);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 512);
|
||||
assert_eq!(table.name, "test");
|
||||
|
||||
let indices = table.load_indices().await.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user