mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 13:59:58 +00:00
Compare commits
3 Commits
python-v0.
...
wjones127/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61de91f08d | ||
|
|
b12eb7a6bf | ||
|
|
f9ccefb032 |
@@ -27,6 +27,8 @@ Currently, we support the following metrics:
|
|||||||
|
|
||||||
If you do not create a vector index, LanceDB would need to exhaustively scan the entire vector column (via `Flat Search`)
|
If you do not create a vector index, LanceDB would need to exhaustively scan the entire vector column (via `Flat Search`)
|
||||||
and compute the distance for *every* vector in order to find the closest matches. This is effectively a KNN search.
|
and compute the distance for *every* vector in order to find the closest matches. This is effectively a KNN search.
|
||||||
|
(Even if you create a vector index, you can force this behavior in LanceDB OSS by setting `use_index=False`,
|
||||||
|
as shown below.)
|
||||||
|
|
||||||
|
|
||||||
<!-- Setup Code
|
<!-- Setup Code
|
||||||
@@ -67,6 +69,7 @@ await db_setup.createTable('my_vectors', data)
|
|||||||
|
|
||||||
df = tbl.search(np.random.random((1536))) \
|
df = tbl.search(np.random.random((1536))) \
|
||||||
.limit(10) \
|
.limit(10) \
|
||||||
|
.use_index(False) \
|
||||||
.to_list()
|
.to_list()
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -80,6 +83,7 @@ await db_setup.createTable('my_vectors', data)
|
|||||||
|
|
||||||
const results_1 = await tbl.search(Array(1536).fill(1.2))
|
const results_1 = await tbl.search(Array(1536).fill(1.2))
|
||||||
.limit(10)
|
.limit(10)
|
||||||
|
.useIndex(false)
|
||||||
.execute()
|
.execute()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ export class Query<T = number[]> {
|
|||||||
private _filter?: string
|
private _filter?: string
|
||||||
private _metricType?: MetricType
|
private _metricType?: MetricType
|
||||||
protected readonly _embeddings?: EmbeddingFunction<T>
|
protected readonly _embeddings?: EmbeddingFunction<T>
|
||||||
|
private _useIndex: boolean
|
||||||
|
|
||||||
constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||||
this._tbl = tbl
|
this._tbl = tbl
|
||||||
@@ -44,6 +45,7 @@ export class Query<T = number[]> {
|
|||||||
this._filter = undefined
|
this._filter = undefined
|
||||||
this._metricType = undefined
|
this._metricType = undefined
|
||||||
this._embeddings = embeddings
|
this._embeddings = embeddings
|
||||||
|
this._useIndex = true
|
||||||
}
|
}
|
||||||
|
|
||||||
/***
|
/***
|
||||||
@@ -102,6 +104,20 @@ export class Query<T = number[]> {
|
|||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether or not to use the ANN index for this query. If set to false,
|
||||||
|
* the query will run exact KNN on a full scan of the table. The default is
|
||||||
|
* true (use the index).
|
||||||
|
*
|
||||||
|
* Setting this option to false is not currently supported by LanceDB Cloud.
|
||||||
|
*
|
||||||
|
* @param value Whether or not to use the index for this query.
|
||||||
|
*/
|
||||||
|
useIndex (value: boolean): Query<T> {
|
||||||
|
this._useIndex = value
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Execute the query and return the results as an Array of Objects
|
* Execute the query and return the results as an Array of Objects
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -93,6 +93,10 @@ export class RemoteQuery<T = number[]> extends Query<T> {
|
|||||||
queryVector = query as number[]
|
queryVector = query as number[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((this as any)._useIndex === false) {
|
||||||
|
console.warn('LanceDB Cloud does not yet support useIndex=false, ignoring option.')
|
||||||
|
}
|
||||||
|
|
||||||
const data = await this._client.search(
|
const data = await this._client.search(
|
||||||
this._name,
|
this._name,
|
||||||
queryVector,
|
queryVector,
|
||||||
|
|||||||
@@ -117,6 +117,19 @@ describe('LanceDB client', function () {
|
|||||||
assert.isUndefined(results[0].name)
|
assert.isUndefined(results[0].name)
|
||||||
assert.isUndefined(results[0].price)
|
assert.isUndefined(results[0].price)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('can choose to do flat search', async function () {
|
||||||
|
const uri = await createTestDB()
|
||||||
|
const con = await lancedb.connect(uri)
|
||||||
|
const table = await con.openTable('vectors')
|
||||||
|
const results = await table.search([0.1, 0.3]).useIndex(false).execute()
|
||||||
|
|
||||||
|
assert.equal(results.length, 2)
|
||||||
|
assert.equal(results[0].price, 10)
|
||||||
|
const vector = results[0].vector as Float32Array
|
||||||
|
assert.approximately(vector[0], 0.0, 0.2)
|
||||||
|
assert.approximately(vector[0], 0.1, 0.3)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('when creating a new dataset', function () {
|
describe('when creating a new dataset', function () {
|
||||||
@@ -385,11 +398,13 @@ describe('Query object', function () {
|
|||||||
.metricType(MetricType.Cosine)
|
.metricType(MetricType.Cosine)
|
||||||
.refineFactor(100)
|
.refineFactor(100)
|
||||||
.select(['a', 'b'])
|
.select(['a', 'b'])
|
||||||
|
.useIndex(false)
|
||||||
.nprobes(20) as Record<string, any>
|
.nprobes(20) as Record<string, any>
|
||||||
assert.equal(query._limit, 1)
|
assert.equal(query._limit, 1)
|
||||||
assert.equal(query._metricType, MetricType.Cosine)
|
assert.equal(query._metricType, MetricType.Cosine)
|
||||||
assert.equal(query._refineFactor, 100)
|
assert.equal(query._refineFactor, 100)
|
||||||
assert.equal(query._nprobes, 20)
|
assert.equal(query._nprobes, 20)
|
||||||
|
assert.equal(query._useIndex, false)
|
||||||
assert.deepEqual(query._select, ['a', 'b'])
|
assert.deepEqual(query._select, ['a', 'b'])
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ class Query(pydantic.BaseModel):
|
|||||||
# Refine factor.
|
# Refine factor.
|
||||||
refine_factor: Optional[int] = None
|
refine_factor: Optional[int] = None
|
||||||
|
|
||||||
|
use_index: bool = True
|
||||||
|
|
||||||
|
|
||||||
class LanceQueryBuilder(ABC):
|
class LanceQueryBuilder(ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -279,6 +281,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._refine_factor = None
|
self._refine_factor = None
|
||||||
self._vector_column = vector_column
|
self._vector_column = vector_column
|
||||||
self._prefilter = False
|
self._prefilter = False
|
||||||
|
self._use_index = True
|
||||||
|
|
||||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||||
"""Set the distance metric to use.
|
"""Set the distance metric to use.
|
||||||
@@ -340,6 +343,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._refine_factor = refine_factor
|
self._refine_factor = refine_factor
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def use_index(self, use_index: bool) -> LanceVectorQueryBuilder:
|
||||||
|
"""
|
||||||
|
Choose whether to use an ANN index or not. Default is True.
|
||||||
|
|
||||||
|
Setting this to False is not yet supported on LanceDB Cloud.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
use_index: bool
|
||||||
|
If True, use an ANN index if one exists, otherwise perform exact KNN
|
||||||
|
on a full table scan.
|
||||||
|
"""
|
||||||
|
self._use_index = use_index
|
||||||
|
return self
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
"""
|
"""
|
||||||
Execute the query and return the results as an
|
Execute the query and return the results as an
|
||||||
@@ -360,6 +378,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
nprobes=self._nprobes,
|
nprobes=self._nprobes,
|
||||||
refine_factor=self._refine_factor,
|
refine_factor=self._refine_factor,
|
||||||
vector_column=self._vector_column,
|
vector_column=self._vector_column,
|
||||||
|
use_index=self._use_index,
|
||||||
)
|
)
|
||||||
return self._table._execute_query(query)
|
return self._table._execute_query(query)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
import warnings
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@@ -101,6 +102,8 @@ class RemoteTable(Table):
|
|||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
if query.prefilter:
|
if query.prefilter:
|
||||||
raise NotImplementedError("Cloud support for prefiltering is coming soon")
|
raise NotImplementedError("Cloud support for prefiltering is coming soon")
|
||||||
|
if not query.use_index:
|
||||||
|
warnings.warn("LanceDB Cloud does not yet support use_index=False")
|
||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||||
|
|
||||||
|
|||||||
@@ -559,7 +559,8 @@ class LanceTable(Table):
|
|||||||
The data to insert into the table.
|
The data to insert into the table.
|
||||||
mode: str
|
mode: str
|
||||||
The mode to use when writing the data. Valid values are
|
The mode to use when writing the data. Valid values are
|
||||||
"append" and "overwrite".
|
"append", which inserts new rows, and "overwrite", which replaces
|
||||||
|
the entire content of the table with the new rows.
|
||||||
on_bad_vectors: str, default "error"
|
on_bad_vectors: str, default "error"
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
One of "error", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
@@ -887,6 +888,7 @@ class LanceTable(Table):
|
|||||||
"metric": query.metric,
|
"metric": query.metric,
|
||||||
"nprobes": query.nprobes,
|
"nprobes": query.nprobes,
|
||||||
"refine_factor": query.refine_factor,
|
"refine_factor": query.refine_factor,
|
||||||
|
"use_index": query.use_index,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class MockTable:
|
|||||||
"metric": query.metric,
|
"metric": query.metric,
|
||||||
"nprobes": query.nprobes,
|
"nprobes": query.nprobes,
|
||||||
"refine_factor": query.refine_factor,
|
"refine_factor": query.refine_factor,
|
||||||
|
"use_index": query.use_index,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,10 +85,12 @@ def test_cast(table):
|
|||||||
assert r0.float_field == 1.0
|
assert r0.float_field == 1.0
|
||||||
|
|
||||||
|
|
||||||
def test_query_builder(table):
|
@pytest.mark.parametrize("use_index", [True, False])
|
||||||
|
def test_query_builder(table, use_index: bool):
|
||||||
rs = (
|
rs = (
|
||||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
.limit(1)
|
.limit(1)
|
||||||
|
.use_index(use_index)
|
||||||
.select(["id"])
|
.select(["id"])
|
||||||
.to_list()
|
.to_list()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -47,6 +47,10 @@ impl JsQuery {
|
|||||||
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
||||||
.map(|s| s.value(&mut cx))
|
.map(|s| s.value(&mut cx))
|
||||||
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
||||||
|
let use_index = query_obj
|
||||||
|
.get_opt::<JsBoolean, _, _>(&mut cx, "_useIndex")?
|
||||||
|
.map(|val| val.value(&mut cx))
|
||||||
|
.unwrap_or(true);
|
||||||
|
|
||||||
let is_electron = cx
|
let is_electron = cx
|
||||||
.argument::<JsBoolean>(1)
|
.argument::<JsBoolean>(1)
|
||||||
@@ -69,7 +73,8 @@ impl JsQuery {
|
|||||||
.nprobes(nprobes)
|
.nprobes(nprobes)
|
||||||
.filter(filter)
|
.filter(filter)
|
||||||
.metric_type(metric_type)
|
.metric_type(metric_type)
|
||||||
.select(select);
|
.select(select)
|
||||||
|
.use_index(use_index);
|
||||||
let record_batch_stream = builder.execute();
|
let record_batch_stream = builder.execute();
|
||||||
let results = record_batch_stream
|
let results = record_batch_stream
|
||||||
.and_then(|stream| {
|
.and_then(|stream| {
|
||||||
|
|||||||
Reference in New Issue
Block a user