added projection api for nodejs (#140)

This commit is contained in:
gsilvestrin
2023-06-03 10:34:08 -07:00
committed by GitHub
parent 41cca31f48
commit d0c47e3838
5 changed files with 66 additions and 4 deletions

View File

@@ -129,4 +129,9 @@ You can select the columns returned by the query using a select clause.
```
=== "Javascript"
Projections are not currently supported in the Javascript SDK.
```javascript
const results = await table
.search(Array(1536).fill(1.2))
.select(["id"])
.execute()
```

View File

@@ -233,7 +233,7 @@ export class Query<T = number[]> {
private _limit: number
private _refineFactor?: number
private _nprobes: number
private readonly _columns?: string[]
private _select?: string[]
private _filter?: string
private _metricType?: MetricType
private readonly _embeddings?: EmbeddingFunction<T>
@@ -244,7 +244,7 @@ export class Query<T = number[]> {
this._limit = 10
this._nprobes = 20
this._refineFactor = undefined
this._columns = undefined
this._select = undefined
this._filter = undefined
this._metricType = undefined
this._embeddings = embeddings
@@ -286,6 +286,15 @@ export class Query<T = number[]> {
return this
}
/** Return only the specified columns.
*
* @param value Only select the specified columns. If not specified, all columns will be returned.
*/
select (value: string[]): Query<T> {
this._select = value
return this
}
/**
* The MetricType used for this Query.
* @param value The metric to the. @see MetricType for the different options

View File

@@ -72,6 +72,22 @@ describe('LanceDB client', function () {
assert.equal(results.length, 1)
assert.equal(results[0].id, 2)
})
it('select only a subset of columns', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.1]).select(['is_active']).execute()
assert.equal(results.length, 2)
// vector and score are always returned
assert.isDefined(results[0].vector)
assert.isDefined(results[0].score)
assert.isDefined(results[0].is_active)
assert.isUndefined(results[0].id)
assert.isUndefined(results[0].name)
assert.isUndefined(results[0].price)
})
})
describe('when creating a new dataset', function () {
@@ -181,11 +197,13 @@ describe('Query object', function () {
.limit(1)
.metricType(MetricType.Cosine)
.refineFactor(100)
.select(['a', 'b'])
.nprobes(20) as Record<string, any>
assert.equal(query._limit, 1)
assert.equal(query._metricType, MetricType.Cosine)
assert.equal(query._refineFactor, 100)
assert.equal(query._nprobes, 20)
assert.deepEqual(query._select, ['a', 'b'])
})
})

View File

@@ -129,6 +129,17 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
let limit = query_obj
.get::<JsNumber, _, _>(&mut cx, "_limit")?
.value(&mut cx);
let select = query_obj
.get_opt::<JsArray, _, _>(&mut cx, "_select")?
.map(|arr| {
let js_array = arr.deref();
let mut projection_vec: Vec<String> = Vec::new();
for i in 0..js_array.len(&mut cx) {
let entry: Handle<JsString> = js_array.get(&mut cx, i).unwrap();
projection_vec.push(entry.value(&mut cx));
}
projection_vec
});
let filter = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
.map(|s| s.value(&mut cx));
@@ -161,7 +172,8 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
.refine_factor(refine_factor)
.nprobes(nprobes)
.filter(filter)
.metric_type(metric_type);
.metric_type(metric_type)
.select(select);
let record_batch_stream = builder.execute();
let results = record_batch_stream
.and_then(|stream| stream.try_collect::<Vec<_>>().map_err(Error::from))

View File

@@ -27,6 +27,7 @@ pub struct Query {
pub query_vector: Float32Array,
pub limit: usize,
pub filter: Option<String>,
pub select: Option<Vec<String>>,
pub nprobes: usize,
pub refine_factor: Option<u32>,
pub metric_type: Option<MetricType>,
@@ -54,6 +55,7 @@ impl Query {
metric_type: None,
use_index: false,
filter: None,
select: None,
}
}
@@ -72,6 +74,9 @@ impl Query {
)?;
scanner.nprobs(self.nprobes);
scanner.use_index(self.use_index);
self.select
.as_ref()
.map(|p| scanner.project(p.as_slice()));
self.filter.as_ref().map(|f| scanner.filter(f));
self.refine_factor.map(|rf| scanner.refine(rf));
self.metric_type.map(|mt| scanner.distance_metric(mt));
@@ -138,10 +143,23 @@ impl Query {
self
}
/// A filter statement to be applied to this query.
///
/// # Arguments
///
/// * `filter` - value A filter in the same format used by a sql WHERE clause.
pub fn filter(mut self, filter: Option<String>) -> Query {
self.filter = filter;
self
}
/// Return only the specified columns.
///
/// Only select the specified columns. If not specified, all columns will be returned.
pub fn select(mut self, columns: Option<Vec<String>>) -> Query {
self.select = columns;
self
}
}
#[cfg(test)]