mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 22:29:58 +00:00
Compare commits
4 Commits
lancedb-cl
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ca0b15354 | ||
|
|
d8c217b47d | ||
|
|
b724b1a01f | ||
|
|
abd75e0ead |
14
Cargo.toml
14
Cargo.toml
@@ -23,13 +23,13 @@ rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.19.2", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "52.2", optional = false }
|
||||
arrow-array = "52.2"
|
||||
|
||||
@@ -222,12 +222,10 @@ nav:
|
||||
- 🦀 Rust: https://docs.rs/lancedb/latest/lancedb/
|
||||
- ☁️ LanceDB Cloud:
|
||||
- Overview: cloud/index.md
|
||||
- Quickstart: cloud/quickstart.md
|
||||
- Best Practices: cloud/best_practices.md
|
||||
# - API reference:
|
||||
# - 🐍 Python: python/saas-python.md
|
||||
# - 👾 JavaScript: javascript/modules.md
|
||||
# - REST API: cloud/rest.md
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- REST API: cloud/rest.md
|
||||
|
||||
- Quick start: basic.md
|
||||
- Concepts:
|
||||
@@ -350,17 +348,10 @@ nav:
|
||||
- Rust: https://docs.rs/lancedb/latest/lancedb/index.html
|
||||
- LanceDB Cloud:
|
||||
- Overview: cloud/index.md
|
||||
- Quickstart: cloud/quickstart.md
|
||||
- Work with data:
|
||||
- Ingest data: cloud/ingest_data.md
|
||||
- Update data: cloud/update_data.md
|
||||
- Build an index: cloud/build_index.md
|
||||
- Vector search: cloud/vector_search.md
|
||||
- Full-text search: cloud/full_text_search.md
|
||||
- Hybrid search: cloud/hybrid_search.md
|
||||
- Metadata Filtering: cloud/metadata_filtering.md
|
||||
- Best Practices: cloud/best_practices.md
|
||||
# - REST API: cloud/rest.md
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- REST API: cloud/rest.md
|
||||
|
||||
extra_css:
|
||||
- styles/global.css
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
This section provides a set of recommended best practices to help you get the most out of LanceDB Cloud. By following these guidelines, you can optimize your usage of LanceDB Cloud, improve performance, and ensure a smooth experience.
|
||||
|
||||
### Should the db connection be created once and keep it open?
|
||||
Yes! It is recommended to establish a single db connection and maintain it throughout your interaction with the tables within.
|
||||
|
||||
LanceDB uses `requests.Session()` for connection pooling, which automatically manages connection reuse and cleanup. This approach avoids the overhead of repeatedly establishing HTTP connections, significantly improving efficiency.
|
||||
|
||||
### Should a single `open_table` call be made and maintained for subsequent table operations?
|
||||
`table = db.open_table()` should be called once and used for all subsequent table operations. If there are changes to the opened table, `table` always reflect the latest version of the data.
|
||||
|
||||
### Row id
|
||||
|
||||
### What are the vector indexing types supported by LanceDB Cloud?
|
||||
We support `IVF_PQ` and `IVF_HNSW_SQ` as the `index_type` which is passed to `create_index`. LanceDB Cloud tunes the indexing parameters automatically to achieve the best tradeoff betweeln query latency and query quality.
|
||||
|
||||
### Do I need to do anything when there is new data added to a table with an existing index?
|
||||
No! LanceDB Cloud triggers an asynchronous background job to index the new vectors. This process will either merge the new vectors into the existing index or initiate a complete re-indexing if needed.
|
||||
|
||||
There is a flag `fast_search` in `table.search()` that allows you to control whether the unindexed rows should be searched or not.
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
LanceDB Cloud supports **vector index**, **scalar index** and **full-text search index**. Compared to open-source version, LanceDB Cloud focuses on **automation**:
|
||||
|
||||
- If there is a single vector column in the table, the vector column can be inferred from the schema and the index will be automatically created.
|
||||
|
||||
- Indexing parameters will be automatically tuned for customer's data.
|
||||
|
||||
## Vector index
|
||||
LanceDB has implemented the state-of-art indexing algorithms (more about [IVF-PQ](https://lancedb.github.io/lancedb/concepts/index_ivfpq/) and [HNSW](https://lancedb.github.io/lancedb/concepts/index_hnsw/)). We currently
|
||||
support the _L2_, _Cosine_ and _Dot_ as distance calculation metrics. You can create multiple vector indices within a table.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:create_index"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:connect_db_and_open_table"
|
||||
--8<-- "nodejs/examples/cloud.test.ts:create_index"
|
||||
```
|
||||
|
||||
## Scalar index
|
||||
LanceDB Cloud and LanceDB Enterprise supports several types of Scalar indices to accelerate search over scalar columns.
|
||||
|
||||
- *BTREE*: The most common type is BTREE. This index is inspired by the btree data structure although only the first few layers of the btree are cached in memory. It will perform well on columns with a large number of unique values and few rows per value.
|
||||
- *BITMAP*: this index stores a bitmap for each unique value in the column. This index is useful for columns with a finite number of unique values and many rows per value.
|
||||
- For example, columns that represent "categories", "labels", or "tags"
|
||||
- *LABEL_LIST*: a special index that is used to index list columns whose values have a finite set of possibilities.
|
||||
- For example, a column that contains lists of tags (e.g. ["tag1", "tag2", "tag3"]) can be indexed with a LABEL_LIST index.
|
||||
|
||||
You can create multiple scalar indices within a table.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:create_scalar_index"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:connect_db_and_open_table"
|
||||
--8<-- "nodejs/examples/cloud.test.ts:create_scalar_index"
|
||||
```
|
||||
|
||||
## Full-text search index
|
||||
We provide performant full-text search on LanceDB Cloud, allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions.
|
||||
!!! note ""
|
||||
|
||||
`use_tantivy` is not available with `create_fts_index` on LanceDB Cloud as we used our native implementation, which has better performance comparing to tantivy.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:create_fts_index"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:create_fts_index"
|
||||
```
|
||||
@@ -1,14 +0,0 @@
|
||||
The full-text search allows you to
|
||||
incorporate keyword-based search (based on BM25) in your retrieval solutions.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:full_text_search"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:full_text_search"
|
||||
```
|
||||
@@ -1,10 +0,0 @@
|
||||
We support hybrid search that combines semantic and full-text search via a
|
||||
reranking algorithm of your choice, to get the best of both worlds. LanceDB
|
||||
comes with [built-in rerankers](https://lancedb.github.io/lancedb/reranking/)
|
||||
and you can implement you own _customized reranker_ as well.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:hybrid_search"
|
||||
```
|
||||
@@ -1,31 +0,0 @@
|
||||
## Insert data
|
||||
The LanceDB Cloud SDK for data ingestion remains consistent with our open-source version,
|
||||
ensuring a seamless transition for existing OSS users.
|
||||
!!! note "unsupported parameters in create_table"
|
||||
|
||||
The following two parameters: `mode="overwrite"` and `exist_ok`, are expected to be added by Nov, 2024.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:import-ingest-data"
|
||||
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:ingest_data"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
--8<-- "nodejs/examples/cloud.test.ts:ingest_data"
|
||||
```
|
||||
|
||||
## Insert large datasets
|
||||
It is recommended to use itertators to add large datasets in batches when creating
|
||||
your table in one go. Data will be automatically compacted for the best query performance.
|
||||
!!! info "batch size"
|
||||
|
||||
The batch size .
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:ingest_data_in_batch"
|
||||
```
|
||||
@@ -1,33 +0,0 @@
|
||||
LanceDB Cloud supports rich filtering features of query results based on metadata fields.
|
||||
|
||||
By default, _post-filtering_ is performed on the top-k results returned by the vector search.
|
||||
However, _pre-filtering_ is also an option that performs the filter prior to vector search.
|
||||
This can be useful to narrow down on the search space on a very large dataset to reduce query
|
||||
latency.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:filtering"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:filtering"
|
||||
```
|
||||
We also support standard SQL expressions as predicates for filtering operations.
|
||||
It can be used during vector search, update, and deletion operations.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:sql_filtering"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:sql_filtering"
|
||||
```
|
||||
@@ -1,49 +0,0 @@
|
||||
LanceDB Cloud efficiently manages updates across many tables.
|
||||
Currently, we offer _update_, _merge_insert_, and _delete_.
|
||||
|
||||
## update
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:update_data"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:connect_db_and_open_table"
|
||||
--8<-- "nodejs/examples/cloud.test.ts:update_data"
|
||||
```
|
||||
|
||||
## merge insert
|
||||
This merge insert can add rows, update rows, and remove rows all in a single transaction.
|
||||
It combines new data from a source table with existing data in a target table by using a join.
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:merge_insert"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:connect_db_and_open_table"
|
||||
--8<-- "nodejs/examples/cloud.test.ts:merge_insert"
|
||||
```
|
||||
|
||||
## delete
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:delete_data"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:connect_db_and_open_table"
|
||||
--8<-- "nodejs/examples/cloud.test.ts:delete_data"
|
||||
```
|
||||
@@ -1,21 +0,0 @@
|
||||
Users can also tune the following parameters for better search quality.
|
||||
|
||||
- [nprobes](https://lancedb.github.io/lancedb/js/classes/VectorQuery/#nprobes):
|
||||
the number of partitions to search (probe).
|
||||
- [refine factor](https://lancedb.github.io/lancedb/js/classes/VectorQuery/#refinefactor):
|
||||
a multiplier to control how many additional rows are taken during the refine step.
|
||||
|
||||
[Metadata filtering](filtering) combined with the vector search is also supported.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
--8<-- "python/python/tests/docs/test_cloud.py:vector_search"
|
||||
```
|
||||
=== "Typescript"
|
||||
|
||||
```typescript
|
||||
--8<-- "nodejs/examples/cloud.test.ts:imports"
|
||||
|
||||
--8<-- "nodejs/examples/cloud.test.ts:vector_search"
|
||||
```
|
||||
@@ -22,8 +22,7 @@ excluded_globs = [
|
||||
"../src/embeddings/available_embedding_models/text_embedding_functions/*.md",
|
||||
"../src/embeddings/available_embedding_models/multimodal_embedding_functions/*.md",
|
||||
"../src/rag/*.md",
|
||||
"../src/rag/advanced_techniques/*.md",
|
||||
"../src/cloud/*.md"
|
||||
"../src/rag/advanced_techniques/*.md"
|
||||
|
||||
|
||||
]
|
||||
|
||||
@@ -998,4 +998,18 @@ describe("column name options", () => {
|
||||
const results = await table.query().where("`camelCase` = 1").toArray();
|
||||
expect(results[0].camelCase).toBe(1);
|
||||
});
|
||||
|
||||
test("can make multiple vector queries in one go", async () => {
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.2])
|
||||
.addQueryVector([0.1, 0.2])
|
||||
.limit(1)
|
||||
.toArray();
|
||||
console.log(results);
|
||||
expect(results.length).toBe(2);
|
||||
results.sort((a, b) => a.query_index - b.query_index);
|
||||
expect(results[0].query_index).toBe(0);
|
||||
expect(results[1].query_index).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
// --8<-- [start:imports]
|
||||
import * as lancedb from "@lancedb/lancedb";
|
||||
// --8<-- [end:imports]
|
||||
|
||||
// --8<-- [start:generate_data]
|
||||
function genData(numRows: number, numVectorDim: number): any[] {
|
||||
const data = [];
|
||||
for (let i = 0; i < numRows; i++) {
|
||||
const vector = [];
|
||||
for (let j = 0; j < numVectorDim; j++) {
|
||||
vector.push(i + j * 0.1);
|
||||
}
|
||||
data.push({
|
||||
id: i,
|
||||
name: `name_${i}`,
|
||||
vector,
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}
|
||||
// --8<-- [end:generate_data]
|
||||
|
||||
test("cloud quickstart", async () => {
|
||||
{
|
||||
// --8<-- [start:connect]
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "your-cloud-region",
|
||||
});
|
||||
// --8<-- [end:connect]
|
||||
// --8<-- [start:create_table]
|
||||
const tableName = "myTable"
|
||||
const data = genData(5000, 1536)
|
||||
const table = await db.createTable(tableName, data);
|
||||
// --8<-- [end:create_table]
|
||||
// --8<-- [start:create_index_search]
|
||||
// create a vector index
|
||||
await table.createIndex({
|
||||
column: "vector",
|
||||
metric_type: lancedb.MetricType.Cosine,
|
||||
type: "ivf_pq",
|
||||
});
|
||||
const result = await table.search([0.01, 0.02])
|
||||
.select(["vector", "item"])
|
||||
.limit(1)
|
||||
.execute();
|
||||
// --8<-- [end:create_index_search]
|
||||
// --8<-- [start:drop_table]
|
||||
await db.dropTable(tableName);
|
||||
// --8<-- [end:drop_table]
|
||||
}
|
||||
});
|
||||
|
||||
test("ingest data", async () => {
|
||||
// --8<-- [start:ingest_data]
|
||||
import { Schema, Field, Float32, FixedSizeList, Utf8 } from "apache-arrow";
|
||||
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "us-east-1"
|
||||
});
|
||||
|
||||
const data = [
|
||||
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
||||
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
||||
{ vector: [10.2, 100.8], item: "baz", price: 30.0},
|
||||
{ vector: [1.4, 9.5], item: "fred", price: 40.0},
|
||||
]
|
||||
// create an empty table with schema
|
||||
const schema = new Schema([
|
||||
new Field(
|
||||
"vector",
|
||||
new FixedSizeList(2, new Field("float32", new Float32())),
|
||||
),
|
||||
new Field("item", new Utf8()),
|
||||
new Field("price", new Float32()),
|
||||
]);
|
||||
const tableName = "myTable";
|
||||
const table = await db.createTable({
|
||||
name: tableName,
|
||||
schema,
|
||||
});
|
||||
await table.add(data);
|
||||
// --8<-- [end:ingest_data]
|
||||
});
|
||||
|
||||
test("update data", async () => {
|
||||
// --8<-- [start:connect_db_and_open_table]
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "us-east-1"
|
||||
});
|
||||
const tableName = "myTable"
|
||||
const table = await db.openTable(tableName);
|
||||
// --8<-- [end:connect_db_and_open_table]
|
||||
// --8<-- [start:update_data]
|
||||
await table.update({
|
||||
where: "price < 20.0",
|
||||
values: { vector: [2, 2], item: "foo-updated" },
|
||||
});
|
||||
// --8<-- [end:update_data]
|
||||
// --8<-- [start:merge_insert]
|
||||
let newData = [
|
||||
{vector: [1, 1], item: 'foo-updated', price: 50.0}
|
||||
];
|
||||
// upsert
|
||||
await table.mergeInsert("item", newData, {
|
||||
whenMatchedUpdateAll: true,
|
||||
whenNotMatchedInsertAll: true,
|
||||
});
|
||||
// --8<-- [end:merge_insert]
|
||||
// --8<-- [start:delete_data]
|
||||
// delete data
|
||||
const predicate = "price = 30.0";
|
||||
await table.delete(predicate);
|
||||
// --8<-- [end:delete_data]
|
||||
});
|
||||
|
||||
test("create index", async () => {
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "us-east-1"
|
||||
});
|
||||
|
||||
const tableName = "myTable";
|
||||
const table = await db.openTable(tableName);
|
||||
// --8<-- [start:create_index]
|
||||
// the vector column only needs to be specified when there are
|
||||
// multiple vector columns or the column is not named as "vector"
|
||||
// L2 is used as the default distance metric
|
||||
await table.createIndex({
|
||||
column: "vector",
|
||||
metric_type: lancedb.MetricType.Cosine,
|
||||
});
|
||||
|
||||
// --8<-- [end:create_index]
|
||||
// --8<-- [start:create_scalar_index]
|
||||
await table.createScalarIndex("item");
|
||||
// --8<-- [end:create_scalar_index]
|
||||
// --8<-- [start:create_fts_index]
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "us-east-1"
|
||||
});
|
||||
|
||||
const tableName = "myTable"
|
||||
const data = [
|
||||
{ vector: [3.1, 4.1], text: "Frodo was a happy puppy" },
|
||||
{ vector: [5.9, 26.5], text: "There are several kittens playing" },
|
||||
];
|
||||
const table = createTable(tableName, data);
|
||||
await table.createIndex("text", {
|
||||
config: lancedb.Index.fts(),
|
||||
});
|
||||
// --8<-- [end:create_fts_index]
|
||||
});
|
||||
|
||||
test("vector search", async () => {
|
||||
// --8<-- [start:vector_search]
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "us-east-1"
|
||||
});
|
||||
|
||||
const tableName = "myTable"
|
||||
const table = await db.openTable(tableName);
|
||||
const result = await table.search([0.4, 1.4])
|
||||
.where("price > 10.0")
|
||||
.prefilter(true)
|
||||
.select(["item", "vector"])
|
||||
.limit(2)
|
||||
.execute();
|
||||
// --8<-- [end:vector_search]
|
||||
});
|
||||
|
||||
test("full-text search", async () => {
|
||||
// --8<-- [start:full_text_search]
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "us-east-1"
|
||||
});
|
||||
|
||||
const data = [
|
||||
{ vector: [3.1, 4.1], text: "Frodo was a happy puppy" },
|
||||
{ vector: [5.9, 26.5], text: "There are several kittens playing" },
|
||||
];
|
||||
const tableName = "myTable"
|
||||
const table = await db.createTable(tableName, data);
|
||||
await table.createIndex("text", {
|
||||
config: lancedb.Index.fts(),
|
||||
});
|
||||
|
||||
await tableName
|
||||
.search("puppy", queryType="fts")
|
||||
.select(["text"])
|
||||
.limit(10)
|
||||
.toArray();
|
||||
// --8<-- [end:full_text_search]
|
||||
});
|
||||
|
||||
test("metadata filtering", async () => {
|
||||
// --8<-- [start:filtering]
|
||||
const db = await lancedb.connect({
|
||||
uri: "db://your-project-slug",
|
||||
apiKey: "your-api-key",
|
||||
region: "us-east-1"
|
||||
});
|
||||
const tableName = "myTable"
|
||||
const table = await db.openTable(tableName);
|
||||
await table
|
||||
.search(Array(2).fill(0.1))
|
||||
.where("(item IN ('foo', 'bar')) AND (price > 10.0)")
|
||||
.postfilter()
|
||||
.toArray();
|
||||
// --8<-- [end:filtering]
|
||||
// --8<-- [start:sql_filtering]
|
||||
await table
|
||||
.search(Array(2).fill(0.1))
|
||||
.where("(item IN ('foo', 'bar')) AND (price > 10.0)")
|
||||
.postfilter()
|
||||
.toArray();
|
||||
// --8<-- [end:sql_filtering]
|
||||
});
|
||||
@@ -492,6 +492,42 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
super.doCall((inner) => inner.bypassVectorIndex());
|
||||
return this;
|
||||
}
|
||||
|
||||
/*
|
||||
* Add a query vector to the search
|
||||
*
|
||||
* This method can be called multiple times to add multiple query vectors
|
||||
* to the search. If multiple query vectors are added, then they will be searched
|
||||
* in parallel, and the results will be concatenated. A column called `query_index`
|
||||
* will be added to indicate the index of the query vector that produced the result.
|
||||
*
|
||||
* Performance wise, this is equivalent to running multiple queries concurrently.
|
||||
*/
|
||||
addQueryVector(vector: IntoVector): VectorQuery {
|
||||
if (vector instanceof Promise) {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.addQueryVector(arr);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
return inner;
|
||||
} catch (e) {
|
||||
return Promise.reject(e);
|
||||
}
|
||||
})();
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
super.doCall((inner) => {
|
||||
inner.addQueryVector(Float32Array.from(vector));
|
||||
});
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** A builder for LanceDB queries. */
|
||||
|
||||
@@ -135,6 +135,16 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().column(&column);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn add_query_vector(&mut self, vector: Float32Array) -> Result<()> {
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.add_query_vector(vector.as_ref())
|
||||
.default_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.16.0-beta.0"
|
||||
current_version = "0.16.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.16.0-beta.0"
|
||||
version = "0.16.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -4,7 +4,7 @@ name = "lancedb"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"nest-asyncio~=1.0",
|
||||
"pylance==0.19.2-beta.3",
|
||||
"pylance==0.19.2",
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"packaging",
|
||||
|
||||
@@ -943,12 +943,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
k=self._limit or 10,
|
||||
with_row_id=self._with_row_id,
|
||||
vector=[],
|
||||
# not actually respected in remote query
|
||||
offset=self._offset or 0,
|
||||
)
|
||||
return self._table._execute_query(query).read_all()
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
@@ -1491,7 +1495,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
return pa.array(vec)
|
||||
|
||||
def nearest_to(
|
||||
self, query_vector: Optional[Union[VEC, Tuple]] = None
|
||||
self, query_vector: Optional[Union[VEC, Tuple, List[VEC]]] = None
|
||||
) -> AsyncVectorQuery:
|
||||
"""
|
||||
Find the nearest vectors to the given query vector.
|
||||
@@ -1529,10 +1533,30 @@ class AsyncQuery(AsyncQueryBase):
|
||||
|
||||
Vector searches always have a [limit][]. If `limit` has not been called then
|
||||
a default `limit` of 10 will be used.
|
||||
|
||||
Typically, a single vector is passed in as the query. However, you can also
|
||||
pass in multiple vectors. This can be useful if you want to find the nearest
|
||||
vectors to multiple query vectors. This is not expected to be faster than
|
||||
making multiple queries concurrently; it is just a convenience method.
|
||||
If multiple vectors are passed in then an additional column `query_index`
|
||||
will be added to the results. This column will contain the index of the
|
||||
query vector that the result is nearest to.
|
||||
"""
|
||||
return AsyncVectorQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
if (
|
||||
isinstance(query_vector, list)
|
||||
and len(query_vector) > 0
|
||||
and not isinstance(query_vector[0], (float, int))
|
||||
):
|
||||
# multiple have been passed
|
||||
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
|
||||
new_self = self._inner.nearest_to(query_vectors[0])
|
||||
for v in query_vectors[1:]:
|
||||
new_self.add_query_vector(v)
|
||||
return AsyncVectorQuery(new_self)
|
||||
else:
|
||||
return AsyncVectorQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = []
|
||||
|
||||
@@ -327,10 +327,6 @@ class RemoteTable(Table):
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
# empty query builder is not supported in saas, raise error
|
||||
if query is None and query_type != "hybrid":
|
||||
raise ValueError("Empty query is not supported")
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
|
||||
@@ -1,293 +0,0 @@
|
||||
# --8<-- [start:imports]
|
||||
# --8<-- [start:import-lancedb]
|
||||
# --8<-- [start:import-ingest-data]
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
# --8<-- [end:import-ingest-data]
|
||||
import numpy as np
|
||||
|
||||
# --8<-- [end:import-lancedb]
|
||||
# --8<-- [end:imports]
|
||||
# --8<-- [start:gen_data]
|
||||
def gen_data(total_rows: int, ndims: int = 1536):
|
||||
return pa.RecordBatch.from_pylist(
|
||||
[
|
||||
{
|
||||
"vector": np.random.rand(ndims).astype(np.float32).tolist(),
|
||||
"id": i,
|
||||
"name": "name_" + str(i),
|
||||
}
|
||||
for i in range(total_rows)
|
||||
],
|
||||
).to_pandas()
|
||||
|
||||
|
||||
# --8<-- [end:gen_data]
|
||||
|
||||
|
||||
def test_cloud_quickstart():
|
||||
# --8<-- [start:connect]
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="your-cloud-region"
|
||||
)
|
||||
# --8<-- [end:connect]
|
||||
# --8<-- [start:create_table]
|
||||
table_name = "myTable"
|
||||
table = db.create_table(table_name, data=gen_data(5000))
|
||||
# --8<-- [end:create_table]
|
||||
# --8<-- [start:create_index_search]
|
||||
# create a vector index
|
||||
table.create_index("cosine", vector_column_name="vector")
|
||||
result = table.search([0.01, 0.02]).select(["vector", "item"]).limit(1).to_pandas()
|
||||
print(result)
|
||||
# --8<-- [end:create_index_search]
|
||||
# --8<-- [start:drop_table]
|
||||
db.drop_table(table_name)
|
||||
# --8<-- [end:drop_table]
|
||||
|
||||
|
||||
def test_ingest_data():
|
||||
# --8<-- [start:ingest_data]
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
|
||||
# create an empty table with schema
|
||||
table_name = "myTable"
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
{"vector": [10.2, 100.8], "item": "baz", "price": 30.0},
|
||||
{"vector": [1.4, 9.5], "item": "fred", "price": 40.0},
|
||||
]
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
table = db.create_table(table_name, schema=schema)
|
||||
table.add(data)
|
||||
# --8<-- [end:ingest_data]
|
||||
# --8<-- [start:ingest_data_in_batch]
|
||||
def make_batches():
|
||||
for i in range(5):
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
[
|
||||
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
["vector", "item", "price"],
|
||||
)
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
db.create_table("table2", make_batches(), schema=schema)
|
||||
# --8<-- [end:ingest_data_in_batch]
|
||||
|
||||
|
||||
def test_updates():
|
||||
# --8<-- [start:update_data]
|
||||
import lancedb
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
table_name = "myTable"
|
||||
table = db.open_table(table_name)
|
||||
table.update(where="price < 20.0", values={"vector": [2, 2], "item": "foo-updated"})
|
||||
# --8<-- [end:update_data]
|
||||
# --8<-- [start:merge_insert]
|
||||
table = db.open_table(table_name)
|
||||
# upsert
|
||||
new_data = [{"vector": [1, 1], "item": "foo-updated", "price": 50.0}]
|
||||
table.merge_insert(
|
||||
"item"
|
||||
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||
# --8<-- [end:merge_insert]
|
||||
# --8<-- [start:delete_data]
|
||||
table_name = "myTable"
|
||||
table = db.open_table(table_name)
|
||||
# delete data
|
||||
predicate = "price = 30.0"
|
||||
table.delete(predicate)
|
||||
# --8<-- [end:delete_data]
|
||||
|
||||
|
||||
def test_create_index():
|
||||
# --8<-- [start:create_index]
|
||||
import lancedb
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
|
||||
table_name = "myTable"
|
||||
table = db.open_table(table_name)
|
||||
# the vector column only needs to be specified when there are
|
||||
# multiple vector columns or the column is not named as "vector"
|
||||
# L2 is used as the default distance metric
|
||||
table.create_index(metric="cosine", vector_column_name="vector")
|
||||
# --8<-- [end:create_index]
|
||||
|
||||
|
||||
def test_create_scalar_index():
|
||||
# --8<-- [start:create_scalar_index]
|
||||
import lancedb
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
|
||||
table_name = "myTable"
|
||||
table = db.open_table(table_name)
|
||||
# default is BTree
|
||||
table.create_scalar_index("item", index_type="BITMAP")
|
||||
# --8<-- [end:create_scalar_index]
|
||||
|
||||
|
||||
def test_create_fts_index():
|
||||
# --8<-- [start:create_fts_index]
|
||||
import lancedb
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
|
||||
table_name = "myTable"
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "text": "Frodo was a happy puppy"},
|
||||
{"vector": [5.9, 26.5], "text": "There are several kittens playing"},
|
||||
]
|
||||
table = db.create_table(table_name, data=data)
|
||||
table.create_fts_index("text")
|
||||
# --8<-- [end:create_fts_index]
|
||||
|
||||
|
||||
def test_search():
|
||||
# --8<-- [start:vector_search]
|
||||
import lancedb
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
|
||||
table_name = "myTable"
|
||||
table = db.open_table(table_name)
|
||||
query = [0.4, 1.4]
|
||||
result = (
|
||||
table.search(query)
|
||||
.where("price > 10.0", prefilter=True)
|
||||
.select(["item", "vector"])
|
||||
.limit(2)
|
||||
.to_pandas()
|
||||
)
|
||||
print(result)
|
||||
# --8<-- [end:vector_search]
|
||||
# --8<-- [start:full_text_search]
|
||||
import lancedb
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
table_name = "myTable"
|
||||
table = db.create_table(
|
||||
table_name,
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "text": "Frodo was a happy puppy"},
|
||||
{"vector": [5.9, 26.5], "text": "There are several kittens playing"},
|
||||
],
|
||||
)
|
||||
|
||||
table.create_fts_index("text")
|
||||
table.search("puppy", query_type="fts").limit(10).select(["text"]).to_list()
|
||||
# --8<-- [end:full_text_search]
|
||||
# --8<-- [start:hybrid_search]
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import openai
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import RRFReranker
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
|
||||
# Configuring the environment variable OPENAI_API_KEY
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
embeddings = get_registry().get("openai").create()
|
||||
|
||||
class Documents(LanceModel):
|
||||
text: str = embeddings.SourceField()
|
||||
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
|
||||
|
||||
table_name = "myTable"
|
||||
table = db.create_table(table_name, schema=Documents)
|
||||
data = [
|
||||
{"text": "rebel spaceships striking from a hidden base"},
|
||||
{"text": "have won their first victory against the evil Galactic Empire"},
|
||||
{"text": "during the battle rebel spies managed to steal secret plans"},
|
||||
{"text": "to the Empire's ultimate weapon the Death Star"},
|
||||
]
|
||||
table.add(data=data)
|
||||
table.create_index("L2", "vector")
|
||||
table.create_fts_index("text")
|
||||
|
||||
# you can use table.list_indices() to make sure indices have been created
|
||||
reranker = RRFReranker()
|
||||
result = (
|
||||
table.search(
|
||||
"flower moon",
|
||||
query_type="hybrid",
|
||||
vector_column_name="vector",
|
||||
fts_columns="text",
|
||||
)
|
||||
.rerank(reranker)
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
print(result)
|
||||
# --8<-- [end:hybrid_search]
|
||||
|
||||
|
||||
def test_filtering():
|
||||
# --8<-- [start:filtering]
|
||||
import lancedb
|
||||
|
||||
# connect to LanceDB
|
||||
db = lancedb.connect(
|
||||
uri="db://your-project-slug", api_key="your-api-key", region="us-east-1"
|
||||
)
|
||||
table_name = "myTable"
|
||||
table = db.open_table(table_name)
|
||||
result = (
|
||||
table.search([100, 102])
|
||||
.where("(item IN ('foo', 'bar')) AND (price > 10.0)")
|
||||
.to_arrow()
|
||||
)
|
||||
print(result)
|
||||
# --8<-- [end:filtering]
|
||||
# --8<-- [start:sql_filtering]
|
||||
table.search([100, 102]).where(
|
||||
"(item IN ('foo', 'bar')) AND (price > 10.0)"
|
||||
).to_arrow()
|
||||
# --8<-- [end:sql_filtering]
|
||||
@@ -197,6 +197,23 @@ def test_query_sync_minimal():
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_query_sync_empty_query():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"k": 10,
|
||||
"filter": "true",
|
||||
"vector": [],
|
||||
"columns": ["id"],
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_query_sync_maximal():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
@@ -229,6 +246,17 @@ def test_query_sync_maximal():
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_multiple_vectors():
|
||||
def handler(_body):
|
||||
return pa.table({"id": [1]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
||||
assert len(results) == 2
|
||||
results.sort(key=lambda x: x["query_index"])
|
||||
assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}]
|
||||
|
||||
|
||||
def test_query_sync_fts():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
|
||||
@@ -892,10 +892,15 @@ def test_empty_query(db):
|
||||
table = LanceTable.create(db, "my_table2", data=[{"id": i} for i in range(100)])
|
||||
df = table.search().select(["id"]).to_pandas()
|
||||
assert len(df) == 10
|
||||
# None is the same as default
|
||||
df = table.search().select(["id"]).limit(None).to_pandas()
|
||||
assert len(df) == 100
|
||||
assert len(df) == 10
|
||||
# invalid limist is the same as None, wihch is the same as default
|
||||
df = table.search().select(["id"]).limit(-1).to_pandas()
|
||||
assert len(df) == 100
|
||||
assert len(df) == 10
|
||||
# valid limit should work
|
||||
df = table.search().select(["id"]).limit(42).to_pandas()
|
||||
assert len(df) == 42
|
||||
|
||||
|
||||
def test_search_with_schema_inf_single_vector(db):
|
||||
|
||||
@@ -142,6 +142,13 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
||||
let array = make_array(data);
|
||||
self.inner = self.inner.clone().add_query_vector(array).infer_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
@@ -475,6 +475,7 @@ impl<T: HasQuery> QueryBase for T {
|
||||
|
||||
/// Options for controlling the execution of a query
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryExecutionOptions {
|
||||
/// The maximum number of rows that will be contained in a single
|
||||
/// `RecordBatch` delivered by the query.
|
||||
@@ -650,7 +651,7 @@ impl Query {
|
||||
pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result<VectorQuery> {
|
||||
let mut vector_query = self.into_vector();
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
vector_query.query_vector = Some(query_vector);
|
||||
vector_query.query_vector.push(query_vector);
|
||||
Ok(vector_query)
|
||||
}
|
||||
}
|
||||
@@ -701,7 +702,7 @@ pub struct VectorQuery {
|
||||
// the column based on the dataset's schema.
|
||||
pub(crate) column: Option<String>,
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Option<Arc<dyn Array>>,
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
pub(crate) refine_factor: Option<u32>,
|
||||
pub(crate) distance_type: Option<DistanceType>,
|
||||
@@ -714,7 +715,7 @@ impl VectorQuery {
|
||||
Self {
|
||||
base,
|
||||
column: None,
|
||||
query_vector: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
@@ -734,6 +735,22 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// Add another query vector to the search.
|
||||
///
|
||||
/// Multiple searches will be dispatched as part of the query.
|
||||
/// This is a convenience method for adding multiple query vectors
|
||||
/// to the search. It is not expected to be faster than issuing
|
||||
/// multiple queries concurrently.
|
||||
///
|
||||
/// The output data will contain an additional columns `query_index` which
|
||||
/// will contain the index of the query vector that was used to generate the
|
||||
/// result.
|
||||
pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result<Self> {
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
self.query_vector.push(query_vector);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Set the number of partitions to search (probe)
|
||||
///
|
||||
/// This argument is only used when the vector column has an IVF PQ index.
|
||||
@@ -854,6 +871,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use arrow::{compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow_array::{
|
||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
@@ -883,7 +901,10 @@ mod tests {
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.query().nearest_to(&[0.1, 0.2]).unwrap();
|
||||
assert_eq!(*query.query_vector.unwrap().as_ref().as_primitive(), vector);
|
||||
assert_eq!(
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
vector
|
||||
);
|
||||
|
||||
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
|
||||
|
||||
@@ -899,7 +920,7 @@ mod tests {
|
||||
.refine_factor(999);
|
||||
|
||||
assert_eq!(
|
||||
*query.query_vector.unwrap().as_ref().as_primitive(),
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
new_vector
|
||||
);
|
||||
assert_eq!(query.base.limit.unwrap(), 100);
|
||||
@@ -1197,4 +1218,34 @@ mod tests {
|
||||
assert!(batch.column_by_name("_rowid").is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_query_vectors() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let query = table
|
||||
.query()
|
||||
.nearest_to(&[0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.add_query_vector(&[0.5, 0.6, 0.7, 0.8])
|
||||
.unwrap()
|
||||
.limit(1);
|
||||
|
||||
let plan = query.explain_plan(true).await.unwrap();
|
||||
assert!(plan.contains("UnionExec"));
|
||||
|
||||
let results = query
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let results = concat_batches(&results[0].schema(), &results).unwrap();
|
||||
assert_eq!(results.num_rows(), 2); // One result for each query vector.
|
||||
let query_index = results["query_index"].as_primitive::<Int32Type>();
|
||||
// We don't guarantee order.
|
||||
assert!(query_index.values().contains(&0));
|
||||
assert!(query_index.values().contains(&1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::index::IndexStatistics;
|
||||
use crate::query::Select;
|
||||
use crate::table::AddDataMode;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::Error;
|
||||
use crate::{Error, Table};
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
@@ -185,6 +185,71 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_vector_query_params(
|
||||
mut body: serde_json::Value,
|
||||
query: &VectorQuery,
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
// Apply general parameters, before we dispatch based on number of query vectors.
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
}
|
||||
if !query.use_index {
|
||||
body["bypass_vector_index"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
fn vector_to_json(vector: &arrow_array::ArrayRef) -> Result<serde_json::Value> {
|
||||
match vector.data_type() {
|
||||
DataType::Float32 => {
|
||||
let array = vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
.unwrap();
|
||||
Ok(serde_json::Value::Array(
|
||||
array
|
||||
.values()
|
||||
.iter()
|
||||
.map(|v| {
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(*v as f64).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
_ => Err(Error::InvalidInput {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
match query.query_vector.len() {
|
||||
0 => {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
Ok(vec![body])
|
||||
}
|
||||
1 => {
|
||||
body["vector"] = vector_to_json(&query.query_vector[0])?;
|
||||
Ok(vec![body])
|
||||
}
|
||||
_ => {
|
||||
let mut bodies = Vec::with_capacity(query.query_vector.len());
|
||||
for vector in &query.query_vector {
|
||||
let mut body = body.clone();
|
||||
body["vector"] = vector_to_json(vector)?;
|
||||
bodies.push(body);
|
||||
}
|
||||
Ok(bodies)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -306,51 +371,29 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
let body = serde_json::Value::Object(Default::default());
|
||||
let bodies = Self::apply_vector_query_params(body, query)?;
|
||||
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
|
||||
let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
|
||||
match vector.data_type() {
|
||||
DataType::Float32 => vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
.unwrap()
|
||||
.values()
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
_ => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
let mut futures = Vec::with_capacity(bodies.len());
|
||||
for body in bodies {
|
||||
let request = request.try_clone().unwrap().json(&body);
|
||||
let future = async move {
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
};
|
||||
futures.push(future);
|
||||
}
|
||||
let streams = futures::future::try_join_all(futures).await?;
|
||||
if streams.len() == 1 {
|
||||
let stream = streams.into_iter().next().unwrap();
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
} else {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
Vec::new()
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
Table::multi_vector_plan(stream_execs)
|
||||
}
|
||||
|
||||
if !query.use_index {
|
||||
body["bypass_vector_index"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
@@ -655,6 +698,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||
@@ -1207,6 +1251,52 @@ mod tests {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_multiple_vectors() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let query = table
|
||||
.query()
|
||||
.nearest_to(vec![0.1, 0.2, 0.3])
|
||||
.unwrap()
|
||||
.add_query_vector(vec![0.4, 0.5, 0.6])
|
||||
.unwrap();
|
||||
let plan = query.explain_plan(true).await.unwrap();
|
||||
assert!(plan.contains("UnionExec"), "Plan: {}", plan);
|
||||
|
||||
let results = query
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let results = concat_batches(&results[0].schema(), &results).unwrap();
|
||||
|
||||
let query_index = results["query_index"].as_primitive::<Int32Type>();
|
||||
// We don't guarantee order.
|
||||
assert!(query_index.values().contains(&0));
|
||||
assert!(query_index.values().contains(&1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_index() {
|
||||
let cases = [
|
||||
|
||||
@@ -24,6 +24,9 @@ use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
use datafusion_physical_plan::repartition::RepartitionExec;
|
||||
use datafusion_physical_plan::union::UnionExec;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
@@ -972,6 +975,57 @@ impl Table {
|
||||
) -> Result<Option<IndexStatistics>> {
|
||||
self.inner.index_stats(index_name.as_ref()).await
|
||||
}
|
||||
|
||||
// Take many execution plans and map them into a single plan that adds
|
||||
// a query_index column and unions them.
|
||||
pub(crate) fn multi_vector_plan(
|
||||
plans: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
if plans.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "No plans provided".to_string(),
|
||||
});
|
||||
}
|
||||
// Projection to keeping all existing columns
|
||||
let first_plan = plans[0].clone();
|
||||
let project_all_columns = first_plan
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, field)| {
|
||||
let expr =
|
||||
datafusion_physical_plan::expressions::Column::new(field.name().as_str(), i);
|
||||
let expr = Arc::new(expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
|
||||
(expr, field.name().clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let projected_plans = plans
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(plan_i, plan)| {
|
||||
let query_index = datafusion_common::ScalarValue::Int32(Some(plan_i as i32));
|
||||
let query_index_expr =
|
||||
datafusion_physical_plan::expressions::Literal::new(query_index);
|
||||
let query_index_expr =
|
||||
Arc::new(query_index_expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
|
||||
let mut projections = vec![(query_index_expr, "query_index".to_string())];
|
||||
projections.extend_from_slice(&project_all_columns);
|
||||
let projection = ProjectionExec::try_new(projections, plan).unwrap();
|
||||
Arc::new(projection) as Arc<dyn datafusion_physical_plan::ExecutionPlan>
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let unioned = Arc::new(UnionExec::new(projected_plans));
|
||||
// We require 1 partition in the final output
|
||||
let repartitioned = RepartitionExec::try_new(
|
||||
unioned,
|
||||
datafusion_physical_plan::Partitioning::RoundRobinBatch(1),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(Arc::new(repartitioned))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NativeTable> for Table {
|
||||
@@ -1784,9 +1838,25 @@ impl TableInternal for NativeTable {
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
if query.query_vector.len() > 1 {
|
||||
// If there are multiple query vectors, create a plan for each of them and union them.
|
||||
let query_vecs = query.query_vector.clone();
|
||||
let plan_futures = query_vecs
|
||||
.into_iter()
|
||||
.map(|query_vector| {
|
||||
let mut sub_query = query.clone();
|
||||
sub_query.query_vector = vec![query_vector];
|
||||
let options_ref = options.clone();
|
||||
async move { self.create_plan(&sub_query, options_ref).await }
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let plans = futures::future::try_join_all(plan_futures).await?;
|
||||
return Table::multi_vector_plan(plans);
|
||||
}
|
||||
|
||||
let mut scanner: Scanner = ds_ref.scan();
|
||||
|
||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||
if let Some(query_vector) = query.query_vector.first() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
let column = if let Some(col) = query.column.as_ref() {
|
||||
col.clone()
|
||||
@@ -1828,18 +1898,11 @@ impl TableInternal for NativeTable {
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
query.base.offset.map(|offset| offset as i64),
|
||||
)?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
query.base.offset.map(|offset| offset as i64),
|
||||
)?;
|
||||
}
|
||||
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
query.base.offset.map(|offset| offset as i64),
|
||||
)?;
|
||||
scanner.nprobs(query.nprobes);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
|
||||
Reference in New Issue
Block a user