diff --git a/node/src/index.ts b/node/src/index.ts index 50f48509..7c377d73 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -22,7 +22,7 @@ import { fromRecordsToBuffer } from './arrow' import type { EmbeddingFunction } from './embedding/embedding_function' // eslint-disable-next-line @typescript-eslint/no-var-requires -const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex } = require('../native.js') +const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex, tableCountRows } = require('../native.js') export type { EmbeddingFunction } export { OpenAIEmbeddingFunction } from './embedding/openai' @@ -178,6 +178,13 @@ export class Table { async create_index (indexParams: VectorIndexParams): Promise { return await this.createIndex(indexParams) } + + /** + * Returns the number of rows in this table. + */ + async countRows (): Promise { + return tableCountRows.call(this._tbl) + } } interface IvfPQIndexConfig { diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 06133369..3e0e8c60 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -110,9 +110,7 @@ describe('LanceDB client', function () { const tableName = `vectors_${Math.floor(Math.random() * 100)}` const table = await con.createTable(tableName, data) assert.equal(table.name, tableName) - - const results = await table.search([0.1, 0.3]).execute() - assert.equal(results.length, 2) + assert.equal(await table.countRows(), 2) }) it('appends records to an existing table ', async function () { @@ -125,16 +123,14 @@ describe('LanceDB client', function () { ] const table = await con.createTable('vectors', data) - const results = await table.search([0.1, 0.3]).execute() - assert.equal(results.length, 2) + assert.equal(await table.countRows(), 2) const dataAdd = [ { id: 3, vector: [2.1, 2.2], price: 10, name: 'c' }, { id: 4, vector: [3.1, 3.2], price: 50, name: 'd' } ] await table.add(dataAdd) - const resultsAdd = await table.search([0.1, 0.3]).execute() - assert.equal(resultsAdd.length, 4) + assert.equal(await table.countRows(), 4) }) it('overwrite all records in a table', async function () { @@ -142,16 +138,14 @@ describe('LanceDB client', function () { const con = await lancedb.connect(uri) const table = await con.openTable('vectors') - const results = await table.search([0.1, 0.3]).execute() - assert.equal(results.length, 2) + assert.equal(await table.countRows(), 2) const dataOver = [ { vector: [2.1, 2.2], price: 10, name: 'foo' }, { vector: [3.1, 3.2], price: 50, name: 'bar' } ] await table.overwrite(dataOver) - const resultsAdd = await table.search([0.1, 0.3]).execute() - assert.equal(resultsAdd.length, 2) + assert.equal(await table.countRows(), 2) }) }) diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index be4369c2..b13f9d53 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -264,6 +264,25 @@ fn table_add(mut cx: FunctionContext) -> JsResult { Ok(promise) } +fn table_count_rows(mut cx: FunctionContext) -> JsResult { + let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; + let rt = runtime(&mut cx)?; + let channel = cx.channel(); + + let (deferred, promise) = cx.promise(); + let table = js_table.table.clone(); + + rt.block_on(async move { + let num_rows_result = table.lock().unwrap().count_rows().await; + + deferred.settle_with(&channel, move |mut cx| { + let num_rows = num_rows_result.or_else(|err| cx.throw_error(err.to_string()))?; + Ok(cx.number(num_rows as f64)) + }); + }); + Ok(promise) +} + #[neon::main] fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("databaseNew", database_new)?; @@ -272,6 +291,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("tableSearch", table_search)?; cx.export_function("tableCreate", table_create)?; cx.export_function("tableAdd", table_add)?; + cx.export_function("tableCountRows", table_count_rows)?; cx.export_function( "tableCreateVectorIndex", index::vector::table_create_vector_index,