diff --git a/docs/src/ann_indexes.md b/docs/src/ann_indexes.md index bd25c789..e6ee253d 100644 --- a/docs/src/ann_indexes.md +++ b/docs/src/ann_indexes.md @@ -129,4 +129,9 @@ You can select the columns returned by the query using a select clause. ``` === "Javascript" - Projections are not currently supported in the Javascript SDK. + ```javascript + const results = await table + .search(Array(1536).fill(1.2)) + .select(["id"]) + .execute() + ``` diff --git a/node/src/index.ts b/node/src/index.ts index ff064f20..cd7d0d40 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -233,7 +233,7 @@ export class Query { private _limit: number private _refineFactor?: number private _nprobes: number - private readonly _columns?: string[] + private _select?: string[] private _filter?: string private _metricType?: MetricType private readonly _embeddings?: EmbeddingFunction @@ -244,7 +244,7 @@ export class Query { this._limit = 10 this._nprobes = 20 this._refineFactor = undefined - this._columns = undefined + this._select = undefined this._filter = undefined this._metricType = undefined this._embeddings = embeddings @@ -286,6 +286,15 @@ export class Query { return this } + /** Return only the specified columns. + * + * @param value Only select the specified columns. If not specified, all columns will be returned. + */ + select (value: string[]): Query { + this._select = value + return this + } + /** * The MetricType used for this Query. * @param value The metric to the. @see MetricType for the different options diff --git a/node/src/test/test.ts b/node/src/test/test.ts index f1476fd8..5988be6a 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -72,6 +72,22 @@ describe('LanceDB client', function () { assert.equal(results.length, 1) assert.equal(results[0].id, 2) }) + + it('select only a subset of columns', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + const results = await table.search([0.1, 0.1]).select(['is_active']).execute() + assert.equal(results.length, 2) + // vector and score are always returned + assert.isDefined(results[0].vector) + assert.isDefined(results[0].score) + assert.isDefined(results[0].is_active) + + assert.isUndefined(results[0].id) + assert.isUndefined(results[0].name) + assert.isUndefined(results[0].price) + }) }) describe('when creating a new dataset', function () { @@ -181,11 +197,13 @@ describe('Query object', function () { .limit(1) .metricType(MetricType.Cosine) .refineFactor(100) + .select(['a', 'b']) .nprobes(20) as Record assert.equal(query._limit, 1) assert.equal(query._metricType, MetricType.Cosine) assert.equal(query._refineFactor, 100) assert.equal(query._nprobes, 20) + assert.deepEqual(query._select, ['a', 'b']) }) }) diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index c19e4850..be4369c2 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -129,6 +129,17 @@ fn table_search(mut cx: FunctionContext) -> JsResult { let limit = query_obj .get::(&mut cx, "_limit")? .value(&mut cx); + let select = query_obj + .get_opt::(&mut cx, "_select")? + .map(|arr| { + let js_array = arr.deref(); + let mut projection_vec: Vec = Vec::new(); + for i in 0..js_array.len(&mut cx) { + let entry: Handle = js_array.get(&mut cx, i).unwrap(); + projection_vec.push(entry.value(&mut cx)); + } + projection_vec + }); let filter = query_obj .get_opt::(&mut cx, "_filter")? .map(|s| s.value(&mut cx)); @@ -161,7 +172,8 @@ fn table_search(mut cx: FunctionContext) -> JsResult { .refine_factor(refine_factor) .nprobes(nprobes) .filter(filter) - .metric_type(metric_type); + .metric_type(metric_type) + .select(select); let record_batch_stream = builder.execute(); let results = record_batch_stream .and_then(|stream| stream.try_collect::>().map_err(Error::from)) diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index aac6134d..27361f6b 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -27,6 +27,7 @@ pub struct Query { pub query_vector: Float32Array, pub limit: usize, pub filter: Option, + pub select: Option>, pub nprobes: usize, pub refine_factor: Option, pub metric_type: Option, @@ -54,6 +55,7 @@ impl Query { metric_type: None, use_index: false, filter: None, + select: None, } } @@ -72,6 +74,9 @@ impl Query { )?; scanner.nprobs(self.nprobes); scanner.use_index(self.use_index); + self.select + .as_ref() + .map(|p| scanner.project(p.as_slice())); self.filter.as_ref().map(|f| scanner.filter(f)); self.refine_factor.map(|rf| scanner.refine(rf)); self.metric_type.map(|mt| scanner.distance_metric(mt)); @@ -138,10 +143,23 @@ impl Query { self } + /// A filter statement to be applied to this query. + /// + /// # Arguments + /// + /// * `filter` - value A filter in the same format used by a sql WHERE clause. pub fn filter(mut self, filter: Option) -> Query { self.filter = filter; self } + + /// Return only the specified columns. + /// + /// Only select the specified columns. If not specified, all columns will be returned. + pub fn select(mut self, columns: Option>) -> Query { + self.select = columns; + self + } } #[cfg(test)]