Compare commits

...

3 Commits

Author SHA1 Message Date
Will Jones
61de91f08d relax to warning 2023-10-13 12:44:23 -07:00
Will Jones
b12eb7a6bf be consistent about cloud behavior 2023-10-13 11:30:37 -07:00
Will Jones
f9ccefb032 feat: expose use_index in LanceDB OSS 2023-10-13 11:17:10 -07:00
9 changed files with 74 additions and 3 deletions

View File

@@ -27,6 +27,8 @@ Currently, we support the following metrics:
If you do not create a vector index, LanceDB would need to exhaustively scan the entire vector column (via `Flat Search`)
and compute the distance for *every* vector in order to find the closest matches. This is effectively a KNN search.
(Even if you create a vector index, you can force this behavior in LanceDB OSS by setting `use_index=False`,
as shown below.)
<!-- Setup Code
@@ -67,6 +69,7 @@ await db_setup.createTable('my_vectors', data)
df = tbl.search(np.random.random((1536))) \
.limit(10) \
.use_index(False) \
.to_list()
```
@@ -80,6 +83,7 @@ await db_setup.createTable('my_vectors', data)
const results_1 = await tbl.search(Array(1536).fill(1.2))
.limit(10)
.useIndex(false)
.execute()
```

View File

@@ -33,6 +33,7 @@ export class Query<T = number[]> {
private _filter?: string
private _metricType?: MetricType
protected readonly _embeddings?: EmbeddingFunction<T>
private _useIndex: boolean
constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
@@ -44,6 +45,7 @@ export class Query<T = number[]> {
this._filter = undefined
this._metricType = undefined
this._embeddings = embeddings
this._useIndex = true
}
/***
@@ -102,6 +104,20 @@ export class Query<T = number[]> {
return this
}
/**
* Whether or not to use the ANN index for this query. If set to false,
* the query will run exact KNN on a full scan of the table. The default is
* true (use the index).
*
* Setting this option to false is not currently supported by LanceDB Cloud.
*
* @param value Whether or not to use the index for this query.
*/
useIndex (value: boolean): Query<T> {
this._useIndex = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/

View File

@@ -93,6 +93,10 @@ export class RemoteQuery<T = number[]> extends Query<T> {
queryVector = query as number[]
}
if ((this as any)._useIndex === false) {
console.warn('LanceDB Cloud does not yet support useIndex=false, ignoring option.')
}
const data = await this._client.search(
this._name,
queryVector,

View File

@@ -117,6 +117,19 @@ describe('LanceDB client', function () {
assert.isUndefined(results[0].name)
assert.isUndefined(results[0].price)
})
it('can choose to do flat search', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.3]).useIndex(false).execute()
assert.equal(results.length, 2)
assert.equal(results[0].price, 10)
const vector = results[0].vector as Float32Array
assert.approximately(vector[0], 0.0, 0.2)
assert.approximately(vector[0], 0.1, 0.3)
})
})
describe('when creating a new dataset', function () {
@@ -385,11 +398,13 @@ describe('Query object', function () {
.metricType(MetricType.Cosine)
.refineFactor(100)
.select(['a', 'b'])
.useIndex(false)
.nprobes(20) as Record<string, any>
assert.equal(query._limit, 1)
assert.equal(query._metricType, MetricType.Cosine)
assert.equal(query._refineFactor, 100)
assert.equal(query._nprobes, 20)
assert.equal(query._useIndex, false)
assert.deepEqual(query._select, ['a', 'b'])
})
})

View File

@@ -59,6 +59,8 @@ class Query(pydantic.BaseModel):
# Refine factor.
refine_factor: Optional[int] = None
use_index: bool = True
class LanceQueryBuilder(ABC):
@classmethod
@@ -279,6 +281,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
self._use_index = True
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
@@ -340,6 +343,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor
return self
def use_index(self, use_index: bool) -> LanceVectorQueryBuilder:
"""
Choose whether to use an ANN index or not. Default is True.
Setting this to False is not yet supported on LanceDB Cloud.
Parameters
----------
use_index: bool
If True, use an ANN index if one exists, otherwise perform exact KNN
on a full table scan.
"""
self._use_index = use_index
return self
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
@@ -360,6 +378,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
nprobes=self._nprobes,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
use_index=self._use_index,
)
return self._table._execute_query(query)

View File

@@ -12,6 +12,7 @@
# limitations under the License.
import uuid
import warnings
from functools import cached_property
from typing import Optional, Union
@@ -101,6 +102,8 @@ class RemoteTable(Table):
def _execute_query(self, query: Query) -> pa.Table:
if query.prefilter:
raise NotImplementedError("Cloud support for prefiltering is coming soon")
if not query.use_index:
warnings.warn("LanceDB Cloud does not yet support use_index=False")
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()

View File

@@ -559,7 +559,8 @@ class LanceTable(Table):
The data to insert into the table.
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
"append", which inserts new rows, and "overwrite", which replaces
the entire content of the table with the new rows.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
@@ -887,6 +888,7 @@ class LanceTable(Table):
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
"use_index": query.use_index,
},
)

View File

@@ -46,6 +46,7 @@ class MockTable:
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
"use_index": query.use_index,
},
)
@@ -84,10 +85,12 @@ def test_cast(table):
assert r0.float_field == 1.0
def test_query_builder(table):
@pytest.mark.parametrize("use_index", [True, False])
def test_query_builder(table, use_index: bool):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(1)
.use_index(use_index)
.select(["id"])
.to_list()
)

View File

@@ -47,6 +47,10 @@ impl JsQuery {
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap());
let use_index = query_obj
.get_opt::<JsBoolean, _, _>(&mut cx, "_useIndex")?
.map(|val| val.value(&mut cx))
.unwrap_or(true);
let is_electron = cx
.argument::<JsBoolean>(1)
@@ -69,7 +73,8 @@ impl JsQuery {
.nprobes(nprobes)
.filter(filter)
.metric_type(metric_type)
.select(select);
.select(select)
.use_index(use_index);
let record_batch_stream = builder.execute();
let results = record_batch_stream
.and_then(|stream| {