feat(node): add Table.countRows() (#185)

This commit is contained in:
gsilvestrin
2023-06-15 14:35:54 -07:00
committed by GitHub
parent a6544c2a31
commit 78de8f5782
3 changed files with 33 additions and 12 deletions

View File

@@ -22,7 +22,7 @@ import { fromRecordsToBuffer } from './arrow'
import type { EmbeddingFunction } from './embedding/embedding_function'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex } = require('../native.js')
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex, tableCountRows } = require('../native.js')
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
@@ -178,6 +178,13 @@ export class Table<T = number[]> {
async create_index (indexParams: VectorIndexParams): Promise<any> {
return await this.createIndex(indexParams)
}
/**
* Returns the number of rows in this table.
*/
async countRows (): Promise<number> {
return tableCountRows.call(this._tbl)
}
}
interface IvfPQIndexConfig {

View File

@@ -110,9 +110,7 @@ describe('LanceDB client', function () {
const tableName = `vectors_${Math.floor(Math.random() * 100)}`
const table = await con.createTable(tableName, data)
assert.equal(table.name, tableName)
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
assert.equal(await table.countRows(), 2)
})
it('appends records to an existing table ', async function () {
@@ -125,16 +123,14 @@ describe('LanceDB client', function () {
]
const table = await con.createTable('vectors', data)
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
assert.equal(await table.countRows(), 2)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], price: 10, name: 'c' },
{ id: 4, vector: [3.1, 3.2], price: 50, name: 'd' }
]
await table.add(dataAdd)
const resultsAdd = await table.search([0.1, 0.3]).execute()
assert.equal(resultsAdd.length, 4)
assert.equal(await table.countRows(), 4)
})
it('overwrite all records in a table', async function () {
@@ -142,16 +138,14 @@ describe('LanceDB client', function () {
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
assert.equal(await table.countRows(), 2)
const dataOver = [
{ vector: [2.1, 2.2], price: 10, name: 'foo' },
{ vector: [3.1, 3.2], price: 50, name: 'bar' }
]
await table.overwrite(dataOver)
const resultsAdd = await table.search([0.1, 0.3]).execute()
assert.equal(resultsAdd.length, 2)
assert.equal(await table.countRows(), 2)
})
})

View File

@@ -264,6 +264,25 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
Ok(promise)
}
fn table_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
rt.block_on(async move {
let num_rows_result = table.lock().unwrap().count_rows().await;
deferred.settle_with(&channel, move |mut cx| {
let num_rows = num_rows_result.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(cx.number(num_rows as f64))
});
});
Ok(promise)
}
#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("databaseNew", database_new)?;
@@ -272,6 +291,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableSearch", table_search)?;
cx.export_function("tableCreate", table_create)?;
cx.export_function("tableAdd", table_add)?;
cx.export_function("tableCountRows", table_count_rows)?;
cx.export_function(
"tableCreateVectorIndex",
index::vector::table_create_vector_index,