mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
added projection api for nodejs (#140)
This commit is contained in:
@@ -129,4 +129,9 @@ You can select the columns returned by the query using a select clause.
|
||||
```
|
||||
|
||||
=== "Javascript"
|
||||
Projections are not currently supported in the Javascript SDK.
|
||||
```javascript
|
||||
const results = await table
|
||||
.search(Array(1536).fill(1.2))
|
||||
.select(["id"])
|
||||
.execute()
|
||||
```
|
||||
|
||||
@@ -233,7 +233,7 @@ export class Query<T = number[]> {
|
||||
private _limit: number
|
||||
private _refineFactor?: number
|
||||
private _nprobes: number
|
||||
private readonly _columns?: string[]
|
||||
private _select?: string[]
|
||||
private _filter?: string
|
||||
private _metricType?: MetricType
|
||||
private readonly _embeddings?: EmbeddingFunction<T>
|
||||
@@ -244,7 +244,7 @@ export class Query<T = number[]> {
|
||||
this._limit = 10
|
||||
this._nprobes = 20
|
||||
this._refineFactor = undefined
|
||||
this._columns = undefined
|
||||
this._select = undefined
|
||||
this._filter = undefined
|
||||
this._metricType = undefined
|
||||
this._embeddings = embeddings
|
||||
@@ -286,6 +286,15 @@ export class Query<T = number[]> {
|
||||
return this
|
||||
}
|
||||
|
||||
/** Return only the specified columns.
|
||||
*
|
||||
* @param value Only select the specified columns. If not specified, all columns will be returned.
|
||||
*/
|
||||
select (value: string[]): Query<T> {
|
||||
this._select = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* The MetricType used for this Query.
|
||||
* @param value The metric to the. @see MetricType for the different options
|
||||
|
||||
@@ -72,6 +72,22 @@ describe('LanceDB client', function () {
|
||||
assert.equal(results.length, 1)
|
||||
assert.equal(results[0].id, 2)
|
||||
})
|
||||
|
||||
it('select only a subset of columns', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.1]).select(['is_active']).execute()
|
||||
assert.equal(results.length, 2)
|
||||
// vector and score are always returned
|
||||
assert.isDefined(results[0].vector)
|
||||
assert.isDefined(results[0].score)
|
||||
assert.isDefined(results[0].is_active)
|
||||
|
||||
assert.isUndefined(results[0].id)
|
||||
assert.isUndefined(results[0].name)
|
||||
assert.isUndefined(results[0].price)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when creating a new dataset', function () {
|
||||
@@ -181,11 +197,13 @@ describe('Query object', function () {
|
||||
.limit(1)
|
||||
.metricType(MetricType.Cosine)
|
||||
.refineFactor(100)
|
||||
.select(['a', 'b'])
|
||||
.nprobes(20) as Record<string, any>
|
||||
assert.equal(query._limit, 1)
|
||||
assert.equal(query._metricType, MetricType.Cosine)
|
||||
assert.equal(query._refineFactor, 100)
|
||||
assert.equal(query._nprobes, 20)
|
||||
assert.deepEqual(query._select, ['a', 'b'])
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -129,6 +129,17 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let limit = query_obj
|
||||
.get::<JsNumber, _, _>(&mut cx, "_limit")?
|
||||
.value(&mut cx);
|
||||
let select = query_obj
|
||||
.get_opt::<JsArray, _, _>(&mut cx, "_select")?
|
||||
.map(|arr| {
|
||||
let js_array = arr.deref();
|
||||
let mut projection_vec: Vec<String> = Vec::new();
|
||||
for i in 0..js_array.len(&mut cx) {
|
||||
let entry: Handle<JsString> = js_array.get(&mut cx, i).unwrap();
|
||||
projection_vec.push(entry.value(&mut cx));
|
||||
}
|
||||
projection_vec
|
||||
});
|
||||
let filter = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
|
||||
.map(|s| s.value(&mut cx));
|
||||
@@ -161,7 +172,8 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
.refine_factor(refine_factor)
|
||||
.nprobes(nprobes)
|
||||
.filter(filter)
|
||||
.metric_type(metric_type);
|
||||
.metric_type(metric_type)
|
||||
.select(select);
|
||||
let record_batch_stream = builder.execute();
|
||||
let results = record_batch_stream
|
||||
.and_then(|stream| stream.try_collect::<Vec<_>>().map_err(Error::from))
|
||||
|
||||
@@ -27,6 +27,7 @@ pub struct Query {
|
||||
pub query_vector: Float32Array,
|
||||
pub limit: usize,
|
||||
pub filter: Option<String>,
|
||||
pub select: Option<Vec<String>>,
|
||||
pub nprobes: usize,
|
||||
pub refine_factor: Option<u32>,
|
||||
pub metric_type: Option<MetricType>,
|
||||
@@ -54,6 +55,7 @@ impl Query {
|
||||
metric_type: None,
|
||||
use_index: false,
|
||||
filter: None,
|
||||
select: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +74,9 @@ impl Query {
|
||||
)?;
|
||||
scanner.nprobs(self.nprobes);
|
||||
scanner.use_index(self.use_index);
|
||||
self.select
|
||||
.as_ref()
|
||||
.map(|p| scanner.project(p.as_slice()));
|
||||
self.filter.as_ref().map(|f| scanner.filter(f));
|
||||
self.refine_factor.map(|rf| scanner.refine(rf));
|
||||
self.metric_type.map(|mt| scanner.distance_metric(mt));
|
||||
@@ -138,10 +143,23 @@ impl Query {
|
||||
self
|
||||
}
|
||||
|
||||
/// A filter statement to be applied to this query.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `filter` - value A filter in the same format used by a sql WHERE clause.
|
||||
pub fn filter(mut self, filter: Option<String>) -> Query {
|
||||
self.filter = filter;
|
||||
self
|
||||
}
|
||||
|
||||
/// Return only the specified columns.
|
||||
///
|
||||
/// Only select the specified columns. If not specified, all columns will be returned.
|
||||
pub fn select(mut self, columns: Option<Vec<String>>) -> Query {
|
||||
self.select = columns;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user