Compare commits

...

3 Commits

Author SHA1 Message Date
Will Jones
61de91f08d relax to warning 2023-10-13 12:44:23 -07:00
Will Jones
b12eb7a6bf be consistent about cloud behavior 2023-10-13 11:30:37 -07:00
Will Jones
f9ccefb032 feat: expose use_index in LanceDB OSS 2023-10-13 11:17:10 -07:00
9 changed files with 74 additions and 3 deletions

View File

@@ -27,6 +27,8 @@ Currently, we support the following metrics:
If you do not create a vector index, LanceDB would need to exhaustively scan the entire vector column (via `Flat Search`) If you do not create a vector index, LanceDB would need to exhaustively scan the entire vector column (via `Flat Search`)
and compute the distance for *every* vector in order to find the closest matches. This is effectively a KNN search. and compute the distance for *every* vector in order to find the closest matches. This is effectively a KNN search.
(Even if you create a vector index, you can force this behavior in LanceDB OSS by setting `use_index=False`,
as shown below.)
<!-- Setup Code <!-- Setup Code
@@ -67,6 +69,7 @@ await db_setup.createTable('my_vectors', data)
df = tbl.search(np.random.random((1536))) \ df = tbl.search(np.random.random((1536))) \
.limit(10) \ .limit(10) \
.use_index(False) \
.to_list() .to_list()
``` ```
@@ -80,6 +83,7 @@ await db_setup.createTable('my_vectors', data)
const results_1 = await tbl.search(Array(1536).fill(1.2)) const results_1 = await tbl.search(Array(1536).fill(1.2))
.limit(10) .limit(10)
.useIndex(false)
.execute() .execute()
``` ```

View File

@@ -33,6 +33,7 @@ export class Query<T = number[]> {
private _filter?: string private _filter?: string
private _metricType?: MetricType private _metricType?: MetricType
protected readonly _embeddings?: EmbeddingFunction<T> protected readonly _embeddings?: EmbeddingFunction<T>
private _useIndex: boolean
constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) { constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl this._tbl = tbl
@@ -44,6 +45,7 @@ export class Query<T = number[]> {
this._filter = undefined this._filter = undefined
this._metricType = undefined this._metricType = undefined
this._embeddings = embeddings this._embeddings = embeddings
this._useIndex = true
} }
/*** /***
@@ -102,6 +104,20 @@ export class Query<T = number[]> {
return this return this
} }
/**
* Whether or not to use the ANN index for this query. If set to false,
* the query will run exact KNN on a full scan of the table. The default is
* true (use the index).
*
* Setting this option to false is not currently supported by LanceDB Cloud.
*
* @param value Whether or not to use the index for this query.
*/
useIndex (value: boolean): Query<T> {
this._useIndex = value
return this
}
/** /**
* Execute the query and return the results as an Array of Objects * Execute the query and return the results as an Array of Objects
*/ */

View File

@@ -93,6 +93,10 @@ export class RemoteQuery<T = number[]> extends Query<T> {
queryVector = query as number[] queryVector = query as number[]
} }
if ((this as any)._useIndex === false) {
console.warn('LanceDB Cloud does not yet support useIndex=false, ignoring option.')
}
const data = await this._client.search( const data = await this._client.search(
this._name, this._name,
queryVector, queryVector,

View File

@@ -117,6 +117,19 @@ describe('LanceDB client', function () {
assert.isUndefined(results[0].name) assert.isUndefined(results[0].name)
assert.isUndefined(results[0].price) assert.isUndefined(results[0].price)
}) })
it('can choose to do flat search', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.3]).useIndex(false).execute()
assert.equal(results.length, 2)
assert.equal(results[0].price, 10)
const vector = results[0].vector as Float32Array
assert.approximately(vector[0], 0.0, 0.2)
assert.approximately(vector[0], 0.1, 0.3)
})
}) })
describe('when creating a new dataset', function () { describe('when creating a new dataset', function () {
@@ -385,11 +398,13 @@ describe('Query object', function () {
.metricType(MetricType.Cosine) .metricType(MetricType.Cosine)
.refineFactor(100) .refineFactor(100)
.select(['a', 'b']) .select(['a', 'b'])
.useIndex(false)
.nprobes(20) as Record<string, any> .nprobes(20) as Record<string, any>
assert.equal(query._limit, 1) assert.equal(query._limit, 1)
assert.equal(query._metricType, MetricType.Cosine) assert.equal(query._metricType, MetricType.Cosine)
assert.equal(query._refineFactor, 100) assert.equal(query._refineFactor, 100)
assert.equal(query._nprobes, 20) assert.equal(query._nprobes, 20)
assert.equal(query._useIndex, false)
assert.deepEqual(query._select, ['a', 'b']) assert.deepEqual(query._select, ['a', 'b'])
}) })
}) })

View File

@@ -59,6 +59,8 @@ class Query(pydantic.BaseModel):
# Refine factor. # Refine factor.
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
use_index: bool = True
class LanceQueryBuilder(ABC): class LanceQueryBuilder(ABC):
@classmethod @classmethod
@@ -279,6 +281,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = None self._refine_factor = None
self._vector_column = vector_column self._vector_column = vector_column
self._prefilter = False self._prefilter = False
self._use_index = True
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
@@ -340,6 +343,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor self._refine_factor = refine_factor
return self return self
def use_index(self, use_index: bool) -> LanceVectorQueryBuilder:
"""
Choose whether to use an ANN index or not. Default is True.
Setting this to False is not yet supported on LanceDB Cloud.
Parameters
----------
use_index: bool
If True, use an ANN index if one exists, otherwise perform exact KNN
on a full table scan.
"""
self._use_index = use_index
return self
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
""" """
Execute the query and return the results as an Execute the query and return the results as an
@@ -360,6 +378,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
nprobes=self._nprobes, nprobes=self._nprobes,
refine_factor=self._refine_factor, refine_factor=self._refine_factor,
vector_column=self._vector_column, vector_column=self._vector_column,
use_index=self._use_index,
) )
return self._table._execute_query(query) return self._table._execute_query(query)

View File

@@ -12,6 +12,7 @@
# limitations under the License. # limitations under the License.
import uuid import uuid
import warnings
from functools import cached_property from functools import cached_property
from typing import Optional, Union from typing import Optional, Union
@@ -101,6 +102,8 @@ class RemoteTable(Table):
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
if query.prefilter: if query.prefilter:
raise NotImplementedError("Cloud support for prefiltering is coming soon") raise NotImplementedError("Cloud support for prefiltering is coming soon")
if not query.use_index:
warnings.warn("LanceDB Cloud does not yet support use_index=False")
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow() return self._conn._loop.run_until_complete(result).to_arrow()

View File

@@ -559,7 +559,8 @@ class LanceTable(Table):
The data to insert into the table. The data to insert into the table.
mode: str mode: str
The mode to use when writing the data. Valid values are The mode to use when writing the data. Valid values are
"append" and "overwrite". "append", which inserts new rows, and "overwrite", which replaces
the entire content of the table with the new rows.
on_bad_vectors: str, default "error" on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs. What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill". One of "error", "drop", "fill".
@@ -887,6 +888,7 @@ class LanceTable(Table):
"metric": query.metric, "metric": query.metric,
"nprobes": query.nprobes, "nprobes": query.nprobes,
"refine_factor": query.refine_factor, "refine_factor": query.refine_factor,
"use_index": query.use_index,
}, },
) )

View File

@@ -46,6 +46,7 @@ class MockTable:
"metric": query.metric, "metric": query.metric,
"nprobes": query.nprobes, "nprobes": query.nprobes,
"refine_factor": query.refine_factor, "refine_factor": query.refine_factor,
"use_index": query.use_index,
}, },
) )
@@ -84,10 +85,12 @@ def test_cast(table):
assert r0.float_field == 1.0 assert r0.float_field == 1.0
def test_query_builder(table): @pytest.mark.parametrize("use_index", [True, False])
def test_query_builder(table, use_index: bool):
rs = ( rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector") LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(1) .limit(1)
.use_index(use_index)
.select(["id"]) .select(["id"])
.to_list() .to_list()
) )

View File

@@ -47,6 +47,10 @@ impl JsQuery {
.get_opt::<JsString, _, _>(&mut cx, "_metricType")? .get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx)) .map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap()); .map(|s| MetricType::try_from(s.as_str()).unwrap());
let use_index = query_obj
.get_opt::<JsBoolean, _, _>(&mut cx, "_useIndex")?
.map(|val| val.value(&mut cx))
.unwrap_or(true);
let is_electron = cx let is_electron = cx
.argument::<JsBoolean>(1) .argument::<JsBoolean>(1)
@@ -69,7 +73,8 @@ impl JsQuery {
.nprobes(nprobes) .nprobes(nprobes)
.filter(filter) .filter(filter)
.metric_type(metric_type) .metric_type(metric_type)
.select(select); .select(select)
.use_index(use_index);
let record_batch_stream = builder.execute(); let record_batch_stream = builder.execute();
let results = record_batch_stream let results = record_batch_stream
.and_then(|stream| { .and_then(|stream| {