mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
16 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0dd98c798 | ||
|
|
ee73a3bcb8 | ||
|
|
c07989ac29 | ||
|
|
8f7ef26f5f | ||
|
|
e14f079fe2 | ||
|
|
7d790bd9e7 | ||
|
|
dbdd0a7b4b | ||
|
|
befb79c5f9 | ||
|
|
0a387a5429 | ||
|
|
5a173e1d54 | ||
|
|
51bdbcad98 | ||
|
|
0c7809c7a0 | ||
|
|
2de226220b | ||
|
|
bd5b6f21e2 | ||
|
|
6331807b95 | ||
|
|
83cb3f01a4 |
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
|||||||
categories = ["database-implementations"]
|
categories = ["database-implementations"]
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.10.4", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.10.5", "features" = ["dynamodb"] }
|
||||||
lance-index = { "version" = "=0.10.4" }
|
lance-index = { "version" = "=0.10.5" }
|
||||||
lance-linalg = { "version" = "=0.10.4" }
|
lance-linalg = { "version" = "=0.10.5" }
|
||||||
lance-testing = { "version" = "=0.10.4" }
|
lance-testing = { "version" = "=0.10.5" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "50.0", optional = false }
|
arrow = { version = "50.0", optional = false }
|
||||||
arrow-array = "50.0"
|
arrow-array = "50.0"
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ theme:
|
|||||||
- content.tabs.link
|
- content.tabs.link
|
||||||
- content.action.edit
|
- content.action.edit
|
||||||
- toc.follow
|
- toc.follow
|
||||||
# - toc.integrate
|
|
||||||
- navigation.top
|
- navigation.top
|
||||||
- navigation.tabs
|
- navigation.tabs
|
||||||
- navigation.tabs.sticky
|
- navigation.tabs.sticky
|
||||||
@@ -64,7 +63,7 @@ plugins:
|
|||||||
add_image: True # Automatically add meta image
|
add_image: True # Automatically add meta image
|
||||||
add_keywords: True # Add page keywords in the header tag
|
add_keywords: True # Add page keywords in the header tag
|
||||||
add_share_buttons: True # Add social share buttons
|
add_share_buttons: True # Add social share buttons
|
||||||
add_authors: False # Display page authors
|
add_authors: False # Display page authors
|
||||||
add_desc: False
|
add_desc: False
|
||||||
add_dates: False
|
add_dates: False
|
||||||
|
|
||||||
@@ -140,12 +139,14 @@ nav:
|
|||||||
- Serverless Website Chatbot: examples/serverless_website_chatbot.md
|
- Serverless Website Chatbot: examples/serverless_website_chatbot.md
|
||||||
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
||||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||||
|
- 🦀 Rust:
|
||||||
|
- Overview: examples/examples_rust.md
|
||||||
- 🔧 CLI & Config: cli_config.md
|
- 🔧 CLI & Config: cli_config.md
|
||||||
- 💭 FAQs: faq.md
|
- 💭 FAQs: faq.md
|
||||||
- ⚙️ API reference:
|
- ⚙️ API reference:
|
||||||
- 🐍 Python: python/python.md
|
- 🐍 Python: python/python.md
|
||||||
- 👾 JavaScript: javascript/modules.md
|
- 👾 JavaScript: javascript/modules.md
|
||||||
- 🦀 Rust: https://docs.rs/vectordb/latest/vectordb/
|
- 🦀 Rust: https://docs.rs/lancedb/latest/lancedb/
|
||||||
- ☁️ LanceDB Cloud:
|
- ☁️ LanceDB Cloud:
|
||||||
- Overview: cloud/index.md
|
- Overview: cloud/index.md
|
||||||
- API reference:
|
- API reference:
|
||||||
@@ -189,21 +190,21 @@ nav:
|
|||||||
- Pydantic: python/pydantic.md
|
- Pydantic: python/pydantic.md
|
||||||
- Voxel51: integrations/voxel51.md
|
- Voxel51: integrations/voxel51.md
|
||||||
- PromptTools: integrations/prompttools.md
|
- PromptTools: integrations/prompttools.md
|
||||||
- Python examples:
|
- Examples:
|
||||||
- examples/index.md
|
- examples/index.md
|
||||||
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
||||||
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
||||||
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
||||||
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
||||||
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
||||||
- Javascript examples:
|
- YouTube Transcript Search (JS): examples/youtube_transcript_bot_with_nodejs.md
|
||||||
- Overview: examples/examples_js.md
|
|
||||||
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
|
||||||
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
||||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||||
- API reference:
|
- API reference:
|
||||||
|
- Overview: api_reference.md
|
||||||
- Python: python/python.md
|
- Python: python/python.md
|
||||||
- Javascript: javascript/modules.md
|
- Javascript: javascript/modules.md
|
||||||
|
- Rust: https://docs.rs/lancedb/latest/lancedb/index.html
|
||||||
- LanceDB Cloud:
|
- LanceDB Cloud:
|
||||||
- Overview: cloud/index.md
|
- Overview: cloud/index.md
|
||||||
- API reference:
|
- API reference:
|
||||||
|
|||||||
@@ -19,39 +19,61 @@ Lance supports `IVF_PQ` index type by default.
|
|||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
Creating indexes is done via the [create_index](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.create_index) method.
|
Creating indexes is done via the [create_index](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.create_index) method.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import lancedb
|
import lancedb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
uri = "data/sample-lancedb"
|
uri = "data/sample-lancedb"
|
||||||
db = lancedb.connect(uri)
|
db = lancedb.connect(uri)
|
||||||
|
|
||||||
# Create 10,000 sample vectors
|
# Create 10,000 sample vectors
|
||||||
data = [{"vector": row, "item": f"item {i}"}
|
data = [{"vector": row, "item": f"item {i}"}
|
||||||
for i, row in enumerate(np.random.random((10_000, 1536)).astype('float32'))]
|
for i, row in enumerate(np.random.random((10_000, 1536)).astype('float32'))]
|
||||||
|
|
||||||
# Add the vectors to a table
|
# Add the vectors to a table
|
||||||
tbl = db.create_table("my_vectors", data=data)
|
tbl = db.create_table("my_vectors", data=data)
|
||||||
|
|
||||||
# Create and train the index - you need to have enough data in the table for an effective training step
|
# Create and train the index - you need to have enough data in the table for an effective training step
|
||||||
tbl.create_index(num_partitions=256, num_sub_vectors=96)
|
tbl.create_index(num_partitions=256, num_sub_vectors=96)
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Typescript"
|
=== "Typescript"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<--- "docs/src/ann_indexes.ts:import"
|
--8<--- "docs/src/ann_indexes.ts:import"
|
||||||
|
|
||||||
--8<-- "docs/src/ann_indexes.ts:ingest"
|
--8<-- "docs/src/ann_indexes.ts:ingest"
|
||||||
```
|
```
|
||||||
|
|
||||||
- **metric** (default: "L2"): The distance metric to use. By default it uses euclidean distance "`L2`".
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/lancedb/examples/ivf_pq.rs:create_index"
|
||||||
|
```
|
||||||
|
|
||||||
|
IVF_PQ index parameters are more fully defined in the [crate docs](https://docs.rs/lancedb/latest/lancedb/index/vector/struct.IvfPqIndexBuilder.html).
|
||||||
|
|
||||||
|
The following IVF_PQ paramters can be specified:
|
||||||
|
|
||||||
|
- **distance_type**: The distance metric to use. By default it uses euclidean distance "`L2`".
|
||||||
We also support "cosine" and "dot" distance as well.
|
We also support "cosine" and "dot" distance as well.
|
||||||
- **num_partitions** (default: 256): The number of partitions of the index.
|
- **num_partitions**: The number of partitions in the index. The default is the square root
|
||||||
- **num_sub_vectors** (default: 96): The number of sub-vectors (M) that will be created during Product Quantization (PQ).
|
of the number of rows.
|
||||||
For D dimensional vector, it will be divided into `M` of `D/M` sub-vectors, each of which is presented by
|
|
||||||
a single PQ code.
|
!!! note
|
||||||
|
|
||||||
|
In the synchronous python SDK and node's `vectordb` the default is 256. This default has
|
||||||
|
changed in the asynchronous python SDK and node's `lancedb`.
|
||||||
|
|
||||||
|
- **num_sub_vectors**: The number of sub-vectors (M) that will be created during Product Quantization (PQ).
|
||||||
|
For D dimensional vector, it will be divided into `M` subvectors with dimension `D/M`, each of which is replaced by
|
||||||
|
a single PQ code. The default is the dimension of the vector divided by 16.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
|
||||||
|
In the synchronous python SDK and node's `vectordb` the default is currently 96. This default has
|
||||||
|
changed in the asynchronous python SDK and node's `lancedb`.
|
||||||
|
|
||||||
<figure markdown>
|
<figure markdown>
|
||||||

|

|
||||||
@@ -114,25 +136,33 @@ There are a couple of parameters that can be used to fine-tune the search:
|
|||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tbl.search(np.random.random((1536))) \
|
tbl.search(np.random.random((1536))) \
|
||||||
.limit(2) \
|
.limit(2) \
|
||||||
.nprobes(20) \
|
.nprobes(20) \
|
||||||
.refine_factor(10) \
|
.refine_factor(10) \
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
```text
|
```text
|
||||||
vector item _distance
|
vector item _distance
|
||||||
0 [0.44949695, 0.8444449, 0.06281311, 0.23338133... item 1141 103.575333
|
0 [0.44949695, 0.8444449, 0.06281311, 0.23338133... item 1141 103.575333
|
||||||
1 [0.48587373, 0.269207, 0.15095535, 0.65531915,... item 3953 108.393867
|
1 [0.48587373, 0.269207, 0.15095535, 0.65531915,... item 3953 108.393867
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Typescript"
|
=== "Typescript"
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
--8<-- "docs/src/ann_indexes.ts:search1"
|
--8<-- "docs/src/ann_indexes.ts:search1"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/lancedb/examples/ivf_pq.rs:search1"
|
||||||
|
```
|
||||||
|
|
||||||
|
Vector search options are more fully defined in the [crate docs](https://docs.rs/lancedb/latest/lancedb/query/struct.Query.html#method.nearest_to).
|
||||||
|
|
||||||
The search will return the data requested in addition to the distance of each item.
|
The search will return the data requested in addition to the distance of each item.
|
||||||
|
|
||||||
@@ -181,7 +211,7 @@ You can select the columns returned by the query using a select clause.
|
|||||||
### Why do I need to manually create an index?
|
### Why do I need to manually create an index?
|
||||||
|
|
||||||
Currently, LanceDB does _not_ automatically create the ANN index.
|
Currently, LanceDB does _not_ automatically create the ANN index.
|
||||||
LanceDB is well-optimized for kNN (exhaustive search) via a disk-based index. For many use-cases,
|
LanceDB is well-optimized for kNN (exhaustive search) via a disk-based index. For many use-cases,
|
||||||
datasets of the order of ~100K vectors don't require index creation. If you can live with up to
|
datasets of the order of ~100K vectors don't require index creation. If you can live with up to
|
||||||
100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall.
|
100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall.
|
||||||
|
|
||||||
|
|||||||
7
docs/src/api_reference.md
Normal file
7
docs/src/api_reference.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# API Reference
|
||||||
|
|
||||||
|
The API reference for the LanceDB client SDKs are available at the following locations:
|
||||||
|
|
||||||
|
- [Python](python/python.md)
|
||||||
|
- [JavaScript](javascript/modules.md)
|
||||||
|
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
!!! info "LanceDB can be run in a number of ways:"
|
!!! info "LanceDB can be run in a number of ways:"
|
||||||
|
|
||||||
* Embedded within an existing backend (like your Django, Flask, Node.js or FastAPI application)
|
* Embedded within an existing backend (like your Django, Flask, Node.js or FastAPI application)
|
||||||
* Connected to directly from a client application like a Jupyter notebook for analytical workloads
|
* Directly from a client application like a Jupyter notebook for analytical workloads
|
||||||
* Deployed as a remote serverless database
|
* Deployed as a remote serverless database
|
||||||
|
|
||||||

|

|
||||||
@@ -24,13 +24,11 @@
|
|||||||
|
|
||||||
=== "Rust"
|
=== "Rust"
|
||||||
|
|
||||||
!!! warning "Rust SDK is experimental, might introduce breaking changes in the near future"
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
cargo add vectordb
|
cargo add lancedb
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! info "To use the vectordb create, you first need to install protobuf."
|
!!! info "To use the lancedb create, you first need to install protobuf."
|
||||||
|
|
||||||
=== "macOS"
|
=== "macOS"
|
||||||
|
|
||||||
@@ -44,7 +42,7 @@
|
|||||||
sudo apt install -y protobuf-compiler libssl-dev
|
sudo apt install -y protobuf-compiler libssl-dev
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! info "Please also make sure you're using the same version of Arrow as in the [vectordb crate](https://github.com/lancedb/lancedb/blob/main/Cargo.toml)"
|
!!! info "Please also make sure you're using the same version of Arrow as in the [lancedb crate](https://github.com/lancedb/lancedb/blob/main/Cargo.toml)"
|
||||||
|
|
||||||
## Connect to a database
|
## Connect to a database
|
||||||
|
|
||||||
@@ -81,10 +79,11 @@ If you need a reminder of the uri, you can call `db.uri()`.
|
|||||||
|
|
||||||
## Create a table
|
## Create a table
|
||||||
|
|
||||||
### Directly insert data to a new table
|
### Create a table from initial data
|
||||||
|
|
||||||
If you have data to insert into the table at creation time, you can simultaneously create a
|
If you have data to insert into the table at creation time, you can simultaneously create a
|
||||||
table and insert the data to it.
|
table and insert the data into it. The schema of the data will be used as the schema of the
|
||||||
|
table.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
@@ -120,21 +119,27 @@ table and insert the data to it.
|
|||||||
=== "Rust"
|
=== "Rust"
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
use arrow_schema::{DataType, Schema, Field};
|
|
||||||
use arrow_array::{RecordBatch, RecordBatchIterator};
|
|
||||||
|
|
||||||
--8<-- "rust/lancedb/examples/simple.rs:create_table"
|
--8<-- "rust/lancedb/examples/simple.rs:create_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
If the table already exists, LanceDB will raise an error by default.
|
If the table already exists, LanceDB will raise an error by default. See
|
||||||
|
[the mode option](https://docs.rs/lancedb/latest/lancedb/connection/struct.CreateTableBuilder.html#method.mode)
|
||||||
|
for details on how to overwrite (or open) existing tables instead.
|
||||||
|
|
||||||
!!! info "Under the hood, LanceDB converts the input data into an Apache Arrow table and persists it to disk using the [Lance format](https://www.github.com/lancedb/lance)."
|
!!! Providing table records in Rust
|
||||||
|
|
||||||
|
The Rust SDK currently expects data to be provided as an Arrow
|
||||||
|
[RecordBatchReader](https://docs.rs/arrow-array/latest/arrow_array/trait.RecordBatchReader.html)
|
||||||
|
Support for additional formats (such as serde or polars) is on the roadmap.
|
||||||
|
|
||||||
|
!!! info "Under the hood, LanceDB reads in the Apache Arrow data and persists it to disk using the [Lance format](https://www.github.com/lancedb/lance)."
|
||||||
|
|
||||||
### Create an empty table
|
### Create an empty table
|
||||||
|
|
||||||
Sometimes you may not have the data to insert into the table at creation time.
|
Sometimes you may not have the data to insert into the table at creation time.
|
||||||
In this case, you can create an empty table and specify the schema, so that you can add
|
In this case, you can create an empty table and specify the schema, so that you can add
|
||||||
data to the table at a later time (such that it conforms to the schema).
|
data to the table at a later time (as long as it conforms to the schema). This is
|
||||||
|
similar to a `CREATE TABLE` statement in SQL.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
@@ -175,7 +180,7 @@ Once created, you can open a table as follows:
|
|||||||
=== "Rust"
|
=== "Rust"
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
--8<-- "rust/lancedb/examples/simple.rs:open_with_existing_file"
|
--8<-- "rust/lancedb/examples/simple.rs:open_existing_tbl"
|
||||||
```
|
```
|
||||||
|
|
||||||
If you forget the name of your table, you can always get a listing of all table names:
|
If you forget the name of your table, you can always get a listing of all table names:
|
||||||
@@ -254,6 +259,14 @@ Once you've embedded the query, you can find its nearest neighbors as follows:
|
|||||||
--8<-- "rust/lancedb/examples/simple.rs:search"
|
--8<-- "rust/lancedb/examples/simple.rs:search"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
!!! Query vectors in Rust
|
||||||
|
Rust does not yet support automatic execution of embedding functions. You will need to
|
||||||
|
calculate embeddings yourself. Support for this is on the roadmap and can be tracked at
|
||||||
|
https://github.com/lancedb/lancedb/issues/994
|
||||||
|
|
||||||
|
Query vectors can be provided as Arrow arrays or a Vec/slice of Rust floats.
|
||||||
|
Support for additional formats (e.g. `polars::series::Series`) is on the roadmap.
|
||||||
|
|
||||||
By default, LanceDB runs a brute-force scan over dataset to find the K nearest neighbours (KNN).
|
By default, LanceDB runs a brute-force scan over dataset to find the K nearest neighbours (KNN).
|
||||||
For tables with more than 50K vectors, creating an ANN index is recommended to speed up search performance.
|
For tables with more than 50K vectors, creating an ANN index is recommended to speed up search performance.
|
||||||
LanceDB allows you to create an ANN index on a table as follows:
|
LanceDB allows you to create an ANN index on a table as follows:
|
||||||
@@ -277,7 +290,7 @@ LanceDB allows you to create an ANN index on a table as follows:
|
|||||||
```
|
```
|
||||||
|
|
||||||
!!! note "Why do I need to create an index manually?"
|
!!! note "Why do I need to create an index manually?"
|
||||||
LanceDB does not automatically create the ANN index, for two reasons. The first is that it's optimized
|
LanceDB does not automatically create the ANN index for two reasons. The first is that it's optimized
|
||||||
for really fast retrievals via a disk-based index, and the second is that data and query workloads can
|
for really fast retrievals via a disk-based index, and the second is that data and query workloads can
|
||||||
be very diverse, so there's no one-size-fits-all index configuration. LanceDB provides many parameters
|
be very diverse, so there's no one-size-fits-all index configuration. LanceDB provides many parameters
|
||||||
to fine-tune index size, query latency and accuracy. See the section on
|
to fine-tune index size, query latency and accuracy. See the section on
|
||||||
@@ -308,8 +321,9 @@ This can delete any number of rows that match the filter.
|
|||||||
```
|
```
|
||||||
|
|
||||||
The deletion predicate is a SQL expression that supports the same expressions
|
The deletion predicate is a SQL expression that supports the same expressions
|
||||||
as the `where()` clause on a search. They can be as simple or complex as needed.
|
as the `where()` clause (`only_if()` in Rust) on a search. They can be as
|
||||||
To see what expressions are supported, see the [SQL filters](sql.md) section.
|
simple or complex as needed. To see what expressions are supported, see the
|
||||||
|
[SQL filters](sql.md) section.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
@@ -319,6 +333,10 @@ To see what expressions are supported, see the [SQL filters](sql.md) section.
|
|||||||
|
|
||||||
Read more: [vectordb.Table.delete](javascript/interfaces/Table.md#delete)
|
Read more: [vectordb.Table.delete](javascript/interfaces/Table.md#delete)
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
Read more: [lancedb::Table::delete](https://docs.rs/lancedb/latest/lancedb/table/struct.Table.html#method.delete)
|
||||||
|
|
||||||
## Drop a table
|
## Drop a table
|
||||||
|
|
||||||
Use the `drop_table()` method on the database to remove a table.
|
Use the `drop_table()` method on the database to remove a table.
|
||||||
|
|||||||
3
docs/src/examples/examples_rust.md
Normal file
3
docs/src/examples/examples_rust.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Examples: Rust
|
||||||
|
|
||||||
|
Our Rust SDK is now stable. Examples are coming soon.
|
||||||
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
## Recipes and example code
|
## Recipes and example code
|
||||||
|
|
||||||
LanceDB provides language APIs, allowing you to embed a database in your language of choice. We currently provide Python and Javascript APIs, with the Rust API and examples actively being worked on and will be available soon.
|
LanceDB provides language APIs, allowing you to embed a database in your language of choice.
|
||||||
|
|
||||||
* 🐍 [Python](examples_python.md) examples
|
* 🐍 [Python](examples_python.md) examples
|
||||||
* 👾 [JavaScript](exampled_js.md) examples
|
* 👾 [JavaScript](examples_js.md) examples
|
||||||
|
* 🦀 Rust examples (coming soon)
|
||||||
|
|
||||||
## Applications powered by LanceDB
|
## Applications powered by LanceDB
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ As we mention in our talk titled “[Lance, a modern columnar data format](https
|
|||||||
|
|
||||||
### Why build in Rust? 🦀
|
### Why build in Rust? 🦀
|
||||||
|
|
||||||
We believe that the Rust ecosystem has attained mainstream maturity and that Rust will form the underpinnings of large parts of the data and ML landscape in a few years. Performance, latency and reliability are paramount to a vector DB, and building in Rust allows us to iterate and release updates more rapidly due to Rust’s safety guarantees. Both Lance (the data format) and LanceDB (the database) are written entirely in Rust. We also provide Python and JavaScript client libraries to interact with the database. Our Rust API is a little rough around the edges right now, but is fast becoming on par with the Python and JS APIs.
|
We believe that the Rust ecosystem has attained mainstream maturity and that Rust will form the underpinnings of large parts of the data and ML landscape in a few years. Performance, latency and reliability are paramount to a vector DB, and building in Rust allows us to iterate and release updates more rapidly due to Rust’s safety guarantees. Both Lance (the data format) and LanceDB (the database) are written entirely in Rust. We also provide Python, JavaScript, and Rust client libraries to interact with the database.
|
||||||
|
|
||||||
### What is the difference between LanceDB OSS and LanceDB Cloud?
|
### What is the difference between LanceDB OSS and LanceDB Cloud?
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ For large-scale (>1M) or higher dimension vectors, it is beneficial to create an
|
|||||||
|
|
||||||
### Does LanceDB support full-text search?
|
### Does LanceDB support full-text search?
|
||||||
|
|
||||||
Yes, LanceDB supports full-text search (FTS) via [Tantivy](https://github.com/quickwit-oss/tantivy). Our current FTS integration is Python-only, and our goal is to push it down to the Rust level in future versions to enable much more powerful search capabilities available to our Python, JavaScript and Rust clients.
|
Yes, LanceDB supports full-text search (FTS) via [Tantivy](https://github.com/quickwit-oss/tantivy). Our current FTS integration is Python-only, and our goal is to push it down to the Rust level in future versions to enable much more powerful search capabilities available to our Python, JavaScript and Rust clients. Follow along in the [Github issue](https://github.com/lancedb/lance/issues/1195)
|
||||||
|
|
||||||
### How can I speed up data inserts?
|
### How can I speed up data inserts?
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Full-text search
|
# Full-text search
|
||||||
|
|
||||||
LanceDB provides support for full-text search via [Tantivy](https://github.com/quickwit-oss/tantivy) (currently Python only), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. Our goal is to push the FTS integration down to the Rust level in the future, so that it's available for JavaScript users as well.
|
LanceDB provides support for full-text search via [Tantivy](https://github.com/quickwit-oss/tantivy) (currently Python only), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. Our goal is to push the FTS integration down to the Rust level in the future, so that it's available for Rust and JavaScript users as well. Follow along at [this Github issue](https://github.com/lancedb/lance/issues/1195)
|
||||||
|
|
||||||
A hybrid search solution combining vector and full-text search is also on the way.
|
A hybrid search solution combining vector and full-text search is also on the way.
|
||||||
|
|
||||||
@@ -75,9 +75,39 @@ applied on top of the full text search results. This can be invoked via the fami
|
|||||||
table.search("puppy").limit(10).where("meta='foo'").to_list()
|
table.search("puppy").limit(10).where("meta='foo'").to_list()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Sorting
|
||||||
|
|
||||||
|
You can pre-sort the documents by specifying `ordering_field_names` when
|
||||||
|
creating the full-text search index. Once pre-sorted, you can then specify
|
||||||
|
`ordering_field_name` while searching to return results sorted by the given
|
||||||
|
field. For example,
|
||||||
|
|
||||||
|
```
|
||||||
|
table.create_fts_index(["text_field"], ordering_field_names=["sort_by_field"])
|
||||||
|
|
||||||
|
(table.search("terms", ordering_field_name="sort_by_field")
|
||||||
|
.limit(20)
|
||||||
|
.to_list())
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
If you wish to specify an ordering field at query time, you must also
|
||||||
|
have specified it during indexing time. Otherwise at query time, an
|
||||||
|
error will be raised that looks like `ValueError: The field does not exist: xxx`
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
The fields to sort on must be of typed unsigned integer, or else you will see
|
||||||
|
an error during indexing that looks like
|
||||||
|
`TypeError: argument 'value': 'float' object cannot be interpreted as an integer`.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
You can specify multiple fields for ordering at indexing time.
|
||||||
|
But at query time only one ordering field is supported.
|
||||||
|
|
||||||
|
|
||||||
## Phrase queries vs. terms queries
|
## Phrase queries vs. terms queries
|
||||||
|
|
||||||
For full-text search you can specify either a **phrase** query like `"the old man and the sea"`,
|
For full-text search you can specify either a **phrase** query like `"the old man and the sea"`,
|
||||||
or a **terms** search query like `"(Old AND Man) AND Sea"`. For more details on the terms
|
or a **terms** search query like `"(Old AND Man) AND Sea"`. For more details on the terms
|
||||||
query syntax, see Tantivy's [query parser rules](https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html).
|
query syntax, see Tantivy's [query parser rules](https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html).
|
||||||
|
|
||||||
@@ -112,7 +142,7 @@ double quotes replaced by single quotes.
|
|||||||
|
|
||||||
## Configurations
|
## Configurations
|
||||||
|
|
||||||
By default, LanceDB configures a 1GB heap size limit for creating the index. You can
|
By default, LanceDB configures a 1GB heap size limit for creating the index. You can
|
||||||
reduce this if running on a smaller node, or increase this for faster performance while
|
reduce this if running on a smaller node, or increase this for faster performance while
|
||||||
indexing a larger corpus.
|
indexing a larger corpus.
|
||||||
|
|
||||||
@@ -128,7 +158,6 @@ table.create_fts_index(["text1", "text2"], writer_heap_size=heap, replace=True)
|
|||||||
If you add data after FTS index creation, it won't be reflected
|
If you add data after FTS index creation, it won't be reflected
|
||||||
in search results until you do a full reindex.
|
in search results until you do a full reindex.
|
||||||
|
|
||||||
2. We currently only support local filesystem paths for the FTS index.
|
2. We currently only support local filesystem paths for the FTS index.
|
||||||
This is a tantivy limitation. We've implemented an object store plugin
|
This is a tantivy limitation. We've implemented an object store plugin
|
||||||
but there's no way in tantivy-py to specify to use it.
|
but there's no way in tantivy-py to specify to use it.
|
||||||
|
|
||||||
|
|||||||
@@ -168,151 +168,151 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
--8<-- "docs/src/basic_legacy.ts:create_f16_table"
|
--8<-- "docs/src/basic_legacy.ts:create_f16_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
### From Pydantic Models
|
### From Pydantic Models
|
||||||
|
|
||||||
When you create an empty table without data, you must specify the table schema.
|
When you create an empty table without data, you must specify the table schema.
|
||||||
LanceDB supports creating tables by specifying a PyArrow schema or a specialized
|
LanceDB supports creating tables by specifying a PyArrow schema or a specialized
|
||||||
Pydantic model called `LanceModel`.
|
Pydantic model called `LanceModel`.
|
||||||
|
|
||||||
For example, the following Content model specifies a table with 5 columns:
|
For example, the following Content model specifies a table with 5 columns:
|
||||||
`movie_id`, `vector`, `genres`, `title`, and `imdb_id`. When you create a table, you can
|
`movie_id`, `vector`, `genres`, `title`, and `imdb_id`. When you create a table, you can
|
||||||
pass the class as the value of the `schema` parameter to `create_table`.
|
pass the class as the value of the `schema` parameter to `create_table`.
|
||||||
The `vector` column is a `Vector` type, which is a specialized Pydantic type that
|
The `vector` column is a `Vector` type, which is a specialized Pydantic type that
|
||||||
can be configured with the vector dimensions. It is also important to note that
|
can be configured with the vector dimensions. It is also important to note that
|
||||||
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
|
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
|
||||||
(which itself derives from `pydantic.BaseModel`).
|
(which itself derives from `pydantic.BaseModel`).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lancedb.pydantic import Vector, LanceModel
|
from lancedb.pydantic import Vector, LanceModel
|
||||||
|
|
||||||
class Content(LanceModel):
|
class Content(LanceModel):
|
||||||
movie_id: int
|
movie_id: int
|
||||||
vector: Vector(128)
|
vector: Vector(128)
|
||||||
genres: str
|
genres: str
|
||||||
title: str
|
title: str
|
||||||
imdb_id: int
|
imdb_id: int
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def imdb_url(self) -> str:
|
def imdb_url(self) -> str:
|
||||||
return f"https://www.imdb.com/title/tt{self.imdb_id}"
|
return f"https://www.imdb.com/title/tt{self.imdb_id}"
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
db = lancedb.connect("~/.lancedb")
|
db = lancedb.connect("~/.lancedb")
|
||||||
table_name = "movielens_small"
|
table_name = "movielens_small"
|
||||||
table = db.create_table(table_name, schema=Content)
|
table = db.create_table(table_name, schema=Content)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Nested schemas
|
#### Nested schemas
|
||||||
|
|
||||||
Sometimes your data model may contain nested objects.
|
Sometimes your data model may contain nested objects.
|
||||||
For example, you may want to store the document string
|
For example, you may want to store the document string
|
||||||
and the document soure name as a nested Document object:
|
and the document soure name as a nested Document object:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class Document(BaseModel):
|
class Document(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
source: str
|
source: str
|
||||||
```
|
```
|
||||||
|
|
||||||
This can be used as the type of a LanceDB table column:
|
This can be used as the type of a LanceDB table column:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class NestedSchema(LanceModel):
|
class NestedSchema(LanceModel):
|
||||||
id: str
|
id: str
|
||||||
vector: Vector(1536)
|
vector: Vector(1536)
|
||||||
document: Document
|
document: Document
|
||||||
|
|
||||||
tbl = db.create_table("nested_table", schema=NestedSchema, mode="overwrite")
|
tbl = db.create_table("nested_table", schema=NestedSchema, mode="overwrite")
|
||||||
```
|
```
|
||||||
|
|
||||||
This creates a struct column called "document" that has two subfields
|
This creates a struct column called "document" that has two subfields
|
||||||
called "content" and "source":
|
called "content" and "source":
|
||||||
|
|
||||||
```
|
```
|
||||||
In [28]: tbl.schema
|
In [28]: tbl.schema
|
||||||
Out[28]:
|
Out[28]:
|
||||||
id: string not null
|
id: string not null
|
||||||
vector: fixed_size_list<item: float>[1536] not null
|
vector: fixed_size_list<item: float>[1536] not null
|
||||||
child 0, item: float
|
child 0, item: float
|
||||||
document: struct<content: string not null, source: string not null> not null
|
document: struct<content: string not null, source: string not null> not null
|
||||||
child 0, content: string not null
|
child 0, content: string not null
|
||||||
child 1, source: string not null
|
child 1, source: string not null
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Validators
|
#### Validators
|
||||||
|
|
||||||
Note that neither Pydantic nor PyArrow automatically validates that input data
|
Note that neither Pydantic nor PyArrow automatically validates that input data
|
||||||
is of the correct timezone, but this is easy to add as a custom field validator:
|
is of the correct timezone, but this is easy to add as a custom field validator:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from lancedb.pydantic import LanceModel
|
from lancedb.pydantic import LanceModel
|
||||||
from pydantic import Field, field_validator, ValidationError, ValidationInfo
|
from pydantic import Field, field_validator, ValidationError, ValidationInfo
|
||||||
|
|
||||||
tzname = "America/New_York"
|
tzname = "America/New_York"
|
||||||
tz = ZoneInfo(tzname)
|
tz = ZoneInfo(tzname)
|
||||||
|
|
||||||
class TestModel(LanceModel):
|
class TestModel(LanceModel):
|
||||||
dt_with_tz: datetime = Field(json_schema_extra={"tz": tzname})
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": tzname})
|
||||||
|
|
||||||
@field_validator('dt_with_tz')
|
@field_validator('dt_with_tz')
|
||||||
@classmethod
|
@classmethod
|
||||||
def tz_must_match(cls, dt: datetime) -> datetime:
|
def tz_must_match(cls, dt: datetime) -> datetime:
|
||||||
assert dt.tzinfo == tz
|
assert dt.tzinfo == tz
|
||||||
return dt
|
return dt
|
||||||
|
|
||||||
ok = TestModel(dt_with_tz=datetime.now(tz))
|
ok = TestModel(dt_with_tz=datetime.now(tz))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
TestModel(dt_with_tz=datetime.now(ZoneInfo("Asia/Shanghai")))
|
TestModel(dt_with_tz=datetime.now(ZoneInfo("Asia/Shanghai")))
|
||||||
assert 0 == 1, "this should raise ValidationError"
|
assert 0 == 1, "this should raise ValidationError"
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
print("A ValidationError was raised.")
|
print("A ValidationError was raised.")
|
||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
When you run this code it should print "A ValidationError was raised."
|
When you run this code it should print "A ValidationError was raised."
|
||||||
|
|
||||||
#### Pydantic custom types
|
#### Pydantic custom types
|
||||||
|
|
||||||
LanceDB does NOT yet support converting pydantic custom types. If this is something you need,
|
LanceDB does NOT yet support converting pydantic custom types. If this is something you need,
|
||||||
please file a feature request on the [LanceDB Github repo](https://github.com/lancedb/lancedb/issues/new).
|
please file a feature request on the [LanceDB Github repo](https://github.com/lancedb/lancedb/issues/new).
|
||||||
|
|
||||||
### Using Iterators / Writing Large Datasets
|
### Using Iterators / Writing Large Datasets
|
||||||
|
|
||||||
It is recommended to use iterators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()`
|
It is recommended to use iterators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()`
|
||||||
|
|
||||||
LanceDB additionally supports PyArrow's `RecordBatch` Iterators or other generators producing supported data types.
|
LanceDB additionally supports PyArrow's `RecordBatch` Iterators or other generators producing supported data types.
|
||||||
|
|
||||||
Here's an example using using `RecordBatch` iterator for creating tables.
|
Here's an example using using `RecordBatch` iterator for creating tables.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
def make_batches():
|
def make_batches():
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
yield pa.RecordBatch.from_arrays(
|
yield pa.RecordBatch.from_arrays(
|
||||||
[
|
[
|
||||||
pa.array([[3.1, 4.1, 5.1, 6.1], [5.9, 26.5, 4.7, 32.8]],
|
pa.array([[3.1, 4.1, 5.1, 6.1], [5.9, 26.5, 4.7, 32.8]],
|
||||||
pa.list_(pa.float32(), 4)),
|
pa.list_(pa.float32(), 4)),
|
||||||
pa.array(["foo", "bar"]),
|
pa.array(["foo", "bar"]),
|
||||||
pa.array([10.0, 20.0]),
|
pa.array([10.0, 20.0]),
|
||||||
],
|
],
|
||||||
["vector", "item", "price"],
|
["vector", "item", "price"],
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = pa.schema([
|
schema = pa.schema([
|
||||||
pa.field("vector", pa.list_(pa.float32(), 4)),
|
pa.field("vector", pa.list_(pa.float32(), 4)),
|
||||||
pa.field("item", pa.utf8()),
|
pa.field("item", pa.utf8()),
|
||||||
pa.field("price", pa.float32()),
|
pa.field("price", pa.float32()),
|
||||||
])
|
])
|
||||||
|
|
||||||
db.create_table("batched_tale", make_batches(), schema=schema)
|
db.create_table("batched_tale", make_batches(), schema=schema)
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example.
|
You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example.
|
||||||
|
|
||||||
## Open existing tables
|
## Open existing tables
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ LanceDB **Cloud** is a SaaS (software-as-a-service) solution that runs serverles
|
|||||||
|
|
||||||
* Fast production-scale vector similarity, full-text & hybrid search and a SQL query interface (via [DataFusion](https://github.com/apache/arrow-datafusion))
|
* Fast production-scale vector similarity, full-text & hybrid search and a SQL query interface (via [DataFusion](https://github.com/apache/arrow-datafusion))
|
||||||
|
|
||||||
* Native Python and Javascript/Typescript support
|
* Python, Javascript/Typescript, and Rust support
|
||||||
|
|
||||||
* Store, query & manage multi-modal data (text, images, videos, point clouds, etc.), not just the embeddings and metadata
|
* Store, query & manage multi-modal data (text, images, videos, point clouds, etc.), not just the embeddings and metadata
|
||||||
|
|
||||||
@@ -54,3 +54,4 @@ The following pages go deeper into the internal of LanceDB and how to use it.
|
|||||||
* [Ecosystem Integrations](integrations/index.md): Integrate LanceDB with other tools in the data ecosystem
|
* [Ecosystem Integrations](integrations/index.md): Integrate LanceDB with other tools in the data ecosystem
|
||||||
* [Python API Reference](python/python.md): Python OSS and Cloud API references
|
* [Python API Reference](python/python.md): Python OSS and Cloud API references
|
||||||
* [JavaScript API Reference](javascript/modules.md): JavaScript OSS and Cloud API references
|
* [JavaScript API Reference](javascript/modules.md): JavaScript OSS and Cloud API references
|
||||||
|
* [Rust API Reference](https://docs.rs/lancedb/latest/lancedb/index.html): Rust API reference
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ Currently, LanceDB supports the following metrics:
|
|||||||
## Exhaustive search (kNN)
|
## Exhaustive search (kNN)
|
||||||
|
|
||||||
If you do not create a vector index, LanceDB exhaustively scans the _entire_ vector space
|
If you do not create a vector index, LanceDB exhaustively scans the _entire_ vector space
|
||||||
and compute the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search.
|
and computes the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search.
|
||||||
|
|
||||||
<!-- Setup Code
|
<!-- Setup Code
|
||||||
```python
|
```python
|
||||||
@@ -85,7 +85,7 @@ To perform scalable vector retrieval with acceptable latencies, it's common to b
|
|||||||
While the exhaustive search is guaranteed to always return 100% recall, the approximate nature of
|
While the exhaustive search is guaranteed to always return 100% recall, the approximate nature of
|
||||||
an ANN search means that using an index often involves a trade-off between recall and latency.
|
an ANN search means that using an index often involves a trade-off between recall and latency.
|
||||||
|
|
||||||
See the [IVF_PQ index](./concepts/index_ivfpq.md.md) for a deeper description of how `IVF_PQ`
|
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ`
|
||||||
indexes work in LanceDB.
|
indexes work in LanceDB.
|
||||||
|
|
||||||
## Output search results
|
## Output search results
|
||||||
@@ -184,4 +184,3 @@ Let's create a LanceDB table with a nested schema:
|
|||||||
|
|
||||||
Note that in this case the extra `_distance` field is discarded since
|
Note that in this case the extra `_distance` field is discarded since
|
||||||
it's not part of the LanceSchema.
|
it's not part of the LanceSchema.
|
||||||
|
|
||||||
|
|||||||
60
node/package-lock.json
generated
60
node/package-lock.json
generated
@@ -333,6 +333,66 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
|
"version": "0.4.13",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.13.tgz",
|
||||||
|
"integrity": "sha512-JfroNCG8yKIU931Y+x8d0Fp8C9DHUSC5j+CjI+e5err7rTWtie4j3JbsXlWAnPFaFEOg0Xk3BWkSikCvhPGJGg==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
|
"version": "0.4.13",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.13.tgz",
|
||||||
|
"integrity": "sha512-dG6IMvfpHpnHdbJ0UffzJ7cZfMiC02MjIi6YJzgx+hKz2UNXWNBIfTvvhqli85mZsGRXL1OYDdYv0K1YzNjXlA==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
|
"version": "0.4.13",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.13.tgz",
|
||||||
|
"integrity": "sha512-BRR1VzaMviXby7qmLm0axNZM8eUZF3ZqfvnDKdVRpC3LaRueD6pMXHuC2IUKaFkn7xktf+8BlDZb6foFNEj8bQ==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
|
"version": "0.4.13",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.13.tgz",
|
||||||
|
"integrity": "sha512-WnekZ7ZMlria+NODZ6aBCljCFQSe2bBNUS9ZpyFl/Y1vHduSQPuBxM6V7vp2QubC0daq/rifgjDob89DF+x3xw==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
|
"version": "0.4.13",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.13.tgz",
|
||||||
|
"integrity": "sha512-3NDpMWBL2ksDHXAraXhowiLqQcNWM5bdbeHwze4+InYMD54hyQ2ODNc+4usxp63Nya9biVnFS27yXULqkzIEqQ==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
]
|
||||||
|
},
|
||||||
"node_modules/@neon-rs/cli": {
|
"node_modules/@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ import {
|
|||||||
import type { IntBitWidth, TimeBitWidth } from "apache-arrow/type";
|
import type { IntBitWidth, TimeBitWidth } from "apache-arrow/type";
|
||||||
|
|
||||||
function sanitizeMetadata(
|
function sanitizeMetadata(
|
||||||
metadataLike?: unknown
|
metadataLike?: unknown,
|
||||||
): Map<string, string> | undefined {
|
): Map<string, string> | undefined {
|
||||||
if (metadataLike === undefined || metadataLike === null) {
|
if (metadataLike === undefined || metadataLike === null) {
|
||||||
return undefined;
|
return undefined;
|
||||||
@@ -90,7 +90,7 @@ function sanitizeMetadata(
|
|||||||
for (const item of metadataLike) {
|
for (const item of metadataLike) {
|
||||||
if (!(typeof item[0] === "string" || !(typeof item[1] === "string"))) {
|
if (!(typeof item[0] === "string" || !(typeof item[1] === "string"))) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected metadata, if present, to be a Map<string, string> but it had non-string keys or values"
|
"Expected metadata, if present, to be a Map<string, string> but it had non-string keys or values",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -105,7 +105,7 @@ function sanitizeInt(typeLike: object) {
|
|||||||
typeof typeLike.isSigned !== "boolean"
|
typeof typeLike.isSigned !== "boolean"
|
||||||
) {
|
) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected an Int Type to have a `bitWidth` and `isSigned` property"
|
"Expected an Int Type to have a `bitWidth` and `isSigned` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
|
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
|
||||||
@@ -128,7 +128,7 @@ function sanitizeDecimal(typeLike: object) {
|
|||||||
typeof typeLike.bitWidth !== "number"
|
typeof typeLike.bitWidth !== "number"
|
||||||
) {
|
) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties"
|
"Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
|
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
|
||||||
@@ -149,7 +149,7 @@ function sanitizeTime(typeLike: object) {
|
|||||||
typeof typeLike.bitWidth !== "number"
|
typeof typeLike.bitWidth !== "number"
|
||||||
) {
|
) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Time type to have `unit` and `bitWidth` properties"
|
"Expected a Time type to have `unit` and `bitWidth` properties",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
|
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
|
||||||
@@ -172,7 +172,7 @@ function sanitizeTypedTimestamp(
|
|||||||
| typeof TimestampNanosecond
|
| typeof TimestampNanosecond
|
||||||
| typeof TimestampMicrosecond
|
| typeof TimestampMicrosecond
|
||||||
| typeof TimestampMillisecond
|
| typeof TimestampMillisecond
|
||||||
| typeof TimestampSecond
|
| typeof TimestampSecond,
|
||||||
) {
|
) {
|
||||||
let timezone = null;
|
let timezone = null;
|
||||||
if ("timezone" in typeLike && typeof typeLike.timezone === "string") {
|
if ("timezone" in typeLike && typeof typeLike.timezone === "string") {
|
||||||
@@ -191,7 +191,7 @@ function sanitizeInterval(typeLike: object) {
|
|||||||
function sanitizeList(typeLike: object) {
|
function sanitizeList(typeLike: object) {
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a List type to have an array-like `children` property"
|
"Expected a List type to have an array-like `children` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (typeLike.children.length !== 1) {
|
if (typeLike.children.length !== 1) {
|
||||||
@@ -203,7 +203,7 @@ function sanitizeList(typeLike: object) {
|
|||||||
function sanitizeStruct(typeLike: object) {
|
function sanitizeStruct(typeLike: object) {
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Struct type to have an array-like `children` property"
|
"Expected a Struct type to have an array-like `children` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
|
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
|
||||||
@@ -216,47 +216,47 @@ function sanitizeUnion(typeLike: object) {
|
|||||||
typeof typeLike.mode !== "number"
|
typeof typeLike.mode !== "number"
|
||||||
) {
|
) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Union type to have `typeIds` and `mode` properties"
|
"Expected a Union type to have `typeIds` and `mode` properties",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Union type to have an array-like `children` property"
|
"Expected a Union type to have an array-like `children` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Union(
|
return new Union(
|
||||||
typeLike.mode,
|
typeLike.mode,
|
||||||
typeLike.typeIds as any,
|
typeLike.typeIds as any,
|
||||||
typeLike.children.map((child) => sanitizeField(child))
|
typeLike.children.map((child) => sanitizeField(child)),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeTypedUnion(
|
function sanitizeTypedUnion(
|
||||||
typeLike: object,
|
typeLike: object,
|
||||||
UnionType: typeof DenseUnion | typeof SparseUnion
|
UnionType: typeof DenseUnion | typeof SparseUnion,
|
||||||
) {
|
) {
|
||||||
if (!("typeIds" in typeLike)) {
|
if (!("typeIds" in typeLike)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a DenseUnion/SparseUnion type to have a `typeIds` property"
|
"Expected a DenseUnion/SparseUnion type to have a `typeIds` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a DenseUnion/SparseUnion type to have an array-like `children` property"
|
"Expected a DenseUnion/SparseUnion type to have an array-like `children` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new UnionType(
|
return new UnionType(
|
||||||
typeLike.typeIds as any,
|
typeLike.typeIds as any,
|
||||||
typeLike.children.map((child) => sanitizeField(child))
|
typeLike.children.map((child) => sanitizeField(child)),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeFixedSizeBinary(typeLike: object) {
|
function sanitizeFixedSizeBinary(typeLike: object) {
|
||||||
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
|
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a FixedSizeBinary type to have a `byteWidth` property"
|
"Expected a FixedSizeBinary type to have a `byteWidth` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return new FixedSizeBinary(typeLike.byteWidth);
|
return new FixedSizeBinary(typeLike.byteWidth);
|
||||||
@@ -268,7 +268,7 @@ function sanitizeFixedSizeList(typeLike: object) {
|
|||||||
}
|
}
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a FixedSizeList type to have an array-like `children` property"
|
"Expected a FixedSizeList type to have an array-like `children` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (typeLike.children.length !== 1) {
|
if (typeLike.children.length !== 1) {
|
||||||
@@ -276,14 +276,14 @@ function sanitizeFixedSizeList(typeLike: object) {
|
|||||||
}
|
}
|
||||||
return new FixedSizeList(
|
return new FixedSizeList(
|
||||||
typeLike.listSize,
|
typeLike.listSize,
|
||||||
sanitizeField(typeLike.children[0])
|
sanitizeField(typeLike.children[0]),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function sanitizeMap(typeLike: object) {
|
function sanitizeMap(typeLike: object) {
|
||||||
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"Expected a Map type to have an array-like `children` property"
|
"Expected a Map type to have an array-like `children` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (!("keysSorted" in typeLike) || typeof typeLike.keysSorted !== "boolean") {
|
if (!("keysSorted" in typeLike) || typeof typeLike.keysSorted !== "boolean") {
|
||||||
@@ -291,7 +291,7 @@ function sanitizeMap(typeLike: object) {
|
|||||||
}
|
}
|
||||||
return new Map_(
|
return new Map_(
|
||||||
typeLike.children.map((field) => sanitizeField(field)) as any,
|
typeLike.children.map((field) => sanitizeField(field)) as any,
|
||||||
typeLike.keysSorted
|
typeLike.keysSorted,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,7 +319,7 @@ function sanitizeDictionary(typeLike: object) {
|
|||||||
sanitizeType(typeLike.dictionary),
|
sanitizeType(typeLike.dictionary),
|
||||||
sanitizeType(typeLike.indices) as any,
|
sanitizeType(typeLike.indices) as any,
|
||||||
typeLike.id,
|
typeLike.id,
|
||||||
typeLike.isOrdered
|
typeLike.isOrdered,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -454,7 +454,7 @@ function sanitizeField(fieldLike: unknown): Field {
|
|||||||
!("nullable" in fieldLike)
|
!("nullable" in fieldLike)
|
||||||
) {
|
) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"The field passed in is missing a `type`/`name`/`nullable` property"
|
"The field passed in is missing a `type`/`name`/`nullable` property",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
const type = sanitizeType(fieldLike.type);
|
const type = sanitizeType(fieldLike.type);
|
||||||
@@ -473,6 +473,13 @@ function sanitizeField(fieldLike: unknown): Field {
|
|||||||
return new Field(name, type, nullable, metadata);
|
return new Field(name, type, nullable, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert something schemaLike into a Schema instance
|
||||||
|
*
|
||||||
|
* This method is often needed even when the caller is using a Schema
|
||||||
|
* instance because they might be using a different instance of apache-arrow
|
||||||
|
* than lancedb is using.
|
||||||
|
*/
|
||||||
export function sanitizeSchema(schemaLike: unknown): Schema {
|
export function sanitizeSchema(schemaLike: unknown): Schema {
|
||||||
if (schemaLike instanceof Schema) {
|
if (schemaLike instanceof Schema) {
|
||||||
return schemaLike;
|
return schemaLike;
|
||||||
@@ -482,7 +489,7 @@ export function sanitizeSchema(schemaLike: unknown): Schema {
|
|||||||
}
|
}
|
||||||
if (!("fields" in schemaLike)) {
|
if (!("fields" in schemaLike)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"The schema passed in does not appear to be a schema (no 'fields' property)"
|
"The schema passed in does not appear to be a schema (no 'fields' property)",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let metadata;
|
let metadata;
|
||||||
@@ -491,11 +498,11 @@ export function sanitizeSchema(schemaLike: unknown): Schema {
|
|||||||
}
|
}
|
||||||
if (!Array.isArray(schemaLike.fields)) {
|
if (!Array.isArray(schemaLike.fields)) {
|
||||||
throw Error(
|
throw Error(
|
||||||
"The schema passed in had a 'fields' property but it was not an array"
|
"The schema passed in had a 'fields' property but it was not an array",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
const sanitizedFields = schemaLike.fields.map((field) =>
|
const sanitizedFields = schemaLike.fields.map((field) =>
|
||||||
sanitizeField(field)
|
sanitizeField(field),
|
||||||
);
|
);
|
||||||
return new Schema(sanitizedFields, metadata);
|
return new Schema(sanitizedFields, metadata);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -129,11 +129,25 @@ describe("When creating an index", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Search without specifying the column
|
// Search without specifying the column
|
||||||
const rst = await tbl.query().nearestTo(queryVec).limit(2).toArrow();
|
let rst = await tbl
|
||||||
|
.query()
|
||||||
|
.limit(2)
|
||||||
|
.nearestTo(queryVec)
|
||||||
|
.distanceType("DoT")
|
||||||
|
.toArrow();
|
||||||
|
expect(rst.numRows).toBe(2);
|
||||||
|
|
||||||
|
// Search using `vectorSearch`
|
||||||
|
rst = await tbl.vectorSearch(queryVec).limit(2).toArrow();
|
||||||
expect(rst.numRows).toBe(2);
|
expect(rst.numRows).toBe(2);
|
||||||
|
|
||||||
// Search with specifying the column
|
// Search with specifying the column
|
||||||
const rst2 = await tbl.search(queryVec, "vec").limit(2).toArrow();
|
const rst2 = await tbl
|
||||||
|
.query()
|
||||||
|
.limit(2)
|
||||||
|
.nearestTo(queryVec)
|
||||||
|
.column("vec")
|
||||||
|
.toArrow();
|
||||||
expect(rst2.numRows).toBe(2);
|
expect(rst2.numRows).toBe(2);
|
||||||
expect(rst.toString()).toEqual(rst2.toString());
|
expect(rst.toString()).toEqual(rst2.toString());
|
||||||
});
|
});
|
||||||
@@ -163,7 +177,7 @@ describe("When creating an index", () => {
|
|||||||
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
||||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||||
|
|
||||||
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
|
for await (const r of tbl.query().where("id > 1").select(["id"])) {
|
||||||
expect(r.numRows).toBe(298);
|
expect(r.numRows).toBe(298);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -205,33 +219,39 @@ describe("When creating an index", () => {
|
|||||||
|
|
||||||
const rst = await tbl
|
const rst = await tbl
|
||||||
.query()
|
.query()
|
||||||
|
.limit(2)
|
||||||
.nearestTo(
|
.nearestTo(
|
||||||
Array(32)
|
Array(32)
|
||||||
.fill(1)
|
.fill(1)
|
||||||
.map(() => Math.random()),
|
.map(() => Math.random()),
|
||||||
)
|
)
|
||||||
.limit(2)
|
|
||||||
.toArrow();
|
.toArrow();
|
||||||
expect(rst.numRows).toBe(2);
|
expect(rst.numRows).toBe(2);
|
||||||
|
|
||||||
// Search with specifying the column
|
// Search with specifying the column
|
||||||
await expect(
|
await expect(
|
||||||
tbl
|
tbl
|
||||||
.search(
|
.query()
|
||||||
|
.limit(2)
|
||||||
|
.nearestTo(
|
||||||
Array(64)
|
Array(64)
|
||||||
.fill(1)
|
.fill(1)
|
||||||
.map(() => Math.random()),
|
.map(() => Math.random()),
|
||||||
"vec",
|
|
||||||
)
|
)
|
||||||
.limit(2)
|
.column("vec")
|
||||||
.toArrow(),
|
.toArrow(),
|
||||||
).rejects.toThrow(/.*does not match the dimension.*/);
|
).rejects.toThrow(/.* query dim=64, expected vector dim=32.*/);
|
||||||
|
|
||||||
const query64 = Array(64)
|
const query64 = Array(64)
|
||||||
.fill(1)
|
.fill(1)
|
||||||
.map(() => Math.random());
|
.map(() => Math.random());
|
||||||
const rst64Query = await tbl.query().nearestTo(query64).limit(2).toArrow();
|
const rst64Query = await tbl.query().limit(2).nearestTo(query64).toArrow();
|
||||||
const rst64Search = await tbl.search(query64, "vec2").limit(2).toArrow();
|
const rst64Search = await tbl
|
||||||
|
.query()
|
||||||
|
.limit(2)
|
||||||
|
.nearestTo(query64)
|
||||||
|
.column("vec2")
|
||||||
|
.toArrow();
|
||||||
expect(rst64Query.toString()).toEqual(rst64Search.toString());
|
expect(rst64Query.toString()).toEqual(rst64Search.toString());
|
||||||
expect(rst64Query.numRows).toBe(2);
|
expect(rst64Query.numRows).toBe(2);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -4,14 +4,25 @@
|
|||||||
const eslint = require("@eslint/js");
|
const eslint = require("@eslint/js");
|
||||||
const tseslint = require("typescript-eslint");
|
const tseslint = require("typescript-eslint");
|
||||||
const eslintConfigPrettier = require("eslint-config-prettier");
|
const eslintConfigPrettier = require("eslint-config-prettier");
|
||||||
|
const jsdoc = require("eslint-plugin-jsdoc");
|
||||||
|
|
||||||
module.exports = tseslint.config(
|
module.exports = tseslint.config(
|
||||||
eslint.configs.recommended,
|
eslint.configs.recommended,
|
||||||
|
jsdoc.configs["flat/recommended"],
|
||||||
eslintConfigPrettier,
|
eslintConfigPrettier,
|
||||||
...tseslint.configs.recommended,
|
...tseslint.configs.recommended,
|
||||||
{
|
{
|
||||||
rules: {
|
rules: {
|
||||||
"@typescript-eslint/naming-convention": "error",
|
"@typescript-eslint/naming-convention": "error",
|
||||||
|
"jsdoc/require-returns": "off",
|
||||||
|
"jsdoc/require-param": "off",
|
||||||
|
"jsdoc/require-jsdoc": [
|
||||||
|
"error",
|
||||||
|
{
|
||||||
|
publicOnly: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
|
plugins: jsdoc,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import {
|
|||||||
DataType,
|
DataType,
|
||||||
Binary,
|
Binary,
|
||||||
Float32,
|
Float32,
|
||||||
|
type makeTable,
|
||||||
} from "apache-arrow";
|
} from "apache-arrow";
|
||||||
import { type EmbeddingFunction } from "./embedding/embedding_function";
|
import { type EmbeddingFunction } from "./embedding/embedding_function";
|
||||||
import { sanitizeSchema } from "./sanitize";
|
import { sanitizeSchema } from "./sanitize";
|
||||||
@@ -128,14 +129,7 @@ export class MakeArrowTableOptions {
|
|||||||
* - Buffer => Binary
|
* - Buffer => Binary
|
||||||
* - Record<String, any> => Struct
|
* - Record<String, any> => Struct
|
||||||
* - Array<any> => List
|
* - Array<any> => List
|
||||||
*
|
|
||||||
* @param data input data
|
|
||||||
* @param options options to control the makeArrowTable call.
|
|
||||||
*
|
|
||||||
* @example
|
* @example
|
||||||
*
|
|
||||||
* ```ts
|
|
||||||
*
|
|
||||||
* import { fromTableToBuffer, makeArrowTable } from "../arrow";
|
* import { fromTableToBuffer, makeArrowTable } from "../arrow";
|
||||||
* import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow";
|
* import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow";
|
||||||
*
|
*
|
||||||
@@ -307,7 +301,9 @@ export function makeEmptyTable(schema: Schema): ArrowTable {
|
|||||||
return makeArrowTable([], { schema });
|
return makeArrowTable([], { schema });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to convert Array<Array<any>> to a variable sized list array
|
/**
|
||||||
|
* Helper function to convert Array<Array<any>> to a variable sized list array
|
||||||
|
*/
|
||||||
// @ts-expect-error (Vector<unknown> is not assignable to Vector<any>)
|
// @ts-expect-error (Vector<unknown> is not assignable to Vector<any>)
|
||||||
function makeListVector(lists: unknown[][]): Vector<unknown> {
|
function makeListVector(lists: unknown[][]): Vector<unknown> {
|
||||||
if (lists.length === 0 || lists[0].length === 0) {
|
if (lists.length === 0 || lists[0].length === 0) {
|
||||||
@@ -333,7 +329,7 @@ function makeListVector(lists: unknown[][]): Vector<unknown> {
|
|||||||
return listBuilder.finish().toVector();
|
return listBuilder.finish().toVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to convert an Array of JS values to an Arrow Vector
|
/** Helper function to convert an Array of JS values to an Arrow Vector */
|
||||||
function makeVector(
|
function makeVector(
|
||||||
values: unknown[],
|
values: unknown[],
|
||||||
type?: DataType,
|
type?: DataType,
|
||||||
@@ -374,6 +370,7 @@ function makeVector(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Helper function to apply embeddings to an input table */
|
||||||
async function applyEmbeddings<T>(
|
async function applyEmbeddings<T>(
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunction<T>,
|
||||||
@@ -466,7 +463,7 @@ async function applyEmbeddings<T>(
|
|||||||
return newTable;
|
return newTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/**
|
||||||
* Convert an Array of records into an Arrow Table, optionally applying an
|
* Convert an Array of records into an Arrow Table, optionally applying an
|
||||||
* embeddings function to it.
|
* embeddings function to it.
|
||||||
*
|
*
|
||||||
@@ -493,7 +490,7 @@ export async function convertToTable<T>(
|
|||||||
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema);
|
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates the Arrow Type for a Vector column with dimension `dim`
|
/** Creates the Arrow Type for a Vector column with dimension `dim` */
|
||||||
function newVectorType<T extends Float>(
|
function newVectorType<T extends Float>(
|
||||||
dim: number,
|
dim: number,
|
||||||
innerType: T,
|
innerType: T,
|
||||||
@@ -565,6 +562,14 @@ export async function fromTableToBuffer<T>(
|
|||||||
return Buffer.from(await writer.toUint8Array());
|
return Buffer.from(await writer.toUint8Array());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
|
||||||
|
*
|
||||||
|
* This function will apply `embeddings` to the table in a manner similar to
|
||||||
|
* `convertToTable`.
|
||||||
|
*
|
||||||
|
* `schema` is required if the table is empty
|
||||||
|
*/
|
||||||
export async function fromDataToBuffer<T>(
|
export async function fromDataToBuffer<T>(
|
||||||
data: Data,
|
data: Data,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunction<T>,
|
||||||
@@ -599,6 +604,9 @@ export async function fromTableToStreamBuffer<T>(
|
|||||||
return Buffer.from(await writer.toUint8Array());
|
return Buffer.from(await writer.toUint8Array());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reorder the columns in `batch` so that they agree with the field order in `schema`
|
||||||
|
*/
|
||||||
function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
|
function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
|
||||||
const alignedChildren = [];
|
const alignedChildren = [];
|
||||||
for (const field of schema.fields) {
|
for (const field of schema.fields) {
|
||||||
@@ -621,6 +629,9 @@ function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
|
|||||||
return new RecordBatch(schema, newData);
|
return new RecordBatch(schema, newData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reorder the columns in `table` so that they agree with the field order in `schema`
|
||||||
|
*/
|
||||||
function alignTable(table: ArrowTable, schema: Schema): ArrowTable {
|
function alignTable(table: ArrowTable, schema: Schema): ArrowTable {
|
||||||
const alignedBatches = table.batches.map((batch) =>
|
const alignedBatches = table.batches.map((batch) =>
|
||||||
alignBatch(batch, schema),
|
alignBatch(batch, schema),
|
||||||
@@ -628,7 +639,9 @@ function alignTable(table: ArrowTable, schema: Schema): ArrowTable {
|
|||||||
return new ArrowTable(schema, alignedBatches);
|
return new ArrowTable(schema, alignedBatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates an empty Arrow Table
|
/**
|
||||||
|
* Create an empty table with the given schema
|
||||||
|
*/
|
||||||
export function createEmptyTable(schema: Schema): ArrowTable {
|
export function createEmptyTable(schema: Schema): ArrowTable {
|
||||||
return new ArrowTable(sanitizeSchema(schema));
|
return new ArrowTable(sanitizeSchema(schema));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,8 @@ export class Connection {
|
|||||||
return this.inner.isOpen();
|
return this.inner.isOpen();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Close the connection, releasing any underlying resources.
|
/**
|
||||||
|
* Close the connection, releasing any underlying resources.
|
||||||
*
|
*
|
||||||
* It is safe to call this method multiple times.
|
* It is safe to call this method multiple times.
|
||||||
*
|
*
|
||||||
@@ -93,11 +94,12 @@ export class Connection {
|
|||||||
return this.inner.display();
|
return this.inner.display();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** List all the table names in this database.
|
/**
|
||||||
|
* List all the table names in this database.
|
||||||
*
|
*
|
||||||
* Tables will be returned in lexicographical order.
|
* Tables will be returned in lexicographical order.
|
||||||
*
|
* @param {Partial<TableNamesOptions>} options - options to control the
|
||||||
* @param options Optional parameters to control the listing.
|
* paging / start point
|
||||||
*/
|
*/
|
||||||
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
|
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
|
||||||
return this.inner.tableNames(options?.startAfter, options?.limit);
|
return this.inner.tableNames(options?.startAfter, options?.limit);
|
||||||
@@ -105,9 +107,7 @@ export class Connection {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Open a table in the database.
|
* Open a table in the database.
|
||||||
*
|
* @param {string} name - The name of the table
|
||||||
* @param name The name of the table.
|
|
||||||
* @param embeddings An embedding function to use on this table
|
|
||||||
*/
|
*/
|
||||||
async openTable(name: string): Promise<Table> {
|
async openTable(name: string): Promise<Table> {
|
||||||
const innerTable = await this.inner.openTable(name);
|
const innerTable = await this.inner.openTable(name);
|
||||||
@@ -116,9 +116,9 @@ export class Connection {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new Table and initialize it with new data.
|
* Creates a new Table and initialize it with new data.
|
||||||
*
|
|
||||||
* @param {string} name - The name of the table.
|
* @param {string} name - The name of the table.
|
||||||
* @param data - Non-empty Array of Records to be inserted into the table
|
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records
|
||||||
|
* to be inserted into the table
|
||||||
*/
|
*/
|
||||||
async createTable(
|
async createTable(
|
||||||
name: string,
|
name: string,
|
||||||
@@ -145,9 +145,8 @@ export class Connection {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new empty Table
|
* Creates a new empty Table
|
||||||
*
|
|
||||||
* @param {string} name - The name of the table.
|
* @param {string} name - The name of the table.
|
||||||
* @param schema - The schema of the table
|
* @param {Schema} schema - The schema of the table
|
||||||
*/
|
*/
|
||||||
async createEmptyTable(
|
async createEmptyTable(
|
||||||
name: string,
|
name: string,
|
||||||
@@ -169,7 +168,7 @@ export class Connection {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Drop an existing table.
|
* Drop an existing table.
|
||||||
* @param name The name of the table to drop.
|
* @param {string} name The name of the table to drop.
|
||||||
*/
|
*/
|
||||||
async dropTable(name: string): Promise<void> {
|
async dropTable(name: string): Promise<void> {
|
||||||
return this.inner.dropTable(name);
|
return this.inner.dropTable(name);
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ export interface EmbeddingFunction<T> {
|
|||||||
embed: (data: T[]) => Promise<number[][]>;
|
embed: (data: T[]) => Promise<number[][]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Test if the input seems to be an embedding function */
|
||||||
export function isEmbeddingFunction<T>(
|
export function isEmbeddingFunction<T>(
|
||||||
value: unknown,
|
value: unknown,
|
||||||
): value is EmbeddingFunction<T> {
|
): value is EmbeddingFunction<T> {
|
||||||
|
|||||||
@@ -30,9 +30,8 @@ export { Table, AddDataOptions } from "./table";
|
|||||||
* - `/path/to/database` - local database
|
* - `/path/to/database` - local database
|
||||||
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
|
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
|
||||||
* - `db://host:port` - remote database (LanceDB cloud)
|
* - `db://host:port` - remote database (LanceDB cloud)
|
||||||
*
|
* @param {string} uri - The uri of the database. If the database uri starts
|
||||||
* @param uri The uri of the database. If the database uri starts with `db://` then it connects to a remote database.
|
* with `db://` then it connects to a remote database.
|
||||||
*
|
|
||||||
* @see {@link ConnectionOptions} for more details on the URI format.
|
* @see {@link ConnectionOptions} for more details on the URI format.
|
||||||
*/
|
*/
|
||||||
export async function connect(
|
export async function connect(
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import { Index as LanceDbIndex } from "./native";
|
|||||||
* Options to create an `IVF_PQ` index
|
* Options to create an `IVF_PQ` index
|
||||||
*/
|
*/
|
||||||
export interface IvfPqOptions {
|
export interface IvfPqOptions {
|
||||||
/** The number of IVF partitions to create.
|
/**
|
||||||
|
* The number of IVF partitions to create.
|
||||||
*
|
*
|
||||||
* This value should generally scale with the number of rows in the dataset.
|
* This value should generally scale with the number of rows in the dataset.
|
||||||
* By default the number of partitions is the square root of the number of
|
* By default the number of partitions is the square root of the number of
|
||||||
@@ -30,7 +31,8 @@ export interface IvfPqOptions {
|
|||||||
*/
|
*/
|
||||||
numPartitions?: number;
|
numPartitions?: number;
|
||||||
|
|
||||||
/** Number of sub-vectors of PQ.
|
/**
|
||||||
|
* Number of sub-vectors of PQ.
|
||||||
*
|
*
|
||||||
* This value controls how much the vector is compressed during the quantization step.
|
* This value controls how much the vector is compressed during the quantization step.
|
||||||
* The more sub vectors there are the less the vector is compressed. The default is
|
* The more sub vectors there are the less the vector is compressed. The default is
|
||||||
@@ -45,9 +47,10 @@ export interface IvfPqOptions {
|
|||||||
*/
|
*/
|
||||||
numSubVectors?: number;
|
numSubVectors?: number;
|
||||||
|
|
||||||
/** [DistanceType] to use to build the index.
|
/**
|
||||||
|
* Distance type to use to build the index.
|
||||||
*
|
*
|
||||||
* Default value is [DistanceType::L2].
|
* Default value is "l2".
|
||||||
*
|
*
|
||||||
* This is used when training the index to calculate the IVF partitions
|
* This is used when training the index to calculate the IVF partitions
|
||||||
* (vectors are grouped in partitions with similar vectors according to this
|
* (vectors are grouped in partitions with similar vectors according to this
|
||||||
@@ -79,7 +82,8 @@ export interface IvfPqOptions {
|
|||||||
*/
|
*/
|
||||||
distanceType?: "l2" | "cosine" | "dot";
|
distanceType?: "l2" | "cosine" | "dot";
|
||||||
|
|
||||||
/** Max iteration to train IVF kmeans.
|
/**
|
||||||
|
* Max iteration to train IVF kmeans.
|
||||||
*
|
*
|
||||||
* When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
|
* When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
|
||||||
* controls how many iterations of kmeans to run.
|
* controls how many iterations of kmeans to run.
|
||||||
@@ -91,7 +95,8 @@ export interface IvfPqOptions {
|
|||||||
*/
|
*/
|
||||||
maxIterations?: number;
|
maxIterations?: number;
|
||||||
|
|
||||||
/** The number of vectors, per partition, to sample when training IVF kmeans.
|
/**
|
||||||
|
* The number of vectors, per partition, to sample when training IVF kmeans.
|
||||||
*
|
*
|
||||||
* When an IVF PQ index is trained, we need to calculate partitions. These are groups
|
* When an IVF PQ index is trained, we need to calculate partitions. These are groups
|
||||||
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||||
@@ -148,7 +153,8 @@ export class Index {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a btree index
|
/**
|
||||||
|
* Create a btree index
|
||||||
*
|
*
|
||||||
* A btree index is an index on a scalar columns. The index stores a copy of the column
|
* A btree index is an index on a scalar columns. The index stores a copy of the column
|
||||||
* in sorted order. A header entry is created for each block of rows (currently the
|
* in sorted order. A header entry is created for each block of rows (currently the
|
||||||
@@ -172,7 +178,8 @@ export class Index {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface IndexOptions {
|
export interface IndexOptions {
|
||||||
/** Advanced index configuration
|
/**
|
||||||
|
* Advanced index configuration
|
||||||
*
|
*
|
||||||
* This option allows you to specify a specfic index to create and also
|
* This option allows you to specify a specfic index to create and also
|
||||||
* allows you to pass in configuration for training the index.
|
* allows you to pass in configuration for training the index.
|
||||||
@@ -183,7 +190,8 @@ export interface IndexOptions {
|
|||||||
* will be used to determine the most useful kind of index to create.
|
* will be used to determine the most useful kind of index to create.
|
||||||
*/
|
*/
|
||||||
config?: Index;
|
config?: Index;
|
||||||
/** Whether to replace the existing index
|
/**
|
||||||
|
* Whether to replace the existing index
|
||||||
*
|
*
|
||||||
* If this is false, and another index already exists on the same columns
|
* If this is false, and another index already exists on the same columns
|
||||||
* and the same name, then an error will be returned. This is true even if
|
* and the same name, then an error will be returned. This is true even if
|
||||||
|
|||||||
21
nodejs/lancedb/native.d.ts
vendored
21
nodejs/lancedb/native.d.ts
vendored
@@ -105,15 +105,23 @@ export class RecordBatchIterator {
|
|||||||
next(): Promise<Buffer | null>
|
next(): Promise<Buffer | null>
|
||||||
}
|
}
|
||||||
export class Query {
|
export class Query {
|
||||||
column(column: string): void
|
onlyIf(predicate: string): void
|
||||||
filter(filter: string): void
|
select(columns: Array<[string, string]>): void
|
||||||
select(columns: Array<string>): void
|
|
||||||
limit(limit: number): void
|
limit(limit: number): void
|
||||||
prefilter(prefilter: boolean): void
|
nearestTo(vector: Float32Array): VectorQuery
|
||||||
nearestTo(vector: Float32Array): void
|
execute(): Promise<RecordBatchIterator>
|
||||||
|
}
|
||||||
|
export class VectorQuery {
|
||||||
|
column(column: string): void
|
||||||
|
distanceType(distanceType: string): void
|
||||||
|
postfilter(): void
|
||||||
refineFactor(refineFactor: number): void
|
refineFactor(refineFactor: number): void
|
||||||
nprobes(nprobe: number): void
|
nprobes(nprobe: number): void
|
||||||
executeStream(): Promise<RecordBatchIterator>
|
bypassVectorIndex(): void
|
||||||
|
onlyIf(predicate: string): void
|
||||||
|
select(columns: Array<[string, string]>): void
|
||||||
|
limit(limit: number): void
|
||||||
|
execute(): Promise<RecordBatchIterator>
|
||||||
}
|
}
|
||||||
export class Table {
|
export class Table {
|
||||||
display(): string
|
display(): string
|
||||||
@@ -127,6 +135,7 @@ export class Table {
|
|||||||
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
|
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
|
||||||
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
|
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
|
||||||
query(): Query
|
query(): Query
|
||||||
|
vectorSearch(vector: Float32Array): VectorQuery
|
||||||
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
|
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
|
||||||
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
|
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
|
||||||
dropColumns(columns: Array<string>): Promise<void>
|
dropColumns(columns: Array<string>): Promise<void>
|
||||||
|
|||||||
@@ -5,302 +5,325 @@
|
|||||||
/* auto-generated by NAPI-RS */
|
/* auto-generated by NAPI-RS */
|
||||||
|
|
||||||
const { existsSync, readFileSync } = require('fs')
|
const { existsSync, readFileSync } = require('fs')
|
||||||
const { join } = require('path')
|
const { join } = require("path");
|
||||||
|
|
||||||
const { platform, arch } = process
|
const { platform, arch } = process;
|
||||||
|
|
||||||
let nativeBinding = null
|
let nativeBinding = null;
|
||||||
let localFileExisted = false
|
let localFileExisted = false;
|
||||||
let loadError = null
|
let loadError = null;
|
||||||
|
|
||||||
function isMusl() {
|
function isMusl() {
|
||||||
// For Node 10
|
// For Node 10
|
||||||
if (!process.report || typeof process.report.getReport !== 'function') {
|
if (!process.report || typeof process.report.getReport !== "function") {
|
||||||
try {
|
try {
|
||||||
const lddPath = require('child_process').execSync('which ldd').toString().trim()
|
const lddPath = require("child_process")
|
||||||
return readFileSync(lddPath, 'utf8').includes('musl')
|
.execSync("which ldd")
|
||||||
|
.toString()
|
||||||
|
.trim();
|
||||||
|
return readFileSync(lddPath, "utf8").includes("musl");
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
return true
|
return true;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const { glibcVersionRuntime } = process.report.getReport().header
|
const { glibcVersionRuntime } = process.report.getReport().header;
|
||||||
return !glibcVersionRuntime
|
return !glibcVersionRuntime;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (platform) {
|
switch (platform) {
|
||||||
case 'android':
|
case "android":
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'arm64':
|
case "arm64":
|
||||||
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm64.node'))
|
localFileExisted = existsSync(
|
||||||
|
join(__dirname, "lancedb-nodejs.android-arm64.node"),
|
||||||
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.android-arm64.node')
|
nativeBinding = require("./lancedb-nodejs.android-arm64.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-android-arm64')
|
nativeBinding = require("lancedb-android-arm64");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'arm':
|
case "arm":
|
||||||
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm-eabi.node'))
|
localFileExisted = existsSync(
|
||||||
|
join(__dirname, "lancedb-nodejs.android-arm-eabi.node"),
|
||||||
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.android-arm-eabi.node')
|
nativeBinding = require("./lancedb-nodejs.android-arm-eabi.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-android-arm-eabi')
|
nativeBinding = require("lancedb-android-arm-eabi");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported architecture on Android ${arch}`)
|
throw new Error(`Unsupported architecture on Android ${arch}`);
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'win32':
|
case "win32":
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'x64':
|
case "x64":
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.win32-x64-msvc.node')
|
join(__dirname, "lancedb-nodejs.win32-x64-msvc.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.win32-x64-msvc.node')
|
nativeBinding = require("./lancedb-nodejs.win32-x64-msvc.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-win32-x64-msvc')
|
nativeBinding = require("lancedb-win32-x64-msvc");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'ia32':
|
case "ia32":
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.win32-ia32-msvc.node')
|
join(__dirname, "lancedb-nodejs.win32-ia32-msvc.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.win32-ia32-msvc.node')
|
nativeBinding = require("./lancedb-nodejs.win32-ia32-msvc.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-win32-ia32-msvc')
|
nativeBinding = require("lancedb-win32-ia32-msvc");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'arm64':
|
case "arm64":
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.win32-arm64-msvc.node')
|
join(__dirname, "lancedb-nodejs.win32-arm64-msvc.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.win32-arm64-msvc.node')
|
nativeBinding = require("./lancedb-nodejs.win32-arm64-msvc.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-win32-arm64-msvc')
|
nativeBinding = require("lancedb-win32-arm64-msvc");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported architecture on Windows: ${arch}`)
|
throw new Error(`Unsupported architecture on Windows: ${arch}`);
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'darwin':
|
case "darwin":
|
||||||
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-universal.node'))
|
localFileExisted = existsSync(
|
||||||
|
join(__dirname, "lancedb-nodejs.darwin-universal.node"),
|
||||||
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.darwin-universal.node')
|
nativeBinding = require("./lancedb-nodejs.darwin-universal.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-darwin-universal')
|
nativeBinding = require("lancedb-darwin-universal");
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
} catch {}
|
} catch {}
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'x64':
|
case "x64":
|
||||||
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-x64.node'))
|
|
||||||
try {
|
|
||||||
if (localFileExisted) {
|
|
||||||
nativeBinding = require('./lancedb-nodejs.darwin-x64.node')
|
|
||||||
} else {
|
|
||||||
nativeBinding = require('lancedb-darwin-x64')
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
loadError = e
|
|
||||||
}
|
|
||||||
break
|
|
||||||
case 'arm64':
|
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.darwin-arm64.node')
|
join(__dirname, "lancedb-nodejs.darwin-x64.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.darwin-arm64.node')
|
nativeBinding = require("./lancedb-nodejs.darwin-x64.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-darwin-arm64')
|
nativeBinding = require("lancedb-darwin-x64");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
|
case "arm64":
|
||||||
|
localFileExisted = existsSync(
|
||||||
|
join(__dirname, "lancedb-nodejs.darwin-arm64.node"),
|
||||||
|
);
|
||||||
|
try {
|
||||||
|
if (localFileExisted) {
|
||||||
|
nativeBinding = require("./lancedb-nodejs.darwin-arm64.node");
|
||||||
|
} else {
|
||||||
|
nativeBinding = require("lancedb-darwin-arm64");
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
loadError = e;
|
||||||
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported architecture on macOS: ${arch}`)
|
throw new Error(`Unsupported architecture on macOS: ${arch}`);
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'freebsd':
|
case "freebsd":
|
||||||
if (arch !== 'x64') {
|
if (arch !== "x64") {
|
||||||
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`)
|
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`);
|
||||||
}
|
}
|
||||||
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.freebsd-x64.node'))
|
localFileExisted = existsSync(
|
||||||
|
join(__dirname, "lancedb-nodejs.freebsd-x64.node"),
|
||||||
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.freebsd-x64.node')
|
nativeBinding = require("./lancedb-nodejs.freebsd-x64.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-freebsd-x64')
|
nativeBinding = require("lancedb-freebsd-x64");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'linux':
|
case "linux":
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'x64':
|
case "x64":
|
||||||
if (isMusl()) {
|
if (isMusl()) {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-x64-musl.node')
|
join(__dirname, "lancedb-nodejs.linux-x64-musl.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-x64-musl.node')
|
nativeBinding = require("./lancedb-nodejs.linux-x64-musl.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-x64-musl')
|
nativeBinding = require("lancedb-linux-x64-musl");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-x64-gnu.node')
|
join(__dirname, "lancedb-nodejs.linux-x64-gnu.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-x64-gnu.node')
|
nativeBinding = require("./lancedb-nodejs.linux-x64-gnu.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-x64-gnu')
|
nativeBinding = require("lancedb-linux-x64-gnu");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'arm64':
|
case "arm64":
|
||||||
if (isMusl()) {
|
if (isMusl()) {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-arm64-musl.node')
|
join(__dirname, "lancedb-nodejs.linux-arm64-musl.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-arm64-musl.node')
|
nativeBinding = require("./lancedb-nodejs.linux-arm64-musl.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-arm64-musl')
|
nativeBinding = require("lancedb-linux-arm64-musl");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-arm64-gnu.node')
|
join(__dirname, "lancedb-nodejs.linux-arm64-gnu.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-arm64-gnu.node')
|
nativeBinding = require("./lancedb-nodejs.linux-arm64-gnu.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-arm64-gnu')
|
nativeBinding = require("lancedb-linux-arm64-gnu");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'arm':
|
case "arm":
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-arm-gnueabihf.node')
|
join(__dirname, "lancedb-nodejs.linux-arm-gnueabihf.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-arm-gnueabihf.node')
|
nativeBinding = require("./lancedb-nodejs.linux-arm-gnueabihf.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-arm-gnueabihf')
|
nativeBinding = require("lancedb-linux-arm-gnueabihf");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 'riscv64':
|
case "riscv64":
|
||||||
if (isMusl()) {
|
if (isMusl()) {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-riscv64-musl.node')
|
join(__dirname, "lancedb-nodejs.linux-riscv64-musl.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-riscv64-musl.node')
|
nativeBinding = require("./lancedb-nodejs.linux-riscv64-musl.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-riscv64-musl')
|
nativeBinding = require("lancedb-linux-riscv64-musl");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-riscv64-gnu.node')
|
join(__dirname, "lancedb-nodejs.linux-riscv64-gnu.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-riscv64-gnu.node')
|
nativeBinding = require("./lancedb-nodejs.linux-riscv64-gnu.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-riscv64-gnu')
|
nativeBinding = require("lancedb-linux-riscv64-gnu");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
case 's390x':
|
case "s390x":
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'lancedb-nodejs.linux-s390x-gnu.node')
|
join(__dirname, "lancedb-nodejs.linux-s390x-gnu.node"),
|
||||||
)
|
);
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./lancedb-nodejs.linux-s390x-gnu.node')
|
nativeBinding = require("./lancedb-nodejs.linux-s390x-gnu.node");
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('lancedb-linux-s390x-gnu')
|
nativeBinding = require("lancedb-linux-s390x-gnu");
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e;
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported architecture on Linux: ${arch}`)
|
throw new Error(`Unsupported architecture on Linux: ${arch}`);
|
||||||
}
|
}
|
||||||
break
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`)
|
throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!nativeBinding) {
|
if (!nativeBinding) {
|
||||||
if (loadError) {
|
if (loadError) {
|
||||||
throw loadError
|
throw loadError;
|
||||||
}
|
}
|
||||||
throw new Error(`Failed to load native binding`)
|
throw new Error(`Failed to load native binding`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const { Connection, Index, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
|
const {
|
||||||
|
Connection,
|
||||||
|
Index,
|
||||||
|
RecordBatchIterator,
|
||||||
|
Query,
|
||||||
|
VectorQuery,
|
||||||
|
Table,
|
||||||
|
WriteMode,
|
||||||
|
connect,
|
||||||
|
} = nativeBinding;
|
||||||
|
|
||||||
module.exports.Connection = Connection
|
module.exports.Connection = Connection;
|
||||||
module.exports.Index = Index
|
module.exports.Index = Index;
|
||||||
module.exports.RecordBatchIterator = RecordBatchIterator
|
module.exports.RecordBatchIterator = RecordBatchIterator;
|
||||||
module.exports.Query = Query
|
module.exports.Query = Query;
|
||||||
module.exports.Table = Table
|
module.exports.VectorQuery = VectorQuery;
|
||||||
module.exports.WriteMode = WriteMode
|
module.exports.Table = Table;
|
||||||
module.exports.connect = connect
|
module.exports.WriteMode = WriteMode;
|
||||||
|
module.exports.connect = connect;
|
||||||
|
|||||||
@@ -17,18 +17,15 @@ import {
|
|||||||
RecordBatchIterator as NativeBatchIterator,
|
RecordBatchIterator as NativeBatchIterator,
|
||||||
Query as NativeQuery,
|
Query as NativeQuery,
|
||||||
Table as NativeTable,
|
Table as NativeTable,
|
||||||
|
VectorQuery as NativeVectorQuery,
|
||||||
} from "./native";
|
} from "./native";
|
||||||
|
import { type IvfPqOptions } from "./indices";
|
||||||
class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||||
private promisedInner?: Promise<NativeBatchIterator>;
|
private promisedInner?: Promise<NativeBatchIterator>;
|
||||||
private inner?: NativeBatchIterator;
|
private inner?: NativeBatchIterator;
|
||||||
|
|
||||||
constructor(
|
constructor(promise?: Promise<NativeBatchIterator>) {
|
||||||
inner?: NativeBatchIterator,
|
|
||||||
promise?: Promise<NativeBatchIterator>,
|
|
||||||
) {
|
|
||||||
// TODO: check promise reliably so we dont need to pass two arguments.
|
// TODO: check promise reliably so we dont need to pass two arguments.
|
||||||
this.inner = inner;
|
|
||||||
this.promisedInner = promise;
|
this.promisedInner = promise;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,82 +50,113 @@ class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
|||||||
}
|
}
|
||||||
/* eslint-enable */
|
/* eslint-enable */
|
||||||
|
|
||||||
/** Query executor */
|
/** Common methods supported by all query types */
|
||||||
export class Query implements AsyncIterable<RecordBatch> {
|
export class QueryBase<
|
||||||
private readonly inner: NativeQuery;
|
NativeQueryType extends NativeQuery | NativeVectorQuery,
|
||||||
|
QueryType,
|
||||||
|
> implements AsyncIterable<RecordBatch>
|
||||||
|
{
|
||||||
|
protected constructor(protected inner: NativeQueryType) {}
|
||||||
|
|
||||||
constructor(tbl: NativeTable) {
|
/**
|
||||||
this.inner = tbl.query();
|
* A filter statement to be applied to this query.
|
||||||
|
*
|
||||||
|
* The filter should be supplied as an SQL query string. For example:
|
||||||
|
* @example
|
||||||
|
* x > 10
|
||||||
|
* y > 0 AND y < 100
|
||||||
|
* x > 5 OR y = 'test'
|
||||||
|
*
|
||||||
|
* Filtering performance can often be improved by creating a scalar index
|
||||||
|
* on the filter column(s).
|
||||||
|
*/
|
||||||
|
where(predicate: string): QueryType {
|
||||||
|
this.inner.onlyIf(predicate);
|
||||||
|
return this as unknown as QueryType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Set the column to run query. */
|
/**
|
||||||
column(column: string): Query {
|
* Return only the specified columns.
|
||||||
this.inner.column(column);
|
*
|
||||||
return this;
|
* By default a query will return all columns from the table. However, this can have
|
||||||
|
* a very significant impact on latency. LanceDb stores data in a columnar fashion. This
|
||||||
|
* means we can finely tune our I/O to select exactly the columns we need.
|
||||||
|
*
|
||||||
|
* As a best practice you should always limit queries to the columns that you need. If you
|
||||||
|
* pass in an array of column names then only those columns will be returned.
|
||||||
|
*
|
||||||
|
* You can also use this method to create new "dynamic" columns based on your existing columns.
|
||||||
|
* For example, you may not care about "a" or "b" but instead simply want "a + b". This is often
|
||||||
|
* seen in the SELECT clause of an SQL query (e.g. `SELECT a+b FROM my_table`).
|
||||||
|
*
|
||||||
|
* To create dynamic columns you can pass in a Map<string, string>. A column will be returned
|
||||||
|
* for each entry in the map. The key provides the name of the column. The value is
|
||||||
|
* an SQL string used to specify how the column is calculated.
|
||||||
|
*
|
||||||
|
* For example, an SQL query might state `SELECT a + b AS combined, c`. The equivalent
|
||||||
|
* input to this method would be:
|
||||||
|
* @example
|
||||||
|
* new Map([["combined", "a + b"], ["c", "c"]])
|
||||||
|
*
|
||||||
|
* Columns will always be returned in the order given, even if that order is different than
|
||||||
|
* the order used when adding the data.
|
||||||
|
*
|
||||||
|
* Note that you can pass in a `Record<string, string>` (e.g. an object literal). This method
|
||||||
|
* uses `Object.entries` which should preserve the insertion order of the object. However,
|
||||||
|
* object insertion order is easy to get wrong and `Map` is more foolproof.
|
||||||
|
*/
|
||||||
|
select(
|
||||||
|
columns: string[] | Map<string, string> | Record<string, string>,
|
||||||
|
): QueryType {
|
||||||
|
let columnTuples: [string, string][];
|
||||||
|
if (Array.isArray(columns)) {
|
||||||
|
columnTuples = columns.map((c) => [c, c]);
|
||||||
|
} else if (columns instanceof Map) {
|
||||||
|
columnTuples = Array.from(columns.entries());
|
||||||
|
} else {
|
||||||
|
columnTuples = Object.entries(columns);
|
||||||
|
}
|
||||||
|
this.inner.select(columnTuples);
|
||||||
|
return this as unknown as QueryType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Set the filter predicate, only returns the results that satisfy the filter.
|
/**
|
||||||
|
* Set the maximum number of results to return.
|
||||||
|
*
|
||||||
|
* By default, a plain search has no limit. If this method is not
|
||||||
|
* called then every valid row from the table will be returned.
|
||||||
|
*/
|
||||||
|
limit(limit: number): QueryType {
|
||||||
|
this.inner.limit(limit);
|
||||||
|
return this as unknown as QueryType;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected nativeExecute(): Promise<NativeBatchIterator> {
|
||||||
|
return this.inner.execute();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute the query and return the results as an @see {@link AsyncIterator}
|
||||||
|
* of @see {@link RecordBatch}.
|
||||||
|
*
|
||||||
|
* By default, LanceDb will use many threads to calculate results and, when
|
||||||
|
* the result set is large, multiple batches will be processed at one time.
|
||||||
|
* This readahead is limited however and backpressure will be applied if this
|
||||||
|
* stream is consumed slowly (this constrains the maximum memory used by a
|
||||||
|
* single query)
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
filter(predicate: string): Query {
|
protected execute(): RecordBatchIterator {
|
||||||
this.inner.filter(predicate);
|
return new RecordBatchIterator(this.nativeExecute());
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
* Select the columns to return. If not set, all columns are returned.
|
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
||||||
*/
|
const promise = this.nativeExecute();
|
||||||
select(columns: string[]): Query {
|
return new RecordBatchIterator(promise);
|
||||||
this.inner.select(columns);
|
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** Collect the results as an Arrow @see {@link ArrowTable}. */
|
||||||
* Set the limit of rows to return.
|
|
||||||
*/
|
|
||||||
limit(limit: number): Query {
|
|
||||||
this.inner.limit(limit);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
prefilter(prefilter: boolean): Query {
|
|
||||||
this.inner.prefilter(prefilter);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the query vector.
|
|
||||||
*/
|
|
||||||
nearestTo(vector: number[]): Query {
|
|
||||||
this.inner.nearestTo(Float32Array.from(vector));
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the number of IVF partitions to use for the query.
|
|
||||||
*/
|
|
||||||
nprobes(nprobes: number): Query {
|
|
||||||
this.inner.nprobes(nprobes);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the refine factor for the query.
|
|
||||||
*/
|
|
||||||
refineFactor(refineFactor: number): Query {
|
|
||||||
this.inner.refineFactor(refineFactor);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Execute the query and return the results as an AsyncIterator.
|
|
||||||
*/
|
|
||||||
async executeStream(): Promise<RecordBatchIterator> {
|
|
||||||
const inner = await this.inner.executeStream();
|
|
||||||
return new RecordBatchIterator(inner);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Collect the results as an Arrow Table. */
|
|
||||||
async toArrow(): Promise<ArrowTable> {
|
async toArrow(): Promise<ArrowTable> {
|
||||||
const batches = [];
|
const batches = [];
|
||||||
for await (const batch of this) {
|
for await (const batch of this) {
|
||||||
@@ -137,18 +165,211 @@ export class Query implements AsyncIterable<RecordBatch> {
|
|||||||
return new ArrowTable(batches);
|
return new ArrowTable(batches);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns a JSON Array of All results.
|
/** Collect the results as an array of objects. */
|
||||||
*
|
|
||||||
*/
|
|
||||||
async toArray(): Promise<unknown[]> {
|
async toArray(): Promise<unknown[]> {
|
||||||
const tbl = await this.toArrow();
|
const tbl = await this.toArrow();
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||||
return tbl.toArray();
|
return tbl.toArray();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
/**
|
||||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
* An interface for a query that can be executed
|
||||||
const promise = this.inner.executeStream();
|
*
|
||||||
return new RecordBatchIterator(undefined, promise);
|
* Supported by all query types
|
||||||
|
*/
|
||||||
|
export interface ExecutableQuery {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A builder used to construct a vector search
|
||||||
|
*
|
||||||
|
* This builder can be reused to execute the query many times.
|
||||||
|
*/
|
||||||
|
export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||||
|
constructor(inner: NativeVectorQuery) {
|
||||||
|
super(inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the number of partitions to search (probe)
|
||||||
|
*
|
||||||
|
* This argument is only used when the vector column has an IVF PQ index.
|
||||||
|
* If there is no index then this value is ignored.
|
||||||
|
*
|
||||||
|
* The IVF stage of IVF PQ divides the input into partitions (clusters) of
|
||||||
|
* related values.
|
||||||
|
*
|
||||||
|
* The partition whose centroids are closest to the query vector will be
|
||||||
|
* exhaustiely searched to find matches. This parameter controls how many
|
||||||
|
* partitions should be searched.
|
||||||
|
*
|
||||||
|
* Increasing this value will increase the recall of your query but will
|
||||||
|
* also increase the latency of your query. The default value is 20. This
|
||||||
|
* default is good for many cases but the best value to use will depend on
|
||||||
|
* your data and the recall that you need to achieve.
|
||||||
|
*
|
||||||
|
* For best results we recommend tuning this parameter with a benchmark against
|
||||||
|
* your actual data to find the smallest possible value that will still give
|
||||||
|
* you the desired recall.
|
||||||
|
*/
|
||||||
|
nprobes(nprobes: number): VectorQuery {
|
||||||
|
this.inner.nprobes(nprobes);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the vector column to query
|
||||||
|
*
|
||||||
|
* This controls which column is compared to the query vector supplied in
|
||||||
|
* the call to @see {@link Query#nearestTo}
|
||||||
|
*
|
||||||
|
* This parameter must be specified if the table has more than one column
|
||||||
|
* whose data type is a fixed-size-list of floats.
|
||||||
|
*/
|
||||||
|
column(column: string): VectorQuery {
|
||||||
|
this.inner.column(column);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the distance metric to use
|
||||||
|
*
|
||||||
|
* When performing a vector search we try and find the "nearest" vectors according
|
||||||
|
* to some kind of distance metric. This parameter controls which distance metric to
|
||||||
|
* use. See @see {@link IvfPqOptions.distanceType} for more details on the different
|
||||||
|
* distance metrics available.
|
||||||
|
*
|
||||||
|
* Note: if there is a vector index then the distance type used MUST match the distance
|
||||||
|
* type used to train the vector index. If this is not done then the results will be
|
||||||
|
* invalid.
|
||||||
|
*
|
||||||
|
* By default "l2" is used.
|
||||||
|
*/
|
||||||
|
distanceType(distanceType: string): VectorQuery {
|
||||||
|
this.inner.distanceType(distanceType);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A multiplier to control how many additional rows are taken during the refine step
|
||||||
|
*
|
||||||
|
* This argument is only used when the vector column has an IVF PQ index.
|
||||||
|
* If there is no index then this value is ignored.
|
||||||
|
*
|
||||||
|
* An IVF PQ index stores compressed (quantized) values. They query vector is compared
|
||||||
|
* against these values and, since they are compressed, the comparison is inaccurate.
|
||||||
|
*
|
||||||
|
* This parameter can be used to refine the results. It can improve both improve recall
|
||||||
|
* and correct the ordering of the nearest results.
|
||||||
|
*
|
||||||
|
* To refine results LanceDb will first perform an ANN search to find the nearest
|
||||||
|
* `limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
|
||||||
|
* `limit` is the default (10) then the first 30 results will be selected. LanceDb
|
||||||
|
* then fetches the full, uncompressed, values for these 30 results. The results are
|
||||||
|
* then reordered by the true distance and only the nearest 10 are kept.
|
||||||
|
*
|
||||||
|
* Note: there is a difference between calling this method with a value of 1 and never
|
||||||
|
* calling this method at all. Calling this method with any value will have an impact
|
||||||
|
* on your search latency. When you call this method with a `refine_factor` of 1 then
|
||||||
|
* LanceDb still needs to fetch the full, uncompressed, values so that it can potentially
|
||||||
|
* reorder the results.
|
||||||
|
*
|
||||||
|
* Note: if this method is NOT called then the distances returned in the _distance column
|
||||||
|
* will be approximate distances based on the comparison of the quantized query vector
|
||||||
|
* and the quantized result vectors. This can be considerably different than the true
|
||||||
|
* distance between the query vector and the actual uncompressed vector.
|
||||||
|
*/
|
||||||
|
refineFactor(refineFactor: number): VectorQuery {
|
||||||
|
this.inner.refineFactor(refineFactor);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If this is called then filtering will happen after the vector search instead of
|
||||||
|
* before.
|
||||||
|
*
|
||||||
|
* By default filtering will be performed before the vector search. This is how
|
||||||
|
* filtering is typically understood to work. This prefilter step does add some
|
||||||
|
* additional latency. Creating a scalar index on the filter column(s) can
|
||||||
|
* often improve this latency. However, sometimes a filter is too complex or scalar
|
||||||
|
* indices cannot be applied to the column. In these cases postfiltering can be
|
||||||
|
* used instead of prefiltering to improve latency.
|
||||||
|
*
|
||||||
|
* Post filtering applies the filter to the results of the vector search. This means
|
||||||
|
* we only run the filter on a much smaller set of data. However, it can cause the
|
||||||
|
* query to return fewer than `limit` results (or even no results) if none of the nearest
|
||||||
|
* results match the filter.
|
||||||
|
*
|
||||||
|
* Post filtering happens during the "refine stage" (described in more detail in
|
||||||
|
* @see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||||
|
* factor can often help restore some of the results lost by post filtering.
|
||||||
|
*/
|
||||||
|
postfilter(): VectorQuery {
|
||||||
|
this.inner.postfilter();
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If this is called then any vector index is skipped
|
||||||
|
*
|
||||||
|
* An exhaustive (flat) search will be performed. The query vector will
|
||||||
|
* be compared to every vector in the table. At high scales this can be
|
||||||
|
* expensive. However, this is often still useful. For example, skipping
|
||||||
|
* the vector index can give you ground truth results which you can use to
|
||||||
|
* calculate your recall to select an appropriate value for nprobes.
|
||||||
|
*/
|
||||||
|
bypassVectorIndex(): VectorQuery {
|
||||||
|
this.inner.bypassVectorIndex();
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A builder for LanceDB queries. */
|
||||||
|
export class Query extends QueryBase<NativeQuery, Query> {
|
||||||
|
constructor(tbl: NativeTable) {
|
||||||
|
super(tbl.query());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find the nearest vectors to the given query vector.
|
||||||
|
*
|
||||||
|
* This converts the query from a plain query to a vector query.
|
||||||
|
*
|
||||||
|
* This method will attempt to convert the input to the query vector
|
||||||
|
* expected by the embedding model. If the input cannot be converted
|
||||||
|
* then an error will be thrown.
|
||||||
|
*
|
||||||
|
* By default, there is no embedding model, and the input should be
|
||||||
|
* an array-like object of numbers (something that can be used as input
|
||||||
|
* to Float32Array.from)
|
||||||
|
*
|
||||||
|
* If there is only one vector column (a column whose data type is a
|
||||||
|
* fixed size list of floats) then the column does not need to be specified.
|
||||||
|
* If there is more than one vector column you must use
|
||||||
|
* @see {@link VectorQuery#column} to specify which column you would like
|
||||||
|
* to compare with.
|
||||||
|
*
|
||||||
|
* If no index has been created on the vector column then a vector query
|
||||||
|
* will perform a distance comparison between the query vector and every
|
||||||
|
* vector in the database and then sort the results. This is sometimes
|
||||||
|
* called a "flat search"
|
||||||
|
*
|
||||||
|
* For small databases, with a few hundred thousand vectors or less, this can
|
||||||
|
* be reasonably fast. In larger databases you should create a vector index
|
||||||
|
* on the column. If there is a vector index then an "approximate" nearest
|
||||||
|
* neighbor search (frequently called an ANN search) will be performed. This
|
||||||
|
* search is much faster, but the results will be approximate.
|
||||||
|
*
|
||||||
|
* The query can be further parameterized using the returned builder. There
|
||||||
|
* are various ANN search parameters that will let you fine tune your recall
|
||||||
|
* accuracy vs search latency.
|
||||||
|
*
|
||||||
|
* Vector searches always have a `limit`. If `limit` has not been called then
|
||||||
|
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
||||||
|
*/
|
||||||
|
nearestTo(vector: unknown): VectorQuery {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
|
||||||
|
return new VectorQuery(vectorQuery);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -481,6 +481,13 @@ function sanitizeField(fieldLike: unknown): Field {
|
|||||||
return new Field(name, type, nullable, metadata);
|
return new Field(name, type, nullable, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert something schemaLike into a Schema instance
|
||||||
|
*
|
||||||
|
* This method is often needed even when the caller is using a Schema
|
||||||
|
* instance because they might be using a different instance of apache-arrow
|
||||||
|
* than lancedb is using.
|
||||||
|
*/
|
||||||
export function sanitizeSchema(schemaLike: unknown): Schema {
|
export function sanitizeSchema(schemaLike: unknown): Schema {
|
||||||
if (schemaLike instanceof Schema) {
|
if (schemaLike instanceof Schema) {
|
||||||
return schemaLike;
|
return schemaLike;
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import {
|
|||||||
IndexConfig,
|
IndexConfig,
|
||||||
Table as _NativeTable,
|
Table as _NativeTable,
|
||||||
} from "./native";
|
} from "./native";
|
||||||
import { Query } from "./query";
|
import { Query, VectorQuery } from "./query";
|
||||||
import { IndexOptions } from "./indices";
|
import { IndexOptions } from "./indices";
|
||||||
import { Data, fromDataToBuffer } from "./arrow";
|
import { Data, fromDataToBuffer } from "./arrow";
|
||||||
|
|
||||||
@@ -28,7 +28,8 @@ export { IndexConfig } from "./native";
|
|||||||
* Options for adding data to a table.
|
* Options for adding data to a table.
|
||||||
*/
|
*/
|
||||||
export interface AddDataOptions {
|
export interface AddDataOptions {
|
||||||
/** If "append" (the default) then the new data will be added to the table
|
/**
|
||||||
|
* If "append" (the default) then the new data will be added to the table
|
||||||
*
|
*
|
||||||
* If "overwrite" then the new data will replace the existing data in the table.
|
* If "overwrite" then the new data will replace the existing data in the table.
|
||||||
*/
|
*/
|
||||||
@@ -74,7 +75,8 @@ export class Table {
|
|||||||
return this.inner.isOpen();
|
return this.inner.isOpen();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Close the table, releasing any underlying resources.
|
/**
|
||||||
|
* Close the table, releasing any underlying resources.
|
||||||
*
|
*
|
||||||
* It is safe to call this method multiple times.
|
* It is safe to call this method multiple times.
|
||||||
*
|
*
|
||||||
@@ -98,9 +100,7 @@ export class Table {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Insert records into this Table.
|
* Insert records into this Table.
|
||||||
*
|
|
||||||
* @param {Data} data Records to be inserted into the Table
|
* @param {Data} data Records to be inserted into the Table
|
||||||
* @return The number of rows added to the table
|
|
||||||
*/
|
*/
|
||||||
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
|
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
|
||||||
const mode = options?.mode ?? "append";
|
const mode = options?.mode ?? "append";
|
||||||
@@ -124,15 +124,15 @@ export class Table {
|
|||||||
* you are updating many rows (with different ids) then you will get
|
* you are updating many rows (with different ids) then you will get
|
||||||
* better performance with a single [`merge_insert`] call instead of
|
* better performance with a single [`merge_insert`] call instead of
|
||||||
* repeatedly calilng this method.
|
* repeatedly calilng this method.
|
||||||
*
|
* @param {Map<string, string> | Record<string, string>} updates - the
|
||||||
* @param updates the columns to update
|
* columns to update
|
||||||
*
|
*
|
||||||
* Keys in the map should specify the name of the column to update.
|
* Keys in the map should specify the name of the column to update.
|
||||||
* Values in the map provide the new value of the column. These can
|
* Values in the map provide the new value of the column. These can
|
||||||
* be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions
|
* be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions
|
||||||
* based on the row being updated (e.g. "my_col + 1")
|
* based on the row being updated (e.g. "my_col + 1")
|
||||||
*
|
* @param {Partial<UpdateOptions>} options - additional options to control
|
||||||
* @param options additional options to control the update behavior
|
* the update behavior
|
||||||
*/
|
*/
|
||||||
async update(
|
async update(
|
||||||
updates: Map<string, string> | Record<string, string>,
|
updates: Map<string, string> | Record<string, string>,
|
||||||
@@ -158,37 +158,28 @@ export class Table {
|
|||||||
await this.inner.delete(predicate);
|
await this.inner.delete(predicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create an index to speed up queries.
|
/**
|
||||||
|
* Create an index to speed up queries.
|
||||||
*
|
*
|
||||||
* Indices can be created on vector columns or scalar columns.
|
* Indices can be created on vector columns or scalar columns.
|
||||||
* Indices on vector columns will speed up vector searches.
|
* Indices on vector columns will speed up vector searches.
|
||||||
* Indices on scalar columns will speed up filtering (in both
|
* Indices on scalar columns will speed up filtering (in both
|
||||||
* vector and non-vector searches)
|
* vector and non-vector searches)
|
||||||
*
|
|
||||||
* @example
|
* @example
|
||||||
*
|
* // If the column has a vector (fixed size list) data type then
|
||||||
* If the column has a vector (fixed size list) data type then
|
* // an IvfPq vector index will be created.
|
||||||
* an IvfPq vector index will be created.
|
|
||||||
*
|
|
||||||
* ```typescript
|
|
||||||
* const table = await conn.openTable("my_table");
|
* const table = await conn.openTable("my_table");
|
||||||
* await table.createIndex(["vector"]);
|
* await table.createIndex(["vector"]);
|
||||||
* ```
|
* @example
|
||||||
*
|
* // For advanced control over vector index creation you can specify
|
||||||
* For advanced control over vector index creation you can specify
|
* // the index type and options.
|
||||||
* the index type and options.
|
|
||||||
* ```typescript
|
|
||||||
* const table = await conn.openTable("my_table");
|
* const table = await conn.openTable("my_table");
|
||||||
* await table.createIndex(["vector"], I)
|
* await table.createIndex(["vector"], I)
|
||||||
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
||||||
* .build();
|
* .build();
|
||||||
* ```
|
* @example
|
||||||
*
|
* // Or create a Scalar index
|
||||||
* Or create a Scalar index
|
|
||||||
*
|
|
||||||
* ```typescript
|
|
||||||
* await table.createIndex("my_float_col").build();
|
* await table.createIndex("my_float_col").build();
|
||||||
* ```
|
|
||||||
*/
|
*/
|
||||||
async createIndex(column: string, options?: Partial<IndexOptions>) {
|
async createIndex(column: string, options?: Partial<IndexOptions>) {
|
||||||
// Bit of a hack to get around the fact that TS has no package-scope.
|
// Bit of a hack to get around the fact that TS has no package-scope.
|
||||||
@@ -198,69 +189,74 @@ export class Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a generic {@link Query} Builder.
|
* Create a {@link Query} Builder.
|
||||||
|
*
|
||||||
|
* Queries allow you to search your existing data. By default the query will
|
||||||
|
* return all the data in the table in no particular order. The builder
|
||||||
|
* returned by this method can be used to control the query using filtering,
|
||||||
|
* vector similarity, sorting, and more.
|
||||||
|
*
|
||||||
|
* Note: By default, all columns are returned. For best performance, you should
|
||||||
|
* only fetch the columns you need. See [`Query::select_with_projection`] for
|
||||||
|
* more details.
|
||||||
*
|
*
|
||||||
* When appropriate, various indices and statistics based pruning will be used to
|
* When appropriate, various indices and statistics based pruning will be used to
|
||||||
* accelerate the query.
|
* accelerate the query.
|
||||||
*
|
|
||||||
* @example
|
* @example
|
||||||
*
|
* // SQL-style filtering
|
||||||
* ### Run a SQL-style query
|
* //
|
||||||
* ```typescript
|
* // This query will return up to 1000 rows whose value in the `id` column
|
||||||
|
* // is greater than 5. LanceDb supports a broad set of filtering functions.
|
||||||
* for await (const batch of table.query()
|
* for await (const batch of table.query()
|
||||||
* .filter("id > 1").select(["id"]).limit(20)) {
|
* .filter("id > 1").select(["id"]).limit(20)) {
|
||||||
* console.log(batch);
|
* console.log(batch);
|
||||||
* }
|
* }
|
||||||
* ```
|
* @example
|
||||||
*
|
* // Vector Similarity Search
|
||||||
* ### Run Top-10 vector similarity search
|
* //
|
||||||
* ```typescript
|
* // This example will find the 10 rows whose value in the "vector" column are
|
||||||
|
* // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
|
||||||
|
* // on the "vector" column then this will perform an ANN search.
|
||||||
|
* //
|
||||||
|
* // The `refine_factor` and `nprobes` methods are used to control the recall /
|
||||||
|
* // latency tradeoff of the search.
|
||||||
* for await (const batch of table.query()
|
* for await (const batch of table.query()
|
||||||
* .nearestTo([1, 2, 3])
|
* .nearestTo([1, 2, 3])
|
||||||
* .refineFactor(5).nprobe(10)
|
* .refineFactor(5).nprobe(10)
|
||||||
* .limit(10)) {
|
* .limit(10)) {
|
||||||
* console.log(batch);
|
* console.log(batch);
|
||||||
* }
|
* }
|
||||||
*```
|
* @example
|
||||||
*
|
* // Scan the full dataset
|
||||||
* ### Scan the full dataset
|
* //
|
||||||
* ```typescript
|
* // This query will return everything in the table in no particular order.
|
||||||
* for await (const batch of table.query()) {
|
* for await (const batch of table.query()) {
|
||||||
* console.log(batch);
|
* console.log(batch);
|
||||||
* }
|
* }
|
||||||
*
|
* @returns {Query} A builder that can be used to parameterize the query
|
||||||
* ### Return the full dataset as Arrow Table
|
|
||||||
* ```typescript
|
|
||||||
* let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow();
|
|
||||||
* ```
|
|
||||||
*
|
|
||||||
* @returns {@link Query}
|
|
||||||
*/
|
*/
|
||||||
query(): Query {
|
query(): Query {
|
||||||
return new Query(this.inner);
|
return new Query(this.inner);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Search the table with a given query vector.
|
/**
|
||||||
|
* Search the table with a given query vector.
|
||||||
*
|
*
|
||||||
* This is a convenience method for preparing an ANN {@link Query}.
|
* This is a convenience method for preparing a vector query and
|
||||||
|
* is the same thing as calling `nearestTo` on the builder returned
|
||||||
|
* by `query`. @see {@link Query#nearestTo} for more details.
|
||||||
*/
|
*/
|
||||||
search(vector: number[], column?: string): Query {
|
vectorSearch(vector: unknown): VectorQuery {
|
||||||
const q = this.query();
|
return this.query().nearestTo(vector);
|
||||||
q.nearestTo(vector);
|
|
||||||
if (column !== undefined) {
|
|
||||||
q.column(column);
|
|
||||||
}
|
|
||||||
return q;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support BatchUDF
|
// TODO: Support BatchUDF
|
||||||
/**
|
/**
|
||||||
* Add new columns with defined values.
|
* Add new columns with defined values.
|
||||||
*
|
* @param {AddColumnsSql[]} newColumnTransforms pairs of column names and
|
||||||
* @param newColumnTransforms pairs of column names and the SQL expression to use
|
* the SQL expression to use to calculate the value of the new column. These
|
||||||
* to calculate the value of the new column. These
|
* expressions will be evaluated for each row in the table, and can
|
||||||
* expressions will be evaluated for each row in the
|
* reference existing columns in the table.
|
||||||
* table, and can reference existing columns in the table.
|
|
||||||
*/
|
*/
|
||||||
async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> {
|
async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> {
|
||||||
await this.inner.addColumns(newColumnTransforms);
|
await this.inner.addColumns(newColumnTransforms);
|
||||||
@@ -268,8 +264,8 @@ export class Table {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Alter the name or nullability of columns.
|
* Alter the name or nullability of columns.
|
||||||
*
|
* @param {ColumnAlteration[]} columnAlterations One or more alterations to
|
||||||
* @param columnAlterations One or more alterations to apply to columns.
|
* apply to columns.
|
||||||
*/
|
*/
|
||||||
async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> {
|
async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> {
|
||||||
await this.inner.alterColumns(columnAlterations);
|
await this.inner.alterColumns(columnAlterations);
|
||||||
@@ -282,16 +278,16 @@ export class Table {
|
|||||||
* underlying storage. In order to remove the data, you must subsequently
|
* underlying storage. In order to remove the data, you must subsequently
|
||||||
* call ``compact_files`` to rewrite the data without the removed columns and
|
* call ``compact_files`` to rewrite the data without the removed columns and
|
||||||
* then call ``cleanup_files`` to remove the old files.
|
* then call ``cleanup_files`` to remove the old files.
|
||||||
*
|
* @param {string[]} columnNames The names of the columns to drop. These can
|
||||||
* @param columnNames The names of the columns to drop. These can be nested
|
* be nested column references (e.g. "a.b.c") or top-level column names
|
||||||
* column references (e.g. "a.b.c") or top-level column
|
* (e.g. "a").
|
||||||
* names (e.g. "a").
|
|
||||||
*/
|
*/
|
||||||
async dropColumns(columnNames: string[]): Promise<void> {
|
async dropColumns(columnNames: string[]): Promise<void> {
|
||||||
await this.inner.dropColumns(columnNames);
|
await this.inner.dropColumns(columnNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Retrieve the version of the table
|
/**
|
||||||
|
* Retrieve the version of the table
|
||||||
*
|
*
|
||||||
* LanceDb supports versioning. Every operation that modifies the table increases
|
* LanceDb supports versioning. Every operation that modifies the table increases
|
||||||
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
|
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
|
||||||
@@ -302,7 +298,8 @@ export class Table {
|
|||||||
return await this.inner.version();
|
return await this.inner.version();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Checks out a specific version of the Table
|
/**
|
||||||
|
* Checks out a specific version of the Table
|
||||||
*
|
*
|
||||||
* Any read operation on the table will now access the data at the checked out version.
|
* Any read operation on the table will now access the data at the checked out version.
|
||||||
* As a consequence, calling this method will disable any read consistency interval
|
* As a consequence, calling this method will disable any read consistency interval
|
||||||
@@ -321,7 +318,8 @@ export class Table {
|
|||||||
await this.inner.checkout(version);
|
await this.inner.checkout(version);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Ensures the table is pointing at the latest version
|
/**
|
||||||
|
* Ensures the table is pointing at the latest version
|
||||||
*
|
*
|
||||||
* This can be used to manually update a table when the read_consistency_interval is None
|
* This can be used to manually update a table when the read_consistency_interval is None
|
||||||
* It can also be used to undo a `[Self::checkout]` operation
|
* It can also be used to undo a `[Self::checkout]` operation
|
||||||
@@ -330,7 +328,8 @@ export class Table {
|
|||||||
await this.inner.checkoutLatest();
|
await this.inner.checkoutLatest();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Restore the table to the currently checked out version
|
/**
|
||||||
|
* Restore the table to the currently checked out version
|
||||||
*
|
*
|
||||||
* This operation will fail if checkout has not been called previously
|
* This operation will fail if checkout has not been called previously
|
||||||
*
|
*
|
||||||
|
|||||||
120
nodejs/package-lock.json
generated
120
nodejs/package-lock.json
generated
@@ -26,6 +26,7 @@
|
|||||||
"apache-arrow-old": "npm:apache-arrow@13.0.0",
|
"apache-arrow-old": "npm:apache-arrow@13.0.0",
|
||||||
"eslint": "^8.57.0",
|
"eslint": "^8.57.0",
|
||||||
"eslint-config-prettier": "^9.1.0",
|
"eslint-config-prettier": "^9.1.0",
|
||||||
|
"eslint-plugin-jsdoc": "^48.2.1",
|
||||||
"jest": "^29.7.0",
|
"jest": "^29.7.0",
|
||||||
"prettier": "^3.1.0",
|
"prettier": "^3.1.0",
|
||||||
"tmp": "^0.2.3",
|
"tmp": "^0.2.3",
|
||||||
@@ -755,6 +756,20 @@
|
|||||||
"integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==",
|
"integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==",
|
||||||
"dev": true
|
"dev": true
|
||||||
},
|
},
|
||||||
|
"node_modules/@es-joy/jsdoccomment": {
|
||||||
|
"version": "0.42.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@es-joy/jsdoccomment/-/jsdoccomment-0.42.0.tgz",
|
||||||
|
"integrity": "sha512-R1w57YlVA6+YE01wch3GPYn6bCsrOV3YW/5oGGE2tmX6JcL9Nr+b5IikrjMPF+v9CV3ay+obImEdsDhovhJrzw==",
|
||||||
|
"dev": true,
|
||||||
|
"dependencies": {
|
||||||
|
"comment-parser": "1.4.1",
|
||||||
|
"esquery": "^1.5.0",
|
||||||
|
"jsdoc-type-pratt-parser": "~4.0.0"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=16"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@eslint-community/eslint-utils": {
|
"node_modules/@eslint-community/eslint-utils": {
|
||||||
"version": "4.4.0",
|
"version": "4.4.0",
|
||||||
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz",
|
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz",
|
||||||
@@ -1948,6 +1963,15 @@
|
|||||||
"integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==",
|
"integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==",
|
||||||
"dev": true
|
"dev": true
|
||||||
},
|
},
|
||||||
|
"node_modules/are-docs-informative": {
|
||||||
|
"version": "0.0.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/are-docs-informative/-/are-docs-informative-0.0.2.tgz",
|
||||||
|
"integrity": "sha512-ixiS0nLNNG5jNQzgZJNoUpBKdo9yTYZMGJ+QgT2jmjR7G7+QHRCc4v6LQ3NgE7EBJq+o0ams3waJwkrlBom8Ig==",
|
||||||
|
"dev": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/argparse": {
|
"node_modules/argparse": {
|
||||||
"version": "1.0.10",
|
"version": "1.0.10",
|
||||||
"resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz",
|
"resolved": "https://registry.npmjs.org/argparse/-/argparse-1.0.10.tgz",
|
||||||
@@ -2189,6 +2213,18 @@
|
|||||||
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
|
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
|
||||||
"dev": true
|
"dev": true
|
||||||
},
|
},
|
||||||
|
"node_modules/builtin-modules": {
|
||||||
|
"version": "3.3.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/builtin-modules/-/builtin-modules-3.3.0.tgz",
|
||||||
|
"integrity": "sha512-zhaCDicdLuWN5UbN5IMnFqNMhNfo919sH85y2/ea+5Yg9TsTkeZxpL+JLbp6cgYFS4sRLp3YV4S6yDuqVWHYOw==",
|
||||||
|
"dev": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=6"
|
||||||
|
},
|
||||||
|
"funding": {
|
||||||
|
"url": "https://github.com/sponsors/sindresorhus"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/camelcase": {
|
"node_modules/camelcase": {
|
||||||
"version": "5.3.1",
|
"version": "5.3.1",
|
||||||
"resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.3.1.tgz",
|
||||||
@@ -2373,6 +2409,15 @@
|
|||||||
"node": ">=12.17"
|
"node": ">=12.17"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/comment-parser": {
|
||||||
|
"version": "1.4.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/comment-parser/-/comment-parser-1.4.1.tgz",
|
||||||
|
"integrity": "sha512-buhp5kePrmda3vhc5B9t7pUQXAb2Tnd0qgpkIhPhkHXxJpiPJ11H0ZEU0oBpJ2QztSbzG/ZxMj/CHsYJqRHmyg==",
|
||||||
|
"dev": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 12.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/concat-map": {
|
"node_modules/concat-map": {
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
|
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
|
||||||
@@ -2660,6 +2705,29 @@
|
|||||||
"eslint": ">=7.0.0"
|
"eslint": ">=7.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/eslint-plugin-jsdoc": {
|
||||||
|
"version": "48.2.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/eslint-plugin-jsdoc/-/eslint-plugin-jsdoc-48.2.1.tgz",
|
||||||
|
"integrity": "sha512-iUvbcyDZSO/9xSuRv2HQBw++8VkV/pt3UWtX9cpPH0l7GKPq78QC/6+PmyQHHvNZaTjAce6QVciEbnc6J/zH5g==",
|
||||||
|
"dev": true,
|
||||||
|
"dependencies": {
|
||||||
|
"@es-joy/jsdoccomment": "~0.42.0",
|
||||||
|
"are-docs-informative": "^0.0.2",
|
||||||
|
"comment-parser": "1.4.1",
|
||||||
|
"debug": "^4.3.4",
|
||||||
|
"escape-string-regexp": "^4.0.0",
|
||||||
|
"esquery": "^1.5.0",
|
||||||
|
"is-builtin-module": "^3.2.1",
|
||||||
|
"semver": "^7.6.0",
|
||||||
|
"spdx-expression-parse": "^4.0.0"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"eslint": "^7.0.0 || ^8.0.0 || ^9.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/eslint-scope": {
|
"node_modules/eslint-scope": {
|
||||||
"version": "7.2.2",
|
"version": "7.2.2",
|
||||||
"resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz",
|
"resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz",
|
||||||
@@ -3299,6 +3367,21 @@
|
|||||||
"integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==",
|
"integrity": "sha512-NcdALwpXkTm5Zvvbk7owOUSvVvBKDgKP5/ewfXEznmQFfs4ZRmanOeKBTjRVjka3QFoN6XJ+9F3USqfHqTaU5w==",
|
||||||
"optional": true
|
"optional": true
|
||||||
},
|
},
|
||||||
|
"node_modules/is-builtin-module": {
|
||||||
|
"version": "3.2.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/is-builtin-module/-/is-builtin-module-3.2.1.tgz",
|
||||||
|
"integrity": "sha512-BSLE3HnV2syZ0FK0iMA/yUGplUeMmNz4AW5fnTunbCIqZi4vG3WjJT9FHMy5D69xmAYBHXQhJdALdpwVxV501A==",
|
||||||
|
"dev": true,
|
||||||
|
"dependencies": {
|
||||||
|
"builtin-modules": "^3.3.0"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=6"
|
||||||
|
},
|
||||||
|
"funding": {
|
||||||
|
"url": "https://github.com/sponsors/sindresorhus"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/is-core-module": {
|
"node_modules/is-core-module": {
|
||||||
"version": "2.13.1",
|
"version": "2.13.1",
|
||||||
"resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.1.tgz",
|
"resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.1.tgz",
|
||||||
@@ -4172,6 +4255,15 @@
|
|||||||
"js-yaml": "bin/js-yaml.js"
|
"js-yaml": "bin/js-yaml.js"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/jsdoc-type-pratt-parser": {
|
||||||
|
"version": "4.0.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/jsdoc-type-pratt-parser/-/jsdoc-type-pratt-parser-4.0.0.tgz",
|
||||||
|
"integrity": "sha512-YtOli5Cmzy3q4dP26GraSOeAhqecewG04hoO8DY56CH4KJ9Fvv5qKWUCCo3HZob7esJQHCv6/+bnTy72xZZaVQ==",
|
||||||
|
"dev": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/jsesc": {
|
"node_modules/jsesc": {
|
||||||
"version": "2.5.2",
|
"version": "2.5.2",
|
||||||
"resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz",
|
"resolved": "https://registry.npmjs.org/jsesc/-/jsesc-2.5.2.tgz",
|
||||||
@@ -5018,9 +5110,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/semver": {
|
"node_modules/semver": {
|
||||||
"version": "7.5.4",
|
"version": "7.6.0",
|
||||||
"resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
|
"resolved": "https://registry.npmjs.org/semver/-/semver-7.6.0.tgz",
|
||||||
"integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
|
"integrity": "sha512-EnwXhrlwXMk9gKu5/flx5sv/an57AkRplG3hTK68W7FRDN+k+OWBj65M7719OkA82XLBxrcX0KSHj+X5COhOVg==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"lru-cache": "^6.0.0"
|
"lru-cache": "^6.0.0"
|
||||||
@@ -5105,6 +5197,28 @@
|
|||||||
"source-map": "^0.6.0"
|
"source-map": "^0.6.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/spdx-exceptions": {
|
||||||
|
"version": "2.5.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/spdx-exceptions/-/spdx-exceptions-2.5.0.tgz",
|
||||||
|
"integrity": "sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==",
|
||||||
|
"dev": true
|
||||||
|
},
|
||||||
|
"node_modules/spdx-expression-parse": {
|
||||||
|
"version": "4.0.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/spdx-expression-parse/-/spdx-expression-parse-4.0.0.tgz",
|
||||||
|
"integrity": "sha512-Clya5JIij/7C6bRR22+tnGXbc4VKlibKSVj2iHvVeX5iMW7s1SIQlqu699JkODJJIhh/pUu8L0/VLh8xflD+LQ==",
|
||||||
|
"dev": true,
|
||||||
|
"dependencies": {
|
||||||
|
"spdx-exceptions": "^2.1.0",
|
||||||
|
"spdx-license-ids": "^3.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/spdx-license-ids": {
|
||||||
|
"version": "3.0.17",
|
||||||
|
"resolved": "https://registry.npmjs.org/spdx-license-ids/-/spdx-license-ids-3.0.17.tgz",
|
||||||
|
"integrity": "sha512-sh8PWc/ftMqAAdFiBu6Fy6JUOYjqDJBJvIhpfDMyHrr0Rbp5liZqd4TjtQ/RgfLjKFZb+LMx5hpml5qOWy0qvg==",
|
||||||
|
"dev": true
|
||||||
|
},
|
||||||
"node_modules/sprintf-js": {
|
"node_modules/sprintf-js": {
|
||||||
"version": "1.0.3",
|
"version": "1.0.3",
|
||||||
"resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz",
|
"resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz",
|
||||||
|
|||||||
@@ -25,6 +25,7 @@
|
|||||||
"apache-arrow-old": "npm:apache-arrow@13.0.0",
|
"apache-arrow-old": "npm:apache-arrow@13.0.0",
|
||||||
"eslint": "^8.57.0",
|
"eslint": "^8.57.0",
|
||||||
"eslint-config-prettier": "^9.1.0",
|
"eslint-config-prettier": "^9.1.0",
|
||||||
|
"eslint-plugin-jsdoc": "^48.2.1",
|
||||||
"jest": "^29.7.0",
|
"jest": "^29.7.0",
|
||||||
"prettier": "^3.1.0",
|
"prettier": "^3.1.0",
|
||||||
"tmp": "^0.2.3",
|
"tmp": "^0.2.3",
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ impl Connection {
|
|||||||
let mode = Self::parse_create_mode_str(&mode)?;
|
let mode = Self::parse_create_mode_str(&mode)?;
|
||||||
let tbl = self
|
let tbl = self
|
||||||
.get_inner()?
|
.get_inner()?
|
||||||
.create_table(&name, Box::new(batches))
|
.create_table(&name, batches)
|
||||||
.mode(mode)
|
.mode(mode)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -17,9 +17,10 @@ use std::sync::Mutex;
|
|||||||
use lancedb::index::scalar::BTreeIndexBuilder;
|
use lancedb::index::scalar::BTreeIndexBuilder;
|
||||||
use lancedb::index::vector::IvfPqIndexBuilder;
|
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||||
use lancedb::index::Index as LanceDbIndex;
|
use lancedb::index::Index as LanceDbIndex;
|
||||||
use lancedb::DistanceType;
|
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
|
|
||||||
|
use crate::util::parse_distance_type;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Index {
|
pub struct Index {
|
||||||
inner: Mutex<Option<LanceDbIndex>>,
|
inner: Mutex<Option<LanceDbIndex>>,
|
||||||
@@ -49,15 +50,7 @@ impl Index {
|
|||||||
) -> napi::Result<Self> {
|
) -> napi::Result<Self> {
|
||||||
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
||||||
if let Some(distance_type) = distance_type {
|
if let Some(distance_type) = distance_type {
|
||||||
let distance_type = match distance_type.as_str() {
|
let distance_type = parse_distance_type(distance_type)?;
|
||||||
"l2" => Ok(DistanceType::L2),
|
|
||||||
"cosine" => Ok(DistanceType::Cosine),
|
|
||||||
"dot" => Ok(DistanceType::Dot),
|
|
||||||
_ => Err(napi::Error::from_reason(format!(
|
|
||||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
|
||||||
distance_type
|
|
||||||
))),
|
|
||||||
}?;
|
|
||||||
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
||||||
}
|
}
|
||||||
if let Some(num_partitions) = num_partitions {
|
if let Some(num_partitions) = num_partitions {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ mod index;
|
|||||||
mod iterator;
|
mod iterator;
|
||||||
mod query;
|
mod query;
|
||||||
mod table;
|
mod table;
|
||||||
|
mod util;
|
||||||
|
|
||||||
#[napi(object)]
|
#[napi(object)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|||||||
@@ -12,36 +12,38 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use lancedb::query::Query as LanceDBQuery;
|
use lancedb::query::ExecutableQuery;
|
||||||
|
use lancedb::query::Query as LanceDbQuery;
|
||||||
|
use lancedb::query::QueryBase;
|
||||||
|
use lancedb::query::Select;
|
||||||
|
use lancedb::query::VectorQuery as LanceDbVectorQuery;
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
|
|
||||||
|
use crate::error::NapiErrorExt;
|
||||||
use crate::iterator::RecordBatchIterator;
|
use crate::iterator::RecordBatchIterator;
|
||||||
|
use crate::util::parse_distance_type;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Query {
|
pub struct Query {
|
||||||
inner: LanceDBQuery,
|
inner: LanceDbQuery,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
impl Query {
|
impl Query {
|
||||||
pub fn new(query: LanceDBQuery) -> Self {
|
pub fn new(query: LanceDbQuery) -> Self {
|
||||||
Self { inner: query }
|
Self { inner: query }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We cannot call this r#where because NAPI gets confused by the r#
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn column(&mut self, column: String) {
|
pub fn only_if(&mut self, predicate: String) {
|
||||||
self.inner = self.inner.clone().column(&column);
|
self.inner = self.inner.clone().only_if(predicate);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn filter(&mut self, filter: String) {
|
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||||
self.inner = self.inner.clone().filter(filter);
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
}
|
|
||||||
|
|
||||||
#[napi]
|
|
||||||
pub fn select(&mut self, columns: Vec<String>) {
|
|
||||||
self.inner = self.inner.clone().select(&columns);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -50,13 +52,46 @@ impl Query {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn prefilter(&mut self, prefilter: bool) {
|
pub fn nearest_to(&mut self, vector: Float32Array) -> Result<VectorQuery> {
|
||||||
self.inner = self.inner.clone().prefilter(prefilter);
|
let inner = self
|
||||||
|
.inner
|
||||||
|
.clone()
|
||||||
|
.nearest_to(vector.as_ref())
|
||||||
|
.default_error()?;
|
||||||
|
Ok(VectorQuery { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn nearest_to(&mut self, vector: Float32Array) {
|
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
|
||||||
self.inner = self.inner.clone().nearest_to(&vector);
|
let inner_stream = self.inner.execute().await.map_err(|e| {
|
||||||
|
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(RecordBatchIterator::new(inner_stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub struct VectorQuery {
|
||||||
|
inner: LanceDbVectorQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl VectorQuery {
|
||||||
|
#[napi]
|
||||||
|
pub fn column(&mut self, column: String) {
|
||||||
|
self.inner = self.inner.clone().column(&column);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
||||||
|
let distance_type = parse_distance_type(distance_type)?;
|
||||||
|
self.inner = self.inner.clone().distance_type(distance_type);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn postfilter(&mut self) {
|
||||||
|
self.inner = self.inner.clone().postfilter();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -70,8 +105,28 @@ impl Query {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn execute_stream(&self) -> napi::Result<RecordBatchIterator> {
|
pub fn bypass_vector_index(&mut self) {
|
||||||
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
|
self.inner = self.inner.clone().bypass_vector_index()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn only_if(&mut self, predicate: String) {
|
||||||
|
self.inner = self.inner.clone().only_if(predicate);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||||
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn limit(&mut self, limit: u32) {
|
||||||
|
self.inner = self.inner.clone().limit(limit as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
|
||||||
|
let inner_stream = self.inner.execute().await.map_err(|e| {
|
||||||
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
|
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
|
||||||
})?;
|
})?;
|
||||||
Ok(RecordBatchIterator::new(inner_stream))
|
Ok(RecordBatchIterator::new(inner_stream))
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ use napi_derive::napi;
|
|||||||
|
|
||||||
use crate::error::NapiErrorExt;
|
use crate::error::NapiErrorExt;
|
||||||
use crate::index::Index;
|
use crate::index::Index;
|
||||||
use crate::query::Query;
|
use crate::query::{Query, VectorQuery};
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Table {
|
pub struct Table {
|
||||||
@@ -89,7 +89,7 @@ impl Table {
|
|||||||
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> {
|
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> {
|
||||||
let batches = ipc_file_to_batches(buf.to_vec())
|
let batches = ipc_file_to_batches(buf.to_vec())
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||||
let mut op = self.inner_ref()?.add(Box::new(batches));
|
let mut op = self.inner_ref()?.add(batches);
|
||||||
|
|
||||||
op = if mode == "append" {
|
op = if mode == "append" {
|
||||||
op.mode(AddDataMode::Append)
|
op.mode(AddDataMode::Append)
|
||||||
@@ -171,6 +171,11 @@ impl Table {
|
|||||||
Ok(Query::new(self.inner_ref()?.query()))
|
Ok(Query::new(self.inner_ref()?.query()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn vector_search(&self, vector: Float32Array) -> napi::Result<VectorQuery> {
|
||||||
|
self.query()?.nearest_to(vector)
|
||||||
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
|
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
|
||||||
let transforms = transforms
|
let transforms = transforms
|
||||||
|
|||||||
13
nodejs/src/util.rs
Normal file
13
nodejs/src/util.rs
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
use lancedb::DistanceType;
|
||||||
|
|
||||||
|
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
|
||||||
|
match distance_type.as_ref().to_lowercase().as_str() {
|
||||||
|
"l2" => Ok(DistanceType::L2),
|
||||||
|
"cosine" => Ok(DistanceType::Cosine),
|
||||||
|
"dot" => Ok(DistanceType::Dot),
|
||||||
|
_ => Err(napi::Error::from_reason(format!(
|
||||||
|
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||||
|
distance_type.as_ref()
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.6.4
|
current_version = 0.6.5
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
|||||||
|
|
||||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||||
lzma-sys = { version = "*", features = ["static"] }
|
lzma-sys = { version = "*", features = ["static"] }
|
||||||
|
pin-project = "1.1.5"
|
||||||
|
futures.workspace = true
|
||||||
|
tokio = { version = "1.36.0", features = ["sync"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
pyo3-build-config = { version = "0.20.3", features = [
|
pyo3-build-config = { version = "0.20.3", features = [
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.6.4"
|
version = "0.6.5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.10.4",
|
"pylance==0.10.5",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
@@ -94,13 +94,11 @@ lancedb = "lancedb.cli.cli:cli"
|
|||||||
requires = ["maturin>=1.4"]
|
requires = ["maturin>=1.4"]
|
||||||
build-backend = "maturin"
|
build-backend = "maturin"
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||||
|
|
||||||
markers = [
|
markers = [
|
||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"asyncio",
|
"asyncio",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -40,6 +40,8 @@ class Table:
|
|||||||
async def checkout_latest(self): ...
|
async def checkout_latest(self): ...
|
||||||
async def restore(self): ...
|
async def restore(self): ...
|
||||||
async def list_indices(self) -> List[IndexConfig]: ...
|
async def list_indices(self) -> List[IndexConfig]: ...
|
||||||
|
def query(self) -> Query: ...
|
||||||
|
def vector_search(self) -> VectorQuery: ...
|
||||||
|
|
||||||
class IndexConfig:
|
class IndexConfig:
|
||||||
index_type: str
|
index_type: str
|
||||||
@@ -52,3 +54,27 @@ async def connect(
|
|||||||
host_override: Optional[str],
|
host_override: Optional[str],
|
||||||
read_consistency_interval: Optional[float],
|
read_consistency_interval: Optional[float],
|
||||||
) -> Connection: ...
|
) -> Connection: ...
|
||||||
|
|
||||||
|
class RecordBatchStream:
|
||||||
|
def schema(self) -> pa.Schema: ...
|
||||||
|
async def next(self) -> Optional[pa.RecordBatch]: ...
|
||||||
|
|
||||||
|
class Query:
|
||||||
|
def where(self, filter: str): ...
|
||||||
|
def select(self, columns: Tuple[str, str]): ...
|
||||||
|
def limit(self, limit: int): ...
|
||||||
|
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||||
|
async def execute(self) -> RecordBatchStream: ...
|
||||||
|
|
||||||
|
class VectorQuery:
|
||||||
|
async def execute(self) -> RecordBatchStream: ...
|
||||||
|
def where(self, filter: str): ...
|
||||||
|
def select(self, columns: List[str]): ...
|
||||||
|
def select_with_projection(self, columns: Tuple[str, str]): ...
|
||||||
|
def limit(self, limit: int): ...
|
||||||
|
def column(self, column: str): ...
|
||||||
|
def distance_type(self, distance_type: str): ...
|
||||||
|
def postfilter(self): ...
|
||||||
|
def refine_factor(self, refine_factor: int): ...
|
||||||
|
def nprobes(self, nprobes: int): ...
|
||||||
|
def bypass_vector_index(self): ...
|
||||||
|
|||||||
44
python/python/lancedb/arrow.py
Normal file
44
python/python/lancedb/arrow.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ._lancedb import RecordBatchStream
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncRecordBatchReader:
|
||||||
|
"""
|
||||||
|
An async iterator over a stream of RecordBatches.
|
||||||
|
|
||||||
|
Also allows access to the schema of the stream
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, inner: RecordBatchStream):
|
||||||
|
self.inner_ = inner
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schema(self) -> pa.Schema:
|
||||||
|
"""
|
||||||
|
Get the schema of the batches produced by the stream
|
||||||
|
|
||||||
|
Accessing the schema does not consume any data from the stream
|
||||||
|
"""
|
||||||
|
return self.inner_.schema()
|
||||||
|
|
||||||
|
async def read_all(self) -> List[pa.RecordBatch]:
|
||||||
|
"""
|
||||||
|
Read all the record batches from the stream
|
||||||
|
|
||||||
|
This consumes the entire stream and returns a list of record batches
|
||||||
|
|
||||||
|
If there are a lot of results this may consume a lot of memory
|
||||||
|
"""
|
||||||
|
return [batch async for batch in self]
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> pa.RecordBatch:
|
||||||
|
next = await self.inner_.next()
|
||||||
|
if next is None:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return next
|
||||||
@@ -28,7 +28,9 @@ except ImportError:
|
|||||||
from .table import LanceTable
|
from .table import LanceTable
|
||||||
|
|
||||||
|
|
||||||
def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
|
def create_index(
|
||||||
|
index_path: str, text_fields: List[str], ordering_fields: List[str] = None
|
||||||
|
) -> tantivy.Index:
|
||||||
"""
|
"""
|
||||||
Create a new Index (not populated)
|
Create a new Index (not populated)
|
||||||
|
|
||||||
@@ -38,12 +40,16 @@ def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
|
|||||||
Path to the index directory
|
Path to the index directory
|
||||||
text_fields : List[str]
|
text_fields : List[str]
|
||||||
List of text fields to index
|
List of text fields to index
|
||||||
|
ordering_fields: List[str]
|
||||||
|
List of unsigned type fields to order by at search time
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
index : tantivy.Index
|
index : tantivy.Index
|
||||||
The index object (not yet populated)
|
The index object (not yet populated)
|
||||||
"""
|
"""
|
||||||
|
if ordering_fields is None:
|
||||||
|
ordering_fields = []
|
||||||
# Declaring our schema.
|
# Declaring our schema.
|
||||||
schema_builder = tantivy.SchemaBuilder()
|
schema_builder = tantivy.SchemaBuilder()
|
||||||
# special field that we'll populate with row_id
|
# special field that we'll populate with row_id
|
||||||
@@ -51,6 +57,9 @@ def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
|
|||||||
# data fields
|
# data fields
|
||||||
for name in text_fields:
|
for name in text_fields:
|
||||||
schema_builder.add_text_field(name, stored=True)
|
schema_builder.add_text_field(name, stored=True)
|
||||||
|
if ordering_fields:
|
||||||
|
for name in ordering_fields:
|
||||||
|
schema_builder.add_unsigned_field(name, fast=True)
|
||||||
schema = schema_builder.build()
|
schema = schema_builder.build()
|
||||||
os.makedirs(index_path, exist_ok=True)
|
os.makedirs(index_path, exist_ok=True)
|
||||||
index = tantivy.Index(schema, path=index_path)
|
index = tantivy.Index(schema, path=index_path)
|
||||||
@@ -62,6 +71,7 @@ def populate_index(
|
|||||||
table: LanceTable,
|
table: LanceTable,
|
||||||
fields: List[str],
|
fields: List[str],
|
||||||
writer_heap_size: int = 1024 * 1024 * 1024,
|
writer_heap_size: int = 1024 * 1024 * 1024,
|
||||||
|
ordering_fields: List[str] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Populate an index with data from a LanceTable
|
Populate an index with data from a LanceTable
|
||||||
@@ -82,8 +92,11 @@ def populate_index(
|
|||||||
int
|
int
|
||||||
The number of rows indexed
|
The number of rows indexed
|
||||||
"""
|
"""
|
||||||
|
if ordering_fields is None:
|
||||||
|
ordering_fields = []
|
||||||
# first check the fields exist and are string or large string type
|
# first check the fields exist and are string or large string type
|
||||||
nested = []
|
nested = []
|
||||||
|
|
||||||
for name in fields:
|
for name in fields:
|
||||||
try:
|
try:
|
||||||
f = table.schema.field(name) # raises KeyError if not found
|
f = table.schema.field(name) # raises KeyError if not found
|
||||||
@@ -104,7 +117,7 @@ def populate_index(
|
|||||||
if len(nested) > 0:
|
if len(nested) > 0:
|
||||||
max_nested_level = max([len(name.split(".")) for name in nested])
|
max_nested_level = max([len(name.split(".")) for name in nested])
|
||||||
|
|
||||||
for b in dataset.to_batches(columns=fields):
|
for b in dataset.to_batches(columns=fields + ordering_fields):
|
||||||
if max_nested_level > 0:
|
if max_nested_level > 0:
|
||||||
b = pa.Table.from_batches([b])
|
b = pa.Table.from_batches([b])
|
||||||
for _ in range(max_nested_level - 1):
|
for _ in range(max_nested_level - 1):
|
||||||
@@ -115,6 +128,10 @@ def populate_index(
|
|||||||
value = b[name][i].as_py()
|
value = b[name][i].as_py()
|
||||||
if value is not None:
|
if value is not None:
|
||||||
doc.add_text(name, value)
|
doc.add_text(name, value)
|
||||||
|
for name in ordering_fields:
|
||||||
|
value = b[name][i].as_py()
|
||||||
|
if value is not None:
|
||||||
|
doc.add_unsigned(name, value)
|
||||||
if not doc.is_empty:
|
if not doc.is_empty:
|
||||||
doc.add_integer("doc_id", row_id)
|
doc.add_integer("doc_id", row_id)
|
||||||
writer.add_document(doc)
|
writer.add_document(doc)
|
||||||
@@ -149,7 +166,7 @@ def resolve_path(schema, field_name: str) -> pa.Field:
|
|||||||
|
|
||||||
|
|
||||||
def search_index(
|
def search_index(
|
||||||
index: tantivy.Index, query: str, limit: int = 10
|
index: tantivy.Index, query: str, limit: int = 10, ordering_field=None
|
||||||
) -> Tuple[Tuple[int], Tuple[float]]:
|
) -> Tuple[Tuple[int], Tuple[float]]:
|
||||||
"""
|
"""
|
||||||
Search an index for a query
|
Search an index for a query
|
||||||
@@ -172,7 +189,10 @@ def search_index(
|
|||||||
searcher = index.searcher()
|
searcher = index.searcher()
|
||||||
query = index.parse_query(query)
|
query = index.parse_query(query)
|
||||||
# get top results
|
# get top results
|
||||||
results = searcher.search(query, limit)
|
if ordering_field:
|
||||||
|
results = searcher.search(query, limit, order_by_field=ordering_field)
|
||||||
|
else:
|
||||||
|
results = searcher.search(query, limit)
|
||||||
if results.count == 0:
|
if results.count == 0:
|
||||||
return tuple(), tuple()
|
return tuple(), tuple()
|
||||||
return tuple(
|
return tuple(
|
||||||
|
|||||||
@@ -16,7 +16,16 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import deprecation
|
import deprecation
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -24,6 +33,7 @@ import pyarrow as pa
|
|||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
from .arrow import AsyncRecordBatchReader
|
||||||
from .common import VEC
|
from .common import VEC
|
||||||
from .rerankers.base import Reranker
|
from .rerankers.base import Reranker
|
||||||
from .rerankers.linear_combination import LinearCombinationReranker
|
from .rerankers.linear_combination import LinearCombinationReranker
|
||||||
@@ -33,6 +43,8 @@ if TYPE_CHECKING:
|
|||||||
import PIL
|
import PIL
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
|
from ._lancedb import Query as LanceQuery
|
||||||
|
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
from .table import Table
|
from .table import Table
|
||||||
|
|
||||||
@@ -117,6 +129,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||||
query_type: str,
|
query_type: str,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
|
ordering_field_name: str = None,
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""
|
"""
|
||||||
Create a query builder based on the given query and query type.
|
Create a query builder based on the given query and query type.
|
||||||
@@ -141,6 +154,9 @@ class LanceQueryBuilder(ABC):
|
|||||||
# hybrid fts and vector query
|
# hybrid fts and vector query
|
||||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||||
|
|
||||||
|
# remember the string query for reranking purpose
|
||||||
|
str_query = query if isinstance(query, str) else None
|
||||||
|
|
||||||
# convert "auto" query_type to "vector", "fts"
|
# convert "auto" query_type to "vector", "fts"
|
||||||
# or "hybrid" and convert the query to vector if needed
|
# or "hybrid" and convert the query to vector if needed
|
||||||
query, query_type = cls._resolve_query(
|
query, query_type = cls._resolve_query(
|
||||||
@@ -152,7 +168,9 @@ class LanceQueryBuilder(ABC):
|
|||||||
|
|
||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
# fts
|
# fts
|
||||||
return LanceFtsQueryBuilder(table, query)
|
return LanceFtsQueryBuilder(
|
||||||
|
table, query, ordering_field_name=ordering_field_name
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(query, list):
|
if isinstance(query, list):
|
||||||
query = np.array(query, dtype=np.float32)
|
query = np.array(query, dtype=np.float32)
|
||||||
@@ -161,7 +179,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||||
|
|
||||||
return LanceVectorQueryBuilder(table, query, vector_column_name)
|
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||||
@@ -425,6 +443,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
table: "Table",
|
table: "Table",
|
||||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||||
vector_column: str,
|
vector_column: str,
|
||||||
|
str_query: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._query = query
|
self._query = query
|
||||||
@@ -433,6 +452,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._refine_factor = None
|
self._refine_factor = None
|
||||||
self._vector_column = vector_column
|
self._vector_column = vector_column
|
||||||
self._prefilter = False
|
self._prefilter = False
|
||||||
|
self._reranker = None
|
||||||
|
self._str_query = str_query
|
||||||
|
|
||||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||||
"""Set the distance metric to use.
|
"""Set the distance metric to use.
|
||||||
@@ -503,6 +524,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
and also the "_distance" column which is the distance between the query
|
and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vectors.
|
vector and the returned vectors.
|
||||||
"""
|
"""
|
||||||
|
return self.to_batches().read_all()
|
||||||
|
|
||||||
|
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||||
|
"""
|
||||||
|
Execute the query and return the result as a RecordBatchReader object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size: int
|
||||||
|
The maximum number of selected records in a RecordBatch object.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pa.RecordBatchReader
|
||||||
|
"""
|
||||||
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
||||||
if isinstance(vector[0], np.ndarray):
|
if isinstance(vector[0], np.ndarray):
|
||||||
vector = [v.tolist() for v in vector]
|
vector = [v.tolist() for v in vector]
|
||||||
@@ -518,7 +554,16 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
vector_column=self._vector_column,
|
vector_column=self._vector_column,
|
||||||
with_row_id=self._with_row_id,
|
with_row_id=self._with_row_id,
|
||||||
)
|
)
|
||||||
return self._table._execute_query(query)
|
result_set = self._table._execute_query(query, batch_size)
|
||||||
|
if self._reranker is not None:
|
||||||
|
rs_table = result_set.read_all()
|
||||||
|
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||||
|
# convert result_set back to RecordBatchReader
|
||||||
|
result_set = pa.RecordBatchReader.from_batches(
|
||||||
|
result_set.schema, result_set.to_batches()
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
|
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
|
||||||
"""Set the where clause.
|
"""Set the where clause.
|
||||||
@@ -544,14 +589,52 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._prefilter = prefilter
|
self._prefilter = prefilter
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def rerank(
|
||||||
|
self, reranker: Reranker, query_string: Optional[str] = None
|
||||||
|
) -> LanceVectorQueryBuilder:
|
||||||
|
"""Rerank the results using the specified reranker.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
reranker: Reranker
|
||||||
|
The reranker to use.
|
||||||
|
|
||||||
|
query_string: Optional[str]
|
||||||
|
The query to use for reranking. This needs to be specified explicitly here
|
||||||
|
as the query used for vector search may already be vectorized and the
|
||||||
|
reranker requires a string query.
|
||||||
|
This is only required if the query used for vector search is not a string.
|
||||||
|
Note: This doesn't yet support the case where the query is multimodal or a
|
||||||
|
list of vectors.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceVectorQueryBuilder
|
||||||
|
The LanceQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._reranker = reranker
|
||||||
|
if self._str_query is None and query_string is None:
|
||||||
|
raise ValueError(
|
||||||
|
"""
|
||||||
|
The query used for vector search is not a string.
|
||||||
|
In this case, the reranker query needs to be specified explicitly.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if query_string is not None and not isinstance(query_string, str):
|
||||||
|
raise ValueError("Reranking currently only supports string queries")
|
||||||
|
self._str_query = query_string if query_string is not None else self._str_query
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||||
"""A builder for full text search for LanceDB."""
|
"""A builder for full text search for LanceDB."""
|
||||||
|
|
||||||
def __init__(self, table: "Table", query: str):
|
def __init__(self, table: "Table", query: str, ordering_field_name: str = None):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._query = query
|
self._query = query
|
||||||
self._phrase_query = False
|
self._phrase_query = False
|
||||||
|
self.ordering_field_name = ordering_field_name
|
||||||
|
self._reranker = None
|
||||||
|
|
||||||
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
|
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
|
||||||
"""Set whether to use phrase query.
|
"""Set whether to use phrase query.
|
||||||
@@ -596,7 +679,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
if self._phrase_query:
|
if self._phrase_query:
|
||||||
query = query.replace('"', "'")
|
query = query.replace('"', "'")
|
||||||
query = f'"{query}"'
|
query = f'"{query}"'
|
||||||
row_ids, scores = search_index(index, query, self._limit)
|
row_ids, scores = search_index(
|
||||||
|
index, query, self._limit, ordering_field=self.ordering_field_name
|
||||||
|
)
|
||||||
if len(row_ids) == 0:
|
if len(row_ids) == 0:
|
||||||
empty_schema = pa.schema([pa.field("score", pa.float32())])
|
empty_schema = pa.schema([pa.field("score", pa.float32())])
|
||||||
return pa.Table.from_pylist([], schema=empty_schema)
|
return pa.Table.from_pylist([], schema=empty_schema)
|
||||||
@@ -638,8 +723,27 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
|
|
||||||
if self._with_row_id:
|
if self._with_row_id:
|
||||||
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
||||||
|
|
||||||
|
if self._reranker is not None:
|
||||||
|
output_tbl = self._reranker.rerank_fts(self._query, output_tbl)
|
||||||
return output_tbl
|
return output_tbl
|
||||||
|
|
||||||
|
def rerank(self, reranker: Reranker) -> LanceFtsQueryBuilder:
|
||||||
|
"""Rerank the results using the specified reranker.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
reranker: Reranker
|
||||||
|
The reranker to use.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceFtsQueryBuilder
|
||||||
|
The LanceQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._reranker = reranker
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
@@ -921,3 +1025,334 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
"""
|
"""
|
||||||
self._vector_query.refine_factor(refine_factor)
|
self._vector_query.refine_factor(refine_factor)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncQueryBase(object):
|
||||||
|
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
|
||||||
|
"""
|
||||||
|
Construct an AsyncQueryBase
|
||||||
|
|
||||||
|
This method is not intended to be called directly. Instead, use the
|
||||||
|
[Table.query][] method to create a query.
|
||||||
|
"""
|
||||||
|
self._inner = inner
|
||||||
|
|
||||||
|
def where(self, predicate: str) -> AsyncQuery:
|
||||||
|
"""
|
||||||
|
Only return rows matching the given predicate
|
||||||
|
|
||||||
|
The predicate should be supplied as an SQL query string. For example:
|
||||||
|
|
||||||
|
>>> predicate = "x > 10"
|
||||||
|
>>> predicate = "y > 0 AND y < 100"
|
||||||
|
>>> predicate = "x > 5 OR y = 'test'"
|
||||||
|
|
||||||
|
Filtering performance can often be improved by creating a scalar index
|
||||||
|
on the filter column(s).
|
||||||
|
"""
|
||||||
|
self._inner.where(predicate)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery:
|
||||||
|
"""
|
||||||
|
Return only the specified columns.
|
||||||
|
|
||||||
|
By default a query will return all columns from the table. However, this can
|
||||||
|
have a very significant impact on latency. LanceDb stores data in a columnar
|
||||||
|
fashion. This
|
||||||
|
means we can finely tune our I/O to select exactly the columns we need.
|
||||||
|
|
||||||
|
As a best practice you should always limit queries to the columns that you need.
|
||||||
|
If you pass in a list of column names then only those columns will be
|
||||||
|
returned.
|
||||||
|
|
||||||
|
You can also use this method to create new "dynamic" columns based on your
|
||||||
|
existing columns. For example, you may not care about "a" or "b" but instead
|
||||||
|
simply want "a + b". This is often seen in the SELECT clause of an SQL query
|
||||||
|
(e.g. `SELECT a+b FROM my_table`).
|
||||||
|
|
||||||
|
To create dynamic columns you can pass in a dict[str, str]. A column will be
|
||||||
|
returned for each entry in the map. The key provides the name of the column.
|
||||||
|
The value is an SQL string used to specify how the column is calculated.
|
||||||
|
|
||||||
|
For example, an SQL query might state `SELECT a + b AS combined, c`. The
|
||||||
|
equivalent input to this method would be `{"combined": "a + b", "c": "c"}`.
|
||||||
|
|
||||||
|
Columns will always be returned in the order given, even if that order is
|
||||||
|
different than the order used when adding the data.
|
||||||
|
"""
|
||||||
|
if isinstance(columns, dict):
|
||||||
|
column_tuples = list(columns.items())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
column_tuples = [(c, c) for c in columns]
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError("columns must be a list of column names or a dict")
|
||||||
|
self._inner.select(column_tuples)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def limit(self, limit: int) -> AsyncQuery:
|
||||||
|
"""
|
||||||
|
Set the maximum number of results to return.
|
||||||
|
|
||||||
|
By default, a plain search has no limit. If this method is not
|
||||||
|
called then every valid row from the table will be returned.
|
||||||
|
"""
|
||||||
|
self._inner.limit(limit)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def to_batches(self) -> AsyncRecordBatchReader:
|
||||||
|
"""
|
||||||
|
Execute the query and return the results as an Apache Arrow RecordBatchReader.
|
||||||
|
"""
|
||||||
|
return AsyncRecordBatchReader(await self._inner.execute())
|
||||||
|
|
||||||
|
async def to_arrow(self) -> pa.Table:
|
||||||
|
"""
|
||||||
|
Execute the query and collect the results into an Apache Arrow Table.
|
||||||
|
|
||||||
|
This method will collect all results into memory before returning. If
|
||||||
|
you expect a large number of results, you may want to use [to_batches][]
|
||||||
|
"""
|
||||||
|
batch_iter = await self.to_batches()
|
||||||
|
return pa.Table.from_batches(
|
||||||
|
await batch_iter.read_all(), schema=batch_iter.schema
|
||||||
|
)
|
||||||
|
|
||||||
|
async def to_pandas(self) -> "pd.DataFrame":
|
||||||
|
"""
|
||||||
|
Execute the query and collect the results into a pandas DataFrame.
|
||||||
|
|
||||||
|
This method will collect all results into memory before returning. If
|
||||||
|
you expect a large number of results, you may want to use [to_batches][]
|
||||||
|
and convert each batch to pandas separately.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
|
||||||
|
>>> import asyncio
|
||||||
|
>>> from lancedb import connect_async
|
||||||
|
>>> async def doctest_example():
|
||||||
|
... conn = await connect_async("./.lancedb")
|
||||||
|
... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}])
|
||||||
|
... async for batch in await table.query().to_batches():
|
||||||
|
... batch_df = batch.to_pandas()
|
||||||
|
>>> asyncio.run(doctest_example())
|
||||||
|
"""
|
||||||
|
return (await self.to_arrow()).to_pandas()
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncQuery(AsyncQueryBase):
|
||||||
|
def __init__(self, inner: LanceQuery):
|
||||||
|
"""
|
||||||
|
Construct an AsyncQuery
|
||||||
|
|
||||||
|
This method is not intended to be called directly. Instead, use the
|
||||||
|
[Table.query][] method to create a query.
|
||||||
|
"""
|
||||||
|
super().__init__(inner)
|
||||||
|
self._inner = inner
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _query_vec_to_array(self, vec: Union[VEC, Tuple]):
|
||||||
|
if isinstance(vec, list):
|
||||||
|
return pa.array(vec)
|
||||||
|
if isinstance(vec, np.ndarray):
|
||||||
|
return pa.array(vec)
|
||||||
|
if isinstance(vec, pa.Array):
|
||||||
|
return vec
|
||||||
|
if isinstance(vec, pa.ChunkedArray):
|
||||||
|
return vec.combine_chunks()
|
||||||
|
if isinstance(vec, tuple):
|
||||||
|
return pa.array(vec)
|
||||||
|
# We've checked everything we formally support in our typings
|
||||||
|
# but, as a fallback, let pyarrow try and convert it anyway.
|
||||||
|
# This can allow for some more exotic things like iterables
|
||||||
|
return pa.array(vec)
|
||||||
|
|
||||||
|
def nearest_to(
|
||||||
|
self, query_vector: Optional[Union[VEC, Tuple]] = None
|
||||||
|
) -> AsyncVectorQuery:
|
||||||
|
"""
|
||||||
|
Find the nearest vectors to the given query vector.
|
||||||
|
|
||||||
|
This converts the query from a plain query to a vector query.
|
||||||
|
|
||||||
|
This method will attempt to convert the input to the query vector
|
||||||
|
expected by the embedding model. If the input cannot be converted
|
||||||
|
then an error will be thrown.
|
||||||
|
|
||||||
|
By default, there is no embedding model, and the input should be
|
||||||
|
something that can be converted to a pyarrow array of floats. This
|
||||||
|
includes lists, numpy arrays, and tuples.
|
||||||
|
|
||||||
|
If there is only one vector column (a column whose data type is a
|
||||||
|
fixed size list of floats) then the column does not need to be specified.
|
||||||
|
If there is more than one vector column you must use
|
||||||
|
[AsyncVectorQuery::column][] to specify which column you would like to
|
||||||
|
compare with.
|
||||||
|
|
||||||
|
If no index has been created on the vector column then a vector query
|
||||||
|
will perform a distance comparison between the query vector and every
|
||||||
|
vector in the database and then sort the results. This is sometimes
|
||||||
|
called a "flat search"
|
||||||
|
|
||||||
|
For small databases, with tens of thousands of vectors or less, this can
|
||||||
|
be reasonably fast. In larger databases you should create a vector index
|
||||||
|
on the column. If there is a vector index then an "approximate" nearest
|
||||||
|
neighbor search (frequently called an ANN search) will be performed. This
|
||||||
|
search is much faster, but the results will be approximate.
|
||||||
|
|
||||||
|
The query can be further parameterized using the returned builder. There
|
||||||
|
are various ANN search parameters that will let you fine tune your recall
|
||||||
|
accuracy vs search latency.
|
||||||
|
|
||||||
|
Vector searches always have a [limit][]. If `limit` has not been called then
|
||||||
|
a default `limit` of 10 will be used.
|
||||||
|
"""
|
||||||
|
return AsyncVectorQuery(
|
||||||
|
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncVectorQuery(AsyncQueryBase):
|
||||||
|
def __init__(self, inner: LanceVectorQuery):
|
||||||
|
"""
|
||||||
|
Construct an AsyncVectorQuery
|
||||||
|
|
||||||
|
This method is not intended to be called directly. Instead, create
|
||||||
|
a query first with [Table.query][] and then use [AsyncQuery.nearest_to][]
|
||||||
|
to convert to a vector query.
|
||||||
|
"""
|
||||||
|
super().__init__(inner)
|
||||||
|
self._inner = inner
|
||||||
|
|
||||||
|
def column(self, column: str) -> AsyncVectorQuery:
|
||||||
|
"""
|
||||||
|
Set the vector column to query
|
||||||
|
|
||||||
|
This controls which column is compared to the query vector supplied in
|
||||||
|
the call to [Query.nearest_to][].
|
||||||
|
|
||||||
|
This parameter must be specified if the table has more than one column
|
||||||
|
whose data type is a fixed-size-list of floats.
|
||||||
|
"""
|
||||||
|
self._inner.column(column)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def nprobes(self, nprobes: int) -> AsyncVectorQuery:
|
||||||
|
"""
|
||||||
|
Set the number of partitions to search (probe)
|
||||||
|
|
||||||
|
This argument is only used when the vector column has an IVF PQ index.
|
||||||
|
If there is no index then this value is ignored.
|
||||||
|
|
||||||
|
The IVF stage of IVF PQ divides the input into partitions (clusters) of
|
||||||
|
related values.
|
||||||
|
|
||||||
|
The partition whose centroids are closest to the query vector will be
|
||||||
|
exhaustiely searched to find matches. This parameter controls how many
|
||||||
|
partitions should be searched.
|
||||||
|
|
||||||
|
Increasing this value will increase the recall of your query but will
|
||||||
|
also increase the latency of your query. The default value is 20. This
|
||||||
|
default is good for many cases but the best value to use will depend on
|
||||||
|
your data and the recall that you need to achieve.
|
||||||
|
|
||||||
|
For best results we recommend tuning this parameter with a benchmark against
|
||||||
|
your actual data to find the smallest possible value that will still give
|
||||||
|
you the desired recall.
|
||||||
|
"""
|
||||||
|
self._inner.nprobes(nprobes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
|
||||||
|
"""
|
||||||
|
A multiplier to control how many additional rows are taken during the refine
|
||||||
|
step
|
||||||
|
|
||||||
|
This argument is only used when the vector column has an IVF PQ index.
|
||||||
|
If there is no index then this value is ignored.
|
||||||
|
|
||||||
|
An IVF PQ index stores compressed (quantized) values. They query vector is
|
||||||
|
compared against these values and, since they are compressed, the comparison is
|
||||||
|
inaccurate.
|
||||||
|
|
||||||
|
This parameter can be used to refine the results. It can improve both improve
|
||||||
|
recall and correct the ordering of the nearest results.
|
||||||
|
|
||||||
|
To refine results LanceDb will first perform an ANN search to find the nearest
|
||||||
|
`limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
|
||||||
|
`limit` is the default (10) then the first 30 results will be selected. LanceDb
|
||||||
|
then fetches the full, uncompressed, values for these 30 results. The results
|
||||||
|
are then reordered by the true distance and only the nearest 10 are kept.
|
||||||
|
|
||||||
|
Note: there is a difference between calling this method with a value of 1 and
|
||||||
|
never calling this method at all. Calling this method with any value will have
|
||||||
|
an impact on your search latency. When you call this method with a
|
||||||
|
`refine_factor` of 1 then LanceDb still needs to fetch the full, uncompressed,
|
||||||
|
values so that it can potentially reorder the results.
|
||||||
|
|
||||||
|
Note: if this method is NOT called then the distances returned in the _distance
|
||||||
|
column will be approximate distances based on the comparison of the quantized
|
||||||
|
query vector and the quantized result vectors. This can be considerably
|
||||||
|
different than the true distance between the query vector and the actual
|
||||||
|
uncompressed vector.
|
||||||
|
"""
|
||||||
|
self._inner.refine_factor(refine_factor)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def distance_type(self, distance_type: str) -> AsyncVectorQuery:
|
||||||
|
"""
|
||||||
|
Set the distance metric to use
|
||||||
|
|
||||||
|
When performing a vector search we try and find the "nearest" vectors according
|
||||||
|
to some kind of distance metric. This parameter controls which distance metric
|
||||||
|
to use. See @see {@link IvfPqOptions.distanceType} for more details on the
|
||||||
|
different distance metrics available.
|
||||||
|
|
||||||
|
Note: if there is a vector index then the distance type used MUST match the
|
||||||
|
distance type used to train the vector index. If this is not done then the
|
||||||
|
results will be invalid.
|
||||||
|
|
||||||
|
By default "l2" is used.
|
||||||
|
"""
|
||||||
|
self._inner.distance_type(distance_type)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def postfilter(self) -> AsyncVectorQuery:
|
||||||
|
"""
|
||||||
|
If this is called then filtering will happen after the vector search instead of
|
||||||
|
before.
|
||||||
|
|
||||||
|
By default filtering will be performed before the vector search. This is how
|
||||||
|
filtering is typically understood to work. This prefilter step does add some
|
||||||
|
additional latency. Creating a scalar index on the filter column(s) can
|
||||||
|
often improve this latency. However, sometimes a filter is too complex or
|
||||||
|
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||||
|
can be used instead of prefiltering to improve latency.
|
||||||
|
|
||||||
|
Post filtering applies the filter to the results of the vector search. This
|
||||||
|
means we only run the filter on a much smaller set of data. However, it can
|
||||||
|
cause the query to return fewer than `limit` results (or even no results) if
|
||||||
|
none of the nearest results match the filter.
|
||||||
|
|
||||||
|
Post filtering happens during the "refine stage" (described in more detail in
|
||||||
|
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||||
|
factor can often help restore some of the results lost by post filtering.
|
||||||
|
"""
|
||||||
|
self._inner.postfilter()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def bypass_vector_index(self) -> AsyncVectorQuery:
|
||||||
|
"""
|
||||||
|
If this is called then any vector index is skipped
|
||||||
|
|
||||||
|
An exhaustive (flat) search will be performed. The query vector will
|
||||||
|
be compared to every vector in the table. At high scales this can be
|
||||||
|
expensive. However, this is often still useful. For example, skipping
|
||||||
|
the vector index can give you ground truth results which you can use to
|
||||||
|
calculate your recall to select an appropriate value for nprobes.
|
||||||
|
"""
|
||||||
|
self._inner.bypass_vector_index()
|
||||||
|
return self
|
||||||
|
|||||||
@@ -295,7 +295,9 @@ class RemoteTable(Table):
|
|||||||
vector_column_name = inf_vector_column_query(self.schema)
|
vector_column_name = inf_vector_column_query(self.schema)
|
||||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(
|
||||||
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
|
) -> pa.RecordBatchReader:
|
||||||
if (
|
if (
|
||||||
query.vector is not None
|
query.vector is not None
|
||||||
and len(query.vector) > 0
|
and len(query.vector) > 0
|
||||||
@@ -321,13 +323,12 @@ class RemoteTable(Table):
|
|||||||
q = query.copy()
|
q = query.copy()
|
||||||
q.vector = v
|
q.vector = v
|
||||||
results.append(submit(self._name, q))
|
results.append(submit(self._name, q))
|
||||||
|
|
||||||
return pa.concat_tables(
|
return pa.concat_tables(
|
||||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||||
)
|
).to_reader()
|
||||||
else:
|
else:
|
||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return result.to_arrow()
|
return result.to_arrow().to_reader()
|
||||||
|
|
||||||
def _do_merge(
|
def _do_merge(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -24,8 +24,59 @@ class Reranker(ABC):
|
|||||||
raise ValueError("score must be either 'relevance' or 'all'")
|
raise ValueError("score must be either 'relevance' or 'all'")
|
||||||
self.score = return_score
|
self.score = return_score
|
||||||
|
|
||||||
|
def rerank_vector(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rerank function receives the result from the vector search.
|
||||||
|
This isn't mandatory to implement
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str
|
||||||
|
The input query
|
||||||
|
vector_results : pa.Table
|
||||||
|
The results from the vector search
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pa.Table
|
||||||
|
The reranked results
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement rerank_vector"
|
||||||
|
)
|
||||||
|
|
||||||
|
def rerank_fts(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rerank function receives the result from the FTS search.
|
||||||
|
This isn't mandatory to implement
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str
|
||||||
|
The input query
|
||||||
|
fts_results : pa.Table
|
||||||
|
The results from the FTS search
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pa.Table
|
||||||
|
The reranked results
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement rerank_fts"
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
@@ -43,6 +94,11 @@ class Reranker(ABC):
|
|||||||
The results from the vector search
|
The results from the vector search
|
||||||
fts_results : pa.Table
|
fts_results : pa.Table
|
||||||
The results from the FTS search
|
The results from the FTS search
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pa.Table
|
||||||
|
The reranked results
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -49,14 +49,8 @@ class CohereReranker(Reranker):
|
|||||||
)
|
)
|
||||||
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
||||||
|
|
||||||
def rerank_hybrid(
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
self,
|
docs = result_set[self.column].to_pylist()
|
||||||
query: str,
|
|
||||||
vector_results: pa.Table,
|
|
||||||
fts_results: pa.Table,
|
|
||||||
):
|
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
|
||||||
docs = combined_results[self.column].to_pylist()
|
|
||||||
results = self._client.rerank(
|
results = self._client.rerank(
|
||||||
query=query,
|
query=query,
|
||||||
documents=docs,
|
documents=docs,
|
||||||
@@ -66,12 +60,22 @@ class CohereReranker(Reranker):
|
|||||||
indices, scores = list(
|
indices, scores = list(
|
||||||
zip(*[(result.index, result.relevance_score) for result in results])
|
zip(*[(result.index, result.relevance_score) for result in results])
|
||||||
) # tuples
|
) # tuples
|
||||||
combined_results = combined_results.take(list(indices))
|
result_set = result_set.take(list(indices))
|
||||||
# add the scores
|
# add the scores
|
||||||
combined_results = combined_results.append_column(
|
result_set = result_set.append_column(
|
||||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
combined_results = self._rerank(combined_results, query)
|
||||||
if self.score == "relevance":
|
if self.score == "relevance":
|
||||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
elif self.score == "all":
|
elif self.score == "all":
|
||||||
@@ -79,3 +83,25 @@ class CohereReranker(Reranker):
|
|||||||
"return_score='all' not implemented for cohere reranker"
|
"return_score='all' not implemented for cohere reranker"
|
||||||
)
|
)
|
||||||
return combined_results
|
return combined_results
|
||||||
|
|
||||||
|
def rerank_vector(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
):
|
||||||
|
result_set = self._rerank(vector_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
result_set = result_set.drop_columns(["_distance"])
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_fts(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
result_set = self._rerank(fts_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
result_set = result_set.drop_columns(["score"])
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|||||||
@@ -33,14 +33,8 @@ class ColbertReranker(Reranker):
|
|||||||
"torch"
|
"torch"
|
||||||
) # import here for faster ops later
|
) # import here for faster ops later
|
||||||
|
|
||||||
def rerank_hybrid(
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
self,
|
docs = result_set[self.column].to_pylist()
|
||||||
query: str,
|
|
||||||
vector_results: pa.Table,
|
|
||||||
fts_results: pa.Table,
|
|
||||||
):
|
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
|
||||||
docs = combined_results[self.column].to_pylist()
|
|
||||||
|
|
||||||
tokenizer, model = self._model
|
tokenizer, model = self._model
|
||||||
|
|
||||||
@@ -59,14 +53,25 @@ class ColbertReranker(Reranker):
|
|||||||
scores.append(score.item())
|
scores.append(score.item())
|
||||||
|
|
||||||
# replace the self.column column with the docs
|
# replace the self.column column with the docs
|
||||||
combined_results = combined_results.drop(self.column)
|
result_set = result_set.drop(self.column)
|
||||||
combined_results = combined_results.append_column(
|
result_set = result_set.append_column(
|
||||||
self.column, pa.array(docs, type=pa.string())
|
self.column, pa.array(docs, type=pa.string())
|
||||||
)
|
)
|
||||||
# add the scores
|
# add the scores
|
||||||
combined_results = combined_results.append_column(
|
result_set = result_set.append_column(
|
||||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
combined_results = self._rerank(combined_results, query)
|
||||||
if self.score == "relevance":
|
if self.score == "relevance":
|
||||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
elif self.score == "all":
|
elif self.score == "all":
|
||||||
@@ -80,6 +85,32 @@ class ColbertReranker(Reranker):
|
|||||||
|
|
||||||
return combined_results
|
return combined_results
|
||||||
|
|
||||||
|
def rerank_vector(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
):
|
||||||
|
result_set = self._rerank(vector_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
result_set = result_set.drop_columns(["_distance"])
|
||||||
|
|
||||||
|
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_fts(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
result_set = self._rerank(fts_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
result_set = result_set.drop_columns(["score"])
|
||||||
|
|
||||||
|
result_set = result_set.sort_by([("_relevance_score", "descending")])
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _model(self):
|
def _model(self):
|
||||||
transformers = attempt_import_or_raise("transformers")
|
transformers = attempt_import_or_raise("transformers")
|
||||||
|
|||||||
@@ -46,6 +46,16 @@ class CrossEncoderReranker(Reranker):
|
|||||||
|
|
||||||
return cross_encoder
|
return cross_encoder
|
||||||
|
|
||||||
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
|
passages = result_set[self.column].to_pylist()
|
||||||
|
cross_inp = [[query, passage] for passage in passages]
|
||||||
|
cross_scores = self.model.predict(cross_inp)
|
||||||
|
result_set = result_set.append_column(
|
||||||
|
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -53,13 +63,7 @@ class CrossEncoderReranker(Reranker):
|
|||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
passages = combined_results[self.column].to_pylist()
|
combined_results = self._rerank(combined_results, query)
|
||||||
cross_inp = [[query, passage] for passage in passages]
|
|
||||||
cross_scores = self.model.predict(cross_inp)
|
|
||||||
combined_results = combined_results.append_column(
|
|
||||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
|
||||||
)
|
|
||||||
|
|
||||||
# sort the results by _score
|
# sort the results by _score
|
||||||
if self.score == "relevance":
|
if self.score == "relevance":
|
||||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
@@ -72,3 +76,27 @@ class CrossEncoderReranker(Reranker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return combined_results
|
return combined_results
|
||||||
|
|
||||||
|
def rerank_vector(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
):
|
||||||
|
vector_results = self._rerank(vector_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
vector_results = vector_results.drop_columns(["_distance"])
|
||||||
|
|
||||||
|
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
return vector_results
|
||||||
|
|
||||||
|
def rerank_fts(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
fts_results = self._rerank(fts_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
fts_results = fts_results.drop_columns(["score"])
|
||||||
|
|
||||||
|
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
return fts_results
|
||||||
|
|||||||
@@ -39,14 +39,8 @@ class OpenaiReranker(Reranker):
|
|||||||
self.column = column
|
self.column = column
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
def rerank_hybrid(
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
self,
|
docs = result_set[self.column].to_pylist()
|
||||||
query: str,
|
|
||||||
vector_results: pa.Table,
|
|
||||||
fts_results: pa.Table,
|
|
||||||
):
|
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
|
||||||
docs = combined_results[self.column].to_pylist()
|
|
||||||
response = self._client.chat.completions.create(
|
response = self._client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
@@ -70,14 +64,25 @@ class OpenaiReranker(Reranker):
|
|||||||
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||||
) # tuples
|
) # tuples
|
||||||
# replace the self.column column with the docs
|
# replace the self.column column with the docs
|
||||||
combined_results = combined_results.drop(self.column)
|
result_set = result_set.drop(self.column)
|
||||||
combined_results = combined_results.append_column(
|
result_set = result_set.append_column(
|
||||||
self.column, pa.array(docs, type=pa.string())
|
self.column, pa.array(docs, type=pa.string())
|
||||||
)
|
)
|
||||||
# add the scores
|
# add the scores
|
||||||
combined_results = combined_results.append_column(
|
result_set = result_set.append_column(
|
||||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
combined_results = self._rerank(combined_results, query)
|
||||||
if self.score == "relevance":
|
if self.score == "relevance":
|
||||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
elif self.score == "all":
|
elif self.score == "all":
|
||||||
@@ -91,6 +96,24 @@ class OpenaiReranker(Reranker):
|
|||||||
|
|
||||||
return combined_results
|
return combined_results
|
||||||
|
|
||||||
|
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||||
|
vector_results = self._rerank(vector_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
vector_results = vector_results.drop_columns(["_distance"])
|
||||||
|
|
||||||
|
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
|
||||||
|
return vector_results
|
||||||
|
|
||||||
|
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||||
|
fts_results = self._rerank(fts_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
fts_results = fts_results.drop_columns(["score"])
|
||||||
|
|
||||||
|
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
|
||||||
|
return fts_results
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _client(self):
|
def _client(self):
|
||||||
openai = attempt_import_or_raise(
|
openai = attempt_import_or_raise(
|
||||||
|
|||||||
@@ -37,13 +37,14 @@ import pyarrow as pa
|
|||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
|
from lance.dependencies import _check_for_hugging_face
|
||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .merge import LanceMergeInsertBuilder
|
from .merge import LanceMergeInsertBuilder
|
||||||
from .pydantic import LanceModel, model_to_dict
|
from .pydantic import LanceModel, model_to_dict
|
||||||
from .query import LanceQueryBuilder, Query
|
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
|
||||||
from .util import (
|
from .util import (
|
||||||
fs_from_uri,
|
fs_from_uri,
|
||||||
inf_vector_column_query,
|
inf_vector_column_query,
|
||||||
@@ -74,6 +75,27 @@ def _sanitize_data(
|
|||||||
on_bad_vectors: str,
|
on_bad_vectors: str,
|
||||||
fill_value: Any,
|
fill_value: Any,
|
||||||
):
|
):
|
||||||
|
if _check_for_hugging_face(data):
|
||||||
|
# Huggingface datasets
|
||||||
|
from lance.dependencies import datasets
|
||||||
|
|
||||||
|
if isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||||
|
if schema is None:
|
||||||
|
schema = _schema_from_hf(data, schema)
|
||||||
|
data = _to_record_batch_generator(
|
||||||
|
_to_batches_with_split(data),
|
||||||
|
schema,
|
||||||
|
metadata,
|
||||||
|
on_bad_vectors,
|
||||||
|
fill_value,
|
||||||
|
)
|
||||||
|
elif isinstance(data, datasets.Dataset):
|
||||||
|
if schema is None:
|
||||||
|
schema = data.features.arrow_schema
|
||||||
|
data = _to_record_batch_generator(
|
||||||
|
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
# convert to list of dict if data is a bunch of LanceModels
|
# convert to list of dict if data is a bunch of LanceModels
|
||||||
if isinstance(data[0], LanceModel):
|
if isinstance(data[0], LanceModel):
|
||||||
@@ -110,6 +132,37 @@ def _sanitize_data(
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_from_hf(data, schema):
|
||||||
|
"""
|
||||||
|
Extract pyarrow schema from HuggingFace DatasetDict
|
||||||
|
and validate that they're all the same schema between
|
||||||
|
splits
|
||||||
|
"""
|
||||||
|
for dataset in data.values():
|
||||||
|
if schema is None:
|
||||||
|
schema = dataset.features.arrow_schema
|
||||||
|
elif schema != dataset.features.arrow_schema:
|
||||||
|
msg = "All datasets in a HuggingFace DatasetDict must have the same schema"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def _to_batches_with_split(data):
|
||||||
|
"""
|
||||||
|
Return a generator of RecordBatches from a HuggingFace DatasetDict
|
||||||
|
with an extra `split` column
|
||||||
|
"""
|
||||||
|
for key, dataset in data.items():
|
||||||
|
for batch in dataset.data.to_batches():
|
||||||
|
table = pa.Table.from_batches([batch])
|
||||||
|
if "split" not in table.column_names:
|
||||||
|
table = table.append_column(
|
||||||
|
"split", pa.array([key] * batch.num_rows, pa.string())
|
||||||
|
)
|
||||||
|
for b in table.to_batches():
|
||||||
|
yield b
|
||||||
|
|
||||||
|
|
||||||
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
|
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
|
||||||
"""
|
"""
|
||||||
Use the embedding function to automatically embed the source column and add the
|
Use the embedding function to automatically embed the source column and add the
|
||||||
@@ -144,12 +197,13 @@ def _to_record_batch_generator(
|
|||||||
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
||||||
):
|
):
|
||||||
for batch in data:
|
for batch in data:
|
||||||
if not isinstance(batch, pa.RecordBatch):
|
# always convert to table because we need to sanitize the data
|
||||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
# and do things like add the vector column etc
|
||||||
for batch in table.to_batches():
|
if isinstance(batch, pa.RecordBatch):
|
||||||
yield batch
|
batch = pa.Table.from_batches([batch])
|
||||||
else:
|
batch = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||||
yield batch
|
for b in batch.to_batches():
|
||||||
|
yield b
|
||||||
|
|
||||||
|
|
||||||
class Table(ABC):
|
class Table(ABC):
|
||||||
@@ -514,7 +568,9 @@ class Table(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(
|
||||||
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
|
) -> pa.RecordBatchReader:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -1105,6 +1161,7 @@ class LanceTable(Table):
|
|||||||
def create_fts_index(
|
def create_fts_index(
|
||||||
self,
|
self,
|
||||||
field_names: Union[str, List[str]],
|
field_names: Union[str, List[str]],
|
||||||
|
ordering_field_names: Union[str, List[str]] = None,
|
||||||
*,
|
*,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||||
@@ -1123,12 +1180,18 @@ class LanceTable(Table):
|
|||||||
not yet an atomic operation; the index will be temporarily
|
not yet an atomic operation; the index will be temporarily
|
||||||
unavailable while the new index is being created.
|
unavailable while the new index is being created.
|
||||||
writer_heap_size: int, default 1GB
|
writer_heap_size: int, default 1GB
|
||||||
|
ordering_field_names:
|
||||||
|
A list of unsigned type fields to index to optionally order
|
||||||
|
results on at search time
|
||||||
"""
|
"""
|
||||||
from .fts import create_index, populate_index
|
from .fts import create_index, populate_index
|
||||||
|
|
||||||
if isinstance(field_names, str):
|
if isinstance(field_names, str):
|
||||||
field_names = [field_names]
|
field_names = [field_names]
|
||||||
|
|
||||||
|
if isinstance(ordering_field_names, str):
|
||||||
|
ordering_field_names = [ordering_field_names]
|
||||||
|
|
||||||
fs, path = fs_from_uri(self._get_fts_index_path())
|
fs, path = fs_from_uri(self._get_fts_index_path())
|
||||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||||
if index_exists:
|
if index_exists:
|
||||||
@@ -1136,8 +1199,18 @@ class LanceTable(Table):
|
|||||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||||
fs.delete_dir(path)
|
fs.delete_dir(path)
|
||||||
|
|
||||||
index = create_index(self._get_fts_index_path(), field_names)
|
index = create_index(
|
||||||
populate_index(index, self, field_names, writer_heap_size=writer_heap_size)
|
self._get_fts_index_path(),
|
||||||
|
field_names,
|
||||||
|
ordering_fields=ordering_field_names,
|
||||||
|
)
|
||||||
|
populate_index(
|
||||||
|
index,
|
||||||
|
self,
|
||||||
|
field_names,
|
||||||
|
ordering_fields=ordering_field_names,
|
||||||
|
writer_heap_size=writer_heap_size,
|
||||||
|
)
|
||||||
register_event("create_fts_index")
|
register_event("create_fts_index")
|
||||||
|
|
||||||
def _get_fts_index_path(self):
|
def _get_fts_index_path(self):
|
||||||
@@ -1271,6 +1344,7 @@ class LanceTable(Table):
|
|||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
|
ordering_field_name: Optional[str] = None,
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector. We currently support [vector search][search]
|
of the given query vector. We currently support [vector search][search]
|
||||||
@@ -1338,7 +1412,11 @@ class LanceTable(Table):
|
|||||||
vector_column_name = inf_vector_column_query(self.schema)
|
vector_column_name = inf_vector_column_query(self.schema)
|
||||||
register_event("search_table")
|
register_event("search_table")
|
||||||
return LanceQueryBuilder.create(
|
return LanceQueryBuilder.create(
|
||||||
self, query, query_type, vector_column_name=vector_column_name
|
self,
|
||||||
|
query,
|
||||||
|
query_type,
|
||||||
|
vector_column_name=vector_column_name,
|
||||||
|
ordering_field_name=ordering_field_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1520,10 +1598,11 @@ class LanceTable(Table):
|
|||||||
self._dataset_mut.update(values_sql, where)
|
self._dataset_mut.update(values_sql, where)
|
||||||
register_event("update")
|
register_event("update")
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(
|
||||||
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
|
) -> pa.RecordBatchReader:
|
||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
|
return ds.scanner(
|
||||||
return ds.to_table(
|
|
||||||
columns=query.columns,
|
columns=query.columns,
|
||||||
filter=query.filter,
|
filter=query.filter,
|
||||||
prefilter=query.prefilter,
|
prefilter=query.prefilter,
|
||||||
@@ -1536,7 +1615,8 @@ class LanceTable(Table):
|
|||||||
"refine_factor": query.refine_factor,
|
"refine_factor": query.refine_factor,
|
||||||
},
|
},
|
||||||
with_row_id=query.with_row_id,
|
with_row_id=query.with_row_id,
|
||||||
)
|
batch_size=batch_size,
|
||||||
|
).to_reader()
|
||||||
|
|
||||||
def _do_merge(
|
def _do_merge(
|
||||||
self,
|
self,
|
||||||
@@ -1907,6 +1987,9 @@ class AsyncTable:
|
|||||||
"""
|
"""
|
||||||
return await self._inner.count_rows(filter)
|
return await self._inner.count_rows(filter)
|
||||||
|
|
||||||
|
def query(self) -> AsyncQuery:
|
||||||
|
return AsyncQuery(self._inner.query())
|
||||||
|
|
||||||
async def to_pandas(self) -> "pd.DataFrame":
|
async def to_pandas(self) -> "pd.DataFrame":
|
||||||
"""Return the table as a pandas DataFrame.
|
"""Return the table as a pandas DataFrame.
|
||||||
|
|
||||||
@@ -1914,7 +1997,7 @@ class AsyncTable:
|
|||||||
-------
|
-------
|
||||||
pd.DataFrame
|
pd.DataFrame
|
||||||
"""
|
"""
|
||||||
return self.to_arrow().to_pandas()
|
return (await self.to_arrow()).to_pandas()
|
||||||
|
|
||||||
async def to_arrow(self) -> pa.Table:
|
async def to_arrow(self) -> pa.Table:
|
||||||
"""Return the table as a pyarrow Table.
|
"""Return the table as a pyarrow Table.
|
||||||
@@ -1923,7 +2006,7 @@ class AsyncTable:
|
|||||||
-------
|
-------
|
||||||
pa.Table
|
pa.Table
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
return await self.query().to_arrow()
|
||||||
|
|
||||||
async def create_index(
|
async def create_index(
|
||||||
self,
|
self,
|
||||||
@@ -2076,89 +2159,21 @@ class AsyncTable:
|
|||||||
|
|
||||||
return LanceMergeInsertBuilder(self, on)
|
return LanceMergeInsertBuilder(self, on)
|
||||||
|
|
||||||
async def search(
|
def vector_search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query_vector: Optional[Union[VEC, Tuple]] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
) -> AsyncVectorQuery:
|
||||||
query_type: str = "auto",
|
|
||||||
) -> LanceQueryBuilder:
|
|
||||||
"""Create a search query to find the nearest neighbors
|
|
||||||
of the given query vector. We currently support [vector search][search]
|
|
||||||
and [full-text search][experimental-full-text-search].
|
|
||||||
|
|
||||||
All query options are defined in [Query][lancedb.query.Query].
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> import lancedb
|
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
|
||||||
>>> data = [
|
|
||||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
|
||||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
|
||||||
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
|
||||||
... ]
|
|
||||||
>>> table = db.create_table("my_table", data)
|
|
||||||
>>> query = [0.4, 1.4, 2.4]
|
|
||||||
>>> (table.search(query)
|
|
||||||
... .where("original_width > 1000", prefilter=True)
|
|
||||||
... .select(["caption", "original_width", "vector"])
|
|
||||||
... .limit(2)
|
|
||||||
... .to_pandas())
|
|
||||||
caption original_width vector _distance
|
|
||||||
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
|
||||||
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
|
||||||
The targetted vector to search for.
|
|
||||||
|
|
||||||
- *default None*.
|
|
||||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
|
||||||
|
|
||||||
- If None then the select/where/limit clauses are applied to filter
|
|
||||||
the table
|
|
||||||
vector_column_name: str, optional
|
|
||||||
The name of the vector column to search.
|
|
||||||
|
|
||||||
The vector column needs to be a pyarrow fixed size list type
|
|
||||||
|
|
||||||
- If not specified then the vector column is inferred from
|
|
||||||
the table schema
|
|
||||||
|
|
||||||
- If the table has multiple vector columns then the *vector_column_name*
|
|
||||||
needs to be specified. Otherwise, an error is raised.
|
|
||||||
query_type: str
|
|
||||||
*default "auto"*.
|
|
||||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
|
||||||
|
|
||||||
- If "auto" then the query type is inferred from the query;
|
|
||||||
|
|
||||||
- If `query` is a list/np.ndarray then the query type is
|
|
||||||
"vector";
|
|
||||||
|
|
||||||
- If `query` is a PIL.Image.Image then either do vector search,
|
|
||||||
or raise an error if no corresponding embedding function is found.
|
|
||||||
|
|
||||||
- If `query` is a string, then the query type is "vector" if the
|
|
||||||
table has embedding functions else the query type is "fts"
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
LanceQueryBuilder
|
|
||||||
A query builder object representing the query.
|
|
||||||
Once executed, the query returns
|
|
||||||
|
|
||||||
- selected columns
|
|
||||||
|
|
||||||
- the vector
|
|
||||||
|
|
||||||
- and also the "_distance" column which is the distance between the query
|
|
||||||
vector and the returned vector.
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
Search the table with a given query vector.
|
||||||
|
This is a convenience method for preparing a vector query and
|
||||||
|
is the same thing as calling `nearestTo` on the builder returned
|
||||||
|
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details.
|
||||||
|
"""
|
||||||
|
return self.query().nearest_to(query_vector)
|
||||||
|
|
||||||
async def _execute_query(self, query: Query) -> pa.Table:
|
async def _execute_query(
|
||||||
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
|
) -> pa.RecordBatchReader:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _do_merge(
|
async def _do_merge(
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
|||||||
)
|
)
|
||||||
for _ in range(100)
|
for _ in range(100)
|
||||||
]
|
]
|
||||||
|
count = [random.randint(1, 10000) for _ in range(100)]
|
||||||
table = db.create_table(
|
table = db.create_table(
|
||||||
"test",
|
"test",
|
||||||
data=pd.DataFrame(
|
data=pd.DataFrame(
|
||||||
@@ -52,6 +53,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
|||||||
"text": text,
|
"text": text,
|
||||||
"text2": text,
|
"text2": text,
|
||||||
"nested": [{"text": t} for t in text],
|
"nested": [{"text": t} for t in text],
|
||||||
|
"count": count,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -79,6 +81,39 @@ def test_search_index(tmp_path, table):
|
|||||||
assert len(results[1]) == 10 # _distance
|
assert len(results[1]) == 10 # _distance
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_ordering_field_index_table(tmp_path, table):
|
||||||
|
table.create_fts_index("text", ordering_field_names=["count"])
|
||||||
|
rows = (
|
||||||
|
table.search("puppy", ordering_field_name="count")
|
||||||
|
.limit(20)
|
||||||
|
.select(["text", "count"])
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
for r in rows:
|
||||||
|
assert "puppy" in r["text"]
|
||||||
|
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_ordering_field_index(tmp_path, table):
|
||||||
|
index = ldb.fts.create_index(
|
||||||
|
str(tmp_path / "index"), ["text"], ordering_fields=["count"]
|
||||||
|
)
|
||||||
|
|
||||||
|
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
|
||||||
|
index.reload()
|
||||||
|
results = ldb.fts.search_index(
|
||||||
|
index, query="puppy", limit=10, ordering_field="count"
|
||||||
|
)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert len(results[0]) == 10 # row_ids
|
||||||
|
assert len(results[1]) == 10 # _distance
|
||||||
|
rows = table.to_lance().take(results[0]).to_pylist()
|
||||||
|
|
||||||
|
for r in rows:
|
||||||
|
assert "puppy" in r["text"]
|
||||||
|
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
|
||||||
|
|
||||||
|
|
||||||
def test_create_index_from_table(tmp_path, table):
|
def test_create_index_from_table(tmp_path, table):
|
||||||
table.create_fts_index("text")
|
table.create_fts_index("text")
|
||||||
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
|
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
|
||||||
@@ -94,6 +129,7 @@ def test_create_index_from_table(tmp_path, table):
|
|||||||
"text": "gorilla",
|
"text": "gorilla",
|
||||||
"text2": "gorilla",
|
"text2": "gorilla",
|
||||||
"nested": {"text": "gorilla"},
|
"nested": {"text": "gorilla"},
|
||||||
|
"count": 10,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -166,6 +202,7 @@ def test_null_input(table):
|
|||||||
"text": None,
|
"text": None,
|
||||||
"text2": None,
|
"text2": None,
|
||||||
"nested": {"text": None},
|
"nested": {"text": None},
|
||||||
|
"count": 7,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
126
python/python/tests/test_huggingface.py
Normal file
126
python/python/tests/test_huggingface.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
# Copyright 2024 Lance Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||||
|
from lancedb.embeddings.registry import register
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
|
datasets = pytest.importorskip("datasets")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mock_embedding_function():
|
||||||
|
@register("random")
|
||||||
|
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||||
|
def generate_embeddings(self, texts):
|
||||||
|
return [np.random.randn(128).tolist() for _ in range(len(texts))]
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
return 128
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_hf_dataset():
|
||||||
|
# Create pyarrow table with `text` and `label` columns
|
||||||
|
train = datasets.Dataset(
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"text": ["foo", "bar"],
|
||||||
|
"label": [0, 1],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
|
||||||
|
test = datasets.Dataset(
|
||||||
|
pa.table(
|
||||||
|
{
|
||||||
|
"text": ["fizz", "buzz"],
|
||||||
|
"label": [0, 1],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
split="test",
|
||||||
|
)
|
||||||
|
return datasets.DatasetDict({"train": train, "test": test})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hf_dataset_with_split():
|
||||||
|
# Create pyarrow table with `text` and `label` columns
|
||||||
|
train = datasets.Dataset(
|
||||||
|
pa.table(
|
||||||
|
{"text": ["foo", "bar"], "label": [0, 1], "split": ["train", "train"]}
|
||||||
|
),
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
|
||||||
|
test = datasets.Dataset(
|
||||||
|
pa.table(
|
||||||
|
{"text": ["fizz", "buzz"], "label": [0, 1], "split": ["test", "test"]}
|
||||||
|
),
|
||||||
|
split="test",
|
||||||
|
)
|
||||||
|
return datasets.DatasetDict({"train": train, "test": test})
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_hf_dataset(tmp_path: Path, mock_embedding_function, mock_hf_dataset):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
emb = get_registry().get("random").create()
|
||||||
|
|
||||||
|
class Schema(LanceModel):
|
||||||
|
text: str = emb.SourceField()
|
||||||
|
label: int
|
||||||
|
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||||
|
|
||||||
|
train_table = db.create_table("train", schema=Schema)
|
||||||
|
train_table.add(mock_hf_dataset["train"])
|
||||||
|
|
||||||
|
class WithSplit(LanceModel):
|
||||||
|
text: str = emb.SourceField()
|
||||||
|
label: int
|
||||||
|
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||||
|
split: str
|
||||||
|
|
||||||
|
full_table = db.create_table("full", schema=WithSplit)
|
||||||
|
full_table.add(mock_hf_dataset)
|
||||||
|
|
||||||
|
assert len(train_table) == mock_hf_dataset["train"].num_rows
|
||||||
|
assert len(full_table) == sum(ds.num_rows for ds in mock_hf_dataset.values())
|
||||||
|
|
||||||
|
rt_train_table = full_table.to_lance().to_table(
|
||||||
|
columns=["text", "label"], filter="split='train'"
|
||||||
|
)
|
||||||
|
assert rt_train_table.to_pylist() == mock_hf_dataset["train"].data.to_pylist()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_hf_dataset(tmp_path: Path, mock_embedding_function, hf_dataset_with_split):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
emb = get_registry().get("random").create()
|
||||||
|
|
||||||
|
class Schema(LanceModel):
|
||||||
|
text: str = emb.SourceField()
|
||||||
|
label: int
|
||||||
|
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||||
|
split: str
|
||||||
|
|
||||||
|
train_table = db.create_table("train", schema=Schema)
|
||||||
|
# this should still work because we don't add the split column
|
||||||
|
# if it already exists
|
||||||
|
train_table.add(hf_dataset_with_split)
|
||||||
@@ -12,16 +12,20 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
|
import lancedb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas.testing as tm
|
import pandas.testing as tm
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
from lancedb.db import LanceDBConnection
|
from lancedb.db import LanceDBConnection
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import AsyncTable, LanceTable
|
||||||
|
|
||||||
|
|
||||||
class MockTable:
|
class MockTable:
|
||||||
@@ -32,9 +36,9 @@ class MockTable:
|
|||||||
def to_lance(self):
|
def to_lance(self):
|
||||||
return lance.dataset(self.uri)
|
return lance.dataset(self.uri)
|
||||||
|
|
||||||
def _execute_query(self, query):
|
def _execute_query(self, query, batch_size: Optional[int] = None):
|
||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
return ds.to_table(
|
return ds.scanner(
|
||||||
columns=query.columns,
|
columns=query.columns,
|
||||||
filter=query.filter,
|
filter=query.filter,
|
||||||
prefilter=query.prefilter,
|
prefilter=query.prefilter,
|
||||||
@@ -46,7 +50,8 @@ class MockTable:
|
|||||||
"nprobes": query.nprobes,
|
"nprobes": query.nprobes,
|
||||||
"refine_factor": query.refine_factor,
|
"refine_factor": query.refine_factor,
|
||||||
},
|
},
|
||||||
)
|
batch_size=batch_size,
|
||||||
|
).to_reader()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -65,6 +70,24 @@ def table(tmp_path) -> MockTable:
|
|||||||
return MockTable(tmp_path)
|
return MockTable(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def table_async(tmp_path) -> AsyncTable:
|
||||||
|
conn = await lancedb.connect_async(
|
||||||
|
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"vector": pa.array(
|
||||||
|
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
|
||||||
|
),
|
||||||
|
"id": pa.array([1, 2]),
|
||||||
|
"str_field": pa.array(["a", "b"]),
|
||||||
|
"float_field": pa.array([1.0, 2.0]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return await conn.create_table("test", data)
|
||||||
|
|
||||||
|
|
||||||
def test_cast(table):
|
def test_cast(table):
|
||||||
class TestModel(LanceModel):
|
class TestModel(LanceModel):
|
||||||
vector: Vector(2)
|
vector: Vector(2)
|
||||||
@@ -94,6 +117,25 @@ def test_query_builder(table):
|
|||||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_builder_batches(table):
|
||||||
|
rs = (
|
||||||
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
|
.limit(2)
|
||||||
|
.select(["id", "vector"])
|
||||||
|
.to_batches(1)
|
||||||
|
)
|
||||||
|
rs_list = []
|
||||||
|
for item in rs:
|
||||||
|
rs_list.append(item)
|
||||||
|
assert isinstance(item, pa.RecordBatch)
|
||||||
|
assert len(rs_list) == 1
|
||||||
|
assert len(rs_list[0]["id"]) == 2
|
||||||
|
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
|
||||||
|
assert rs_list[0].to_pandas()["id"][0] == 1
|
||||||
|
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
|
||||||
|
assert rs_list[0].to_pandas()["id"][1] == 2
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_projection(table):
|
def test_dynamic_projection(table):
|
||||||
rs = (
|
rs = (
|
||||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
@@ -178,9 +220,116 @@ def test_query_builder_with_different_vector_column():
|
|||||||
nprobes=20,
|
nprobes=20,
|
||||||
refine_factor=None,
|
refine_factor=None,
|
||||||
vector_column="foo_vector",
|
vector_column="foo_vector",
|
||||||
)
|
),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cosine_distance(vec1, vec2):
|
def cosine_distance(vec1, vec2):
|
||||||
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||||
|
|
||||||
|
|
||||||
|
async def check_query(
|
||||||
|
query: AsyncQueryBase, *, expected_num_rows=None, expected_columns=None
|
||||||
|
):
|
||||||
|
num_rows = 0
|
||||||
|
results = await query.to_batches()
|
||||||
|
async for batch in results:
|
||||||
|
if expected_columns is not None:
|
||||||
|
assert batch.schema.names == expected_columns
|
||||||
|
num_rows += batch.num_rows
|
||||||
|
if expected_num_rows is not None:
|
||||||
|
assert num_rows == expected_num_rows
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_async(table_async: AsyncTable):
|
||||||
|
await check_query(
|
||||||
|
table_async.query(),
|
||||||
|
expected_num_rows=2,
|
||||||
|
expected_columns=["vector", "id", "str_field", "float_field"],
|
||||||
|
)
|
||||||
|
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().select(["id", "vector"]), expected_columns=["id", "vector"]
|
||||||
|
)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().select({"foo": "id", "bar": "id + 1"}),
|
||||||
|
expected_columns=["foo", "bar"],
|
||||||
|
)
|
||||||
|
await check_query(table_async.query().limit(1), expected_num_rows=1)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
|
||||||
|
)
|
||||||
|
# Support different types of inputs for the vector query
|
||||||
|
for vector_query in [
|
||||||
|
[1, 2],
|
||||||
|
[1.0, 2.0],
|
||||||
|
np.array([1, 2]),
|
||||||
|
(1, 2),
|
||||||
|
]:
|
||||||
|
await check_query(
|
||||||
|
table_async.query().nearest_to(vector_query), expected_num_rows=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# No easy way to check these vector query parameters are doing what they say. We
|
||||||
|
# just check that they don't raise exceptions and assume this is tested at a lower
|
||||||
|
# level.
|
||||||
|
await check_query(
|
||||||
|
table_async.query().where("id = 2").nearest_to(pa.array([1, 2])).postfilter(),
|
||||||
|
expected_num_rows=1,
|
||||||
|
)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().nearest_to(pa.array([1, 2])).refine_factor(1),
|
||||||
|
expected_num_rows=2,
|
||||||
|
)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
|
||||||
|
expected_num_rows=2,
|
||||||
|
)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
|
||||||
|
expected_num_rows=2,
|
||||||
|
)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().nearest_to(pa.array([1, 2])).distance_type("dot"),
|
||||||
|
expected_num_rows=2,
|
||||||
|
)
|
||||||
|
await check_query(
|
||||||
|
table_async.query().nearest_to(pa.array([1, 2])).distance_type("DoT"),
|
||||||
|
expected_num_rows=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure we can use a vector query as a base query (e.g. call limit on it)
|
||||||
|
# Also make sure `vector_search` works
|
||||||
|
await check_query(table_async.vector_search([1, 2]).limit(1), expected_num_rows=1)
|
||||||
|
|
||||||
|
# Also check an empty query
|
||||||
|
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||||
|
table = await table_async.to_arrow()
|
||||||
|
assert table.num_rows == 2
|
||||||
|
assert table.num_columns == 4
|
||||||
|
|
||||||
|
table = await table_async.query().to_arrow()
|
||||||
|
assert table.num_rows == 2
|
||||||
|
assert table.num_columns == 4
|
||||||
|
|
||||||
|
table = await table_async.query().where("id < 0").to_arrow()
|
||||||
|
assert table.num_rows == 0
|
||||||
|
assert table.num_columns == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||||
|
df = await table_async.to_pandas()
|
||||||
|
assert df.shape == (2, 4)
|
||||||
|
|
||||||
|
df = await table_async.query().to_pandas()
|
||||||
|
assert df.shape == (2, 4)
|
||||||
|
|
||||||
|
df = await table_async.query().where("id < 0").to_pandas()
|
||||||
|
assert df.shape == (0, 4)
|
||||||
|
|||||||
@@ -124,8 +124,9 @@ def test_linear_combination(tmp_path):
|
|||||||
)
|
)
|
||||||
def test_cohere_reranker(tmp_path):
|
def test_cohere_reranker(tmp_path):
|
||||||
pytest.importorskip("cohere")
|
pytest.importorskip("cohere")
|
||||||
|
reranker = CohereReranker()
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
# The default reranker
|
# Hybrid search setting
|
||||||
result1 = (
|
result1 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="score", reranker=CohereReranker())
|
.rerank(normalize="score", reranker=CohereReranker())
|
||||||
@@ -133,7 +134,7 @@ def test_cohere_reranker(tmp_path):
|
|||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(reranker=CohereReranker())
|
.rerank(reranker=reranker)
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
@@ -143,64 +144,120 @@ def test_cohere_reranker(tmp_path):
|
|||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query))
|
table.search((query_vector, query))
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=CohereReranker())
|
.rerank(reranker=reranker)
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 30
|
assert len(result) == 30
|
||||||
|
err = (
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
# Vector search setting
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
assert len(result) == 30
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
result_explicit = (
|
||||||
|
table.search(query_vector)
|
||||||
|
.rerank(reranker=reranker, query_string=query)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result_explicit) == 30
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError
|
||||||
|
): # This raises an error because vector query is provided without reanking query
|
||||||
|
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
|
||||||
|
# FTS search setting
|
||||||
|
result = (
|
||||||
|
table.search(query, query_type="fts")
|
||||||
|
.rerank(reranker=reranker)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
|
||||||
def test_cross_encoder_reranker(tmp_path):
|
def test_cross_encoder_reranker(tmp_path):
|
||||||
pytest.importorskip("sentence_transformers")
|
pytest.importorskip("sentence_transformers")
|
||||||
|
reranker = CrossEncoderReranker()
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
result1 = (
|
result1 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="score", reranker=CrossEncoderReranker())
|
.rerank(normalize="score", reranker=reranker)
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(reranker=CrossEncoderReranker())
|
.rerank(reranker=reranker)
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
|
|
||||||
# test explicit hybrid query
|
|
||||||
query = "Our father who art in heaven"
|
query = "Our father who art in heaven"
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query), query_type="hybrid")
|
table.search((query_vector, query), query_type="hybrid")
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=CrossEncoderReranker())
|
.rerank(reranker=reranker)
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 30
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
err = (
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
# Vector search setting
|
||||||
|
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
assert len(result) == 30
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
result_explicit = (
|
||||||
|
table.search(query_vector)
|
||||||
|
.rerank(reranker=reranker, query_string=query)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result_explicit) == 30
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError
|
||||||
|
): # This raises an error because vector query is provided without reanking query
|
||||||
|
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
|
||||||
|
# FTS search setting
|
||||||
|
result = (
|
||||||
|
table.search(query, query_type="fts")
|
||||||
|
.rerank(reranker=reranker)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
|
||||||
def test_colbert_reranker(tmp_path):
|
def test_colbert_reranker(tmp_path):
|
||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
|
reranker = ColbertReranker()
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
result1 = (
|
result1 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="score", reranker=ColbertReranker())
|
.rerank(normalize="score", reranker=reranker)
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(reranker=ColbertReranker())
|
.rerank(reranker=reranker)
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
@@ -211,17 +268,43 @@ def test_colbert_reranker(tmp_path):
|
|||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query))
|
table.search((query_vector, query))
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=ColbertReranker())
|
.rerank(reranker=reranker)
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 30
|
assert len(result) == 30
|
||||||
|
err = (
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
# Vector search setting
|
||||||
|
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
assert len(result) == 30
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
result_explicit = (
|
||||||
|
table.search(query_vector)
|
||||||
|
.rerank(reranker=reranker, query_string=query)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result_explicit) == 30
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError
|
||||||
|
): # This raises an error because vector query is provided without reanking query
|
||||||
|
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
|
||||||
|
# FTS search setting
|
||||||
|
result = (
|
||||||
|
table.search(query, query_type="fts")
|
||||||
|
.rerank(reranker=reranker)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@@ -230,9 +313,10 @@ def test_colbert_reranker(tmp_path):
|
|||||||
def test_openai_reranker(tmp_path):
|
def test_openai_reranker(tmp_path):
|
||||||
pytest.importorskip("openai")
|
pytest.importorskip("openai")
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
|
reranker = OpenaiReranker()
|
||||||
result1 = (
|
result1 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="score", reranker=OpenaiReranker())
|
.rerank(normalize="score", reranker=reranker)
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
@@ -248,14 +332,40 @@ def test_openai_reranker(tmp_path):
|
|||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query))
|
table.search((query_vector, query))
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=OpenaiReranker())
|
.rerank(reranker=reranker)
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 30
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
err = (
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
# Vector search setting
|
||||||
|
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
assert len(result) == 30
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
result_explicit = (
|
||||||
|
table.search(query_vector)
|
||||||
|
.rerank(reranker=reranker, query_string=query)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result_explicit) == 30
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError
|
||||||
|
): # This raises an error because vector query is provided without reanking query
|
||||||
|
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||||
|
# FTS search setting
|
||||||
|
result = (
|
||||||
|
table.search(query, query_type="fts")
|
||||||
|
.rerank(reranker=reranker)
|
||||||
|
.limit(30)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|||||||
51
python/src/arrow.rs
Normal file
51
python/src/arrow.rs
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// use arrow::datatypes::SchemaRef;
|
||||||
|
// use lancedb::arrow::SendableRecordBatchStream;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow::{
|
||||||
|
datatypes::SchemaRef,
|
||||||
|
pyarrow::{IntoPyArrow, ToPyArrow},
|
||||||
|
};
|
||||||
|
use futures::stream::StreamExt;
|
||||||
|
use lancedb::arrow::SendableRecordBatchStream;
|
||||||
|
use pyo3::{pyclass, pymethods, PyAny, PyObject, PyRef, PyResult, Python};
|
||||||
|
use pyo3_asyncio::tokio::future_into_py;
|
||||||
|
|
||||||
|
use crate::error::PythonErrorExt;
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct RecordBatchStream {
|
||||||
|
schema: SchemaRef,
|
||||||
|
inner: Arc<tokio::sync::Mutex<SendableRecordBatchStream>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RecordBatchStream {
|
||||||
|
pub fn new(inner: SendableRecordBatchStream) -> Self {
|
||||||
|
let schema = inner.schema().clone();
|
||||||
|
Self {
|
||||||
|
schema,
|
||||||
|
inner: Arc::new(tokio::sync::Mutex::new(inner)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl RecordBatchStream {
|
||||||
|
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
|
||||||
|
(*self.schema).clone().into_pyarrow(py)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let inner_next = inner.lock().await.next().await;
|
||||||
|
inner_next
|
||||||
|
.map(|item| {
|
||||||
|
let item = item.infer_error()?;
|
||||||
|
Python::with_gil(|py| item.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
.transpose()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -95,7 +95,7 @@ impl Connection {
|
|||||||
|
|
||||||
let mode = Self::parse_create_mode_str(mode)?;
|
let mode = Self::parse_create_mode_str(mode)?;
|
||||||
|
|
||||||
let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?);
|
let batches = ArrowArrayStreamReader::from_pyarrow(data)?;
|
||||||
future_into_py(self_.py(), async move {
|
future_into_py(self_.py(), async move {
|
||||||
let table = inner
|
let table = inner
|
||||||
.create_table(name, batches)
|
.create_table(name, batches)
|
||||||
|
|||||||
@@ -12,15 +12,19 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use arrow::RecordBatchStream;
|
||||||
use connection::{connect, Connection};
|
use connection::{connect, Connection};
|
||||||
use env_logger::Env;
|
use env_logger::Env;
|
||||||
use index::{Index, IndexConfig};
|
use index::{Index, IndexConfig};
|
||||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||||
|
use query::{Query, VectorQuery};
|
||||||
use table::Table;
|
use table::Table;
|
||||||
|
|
||||||
|
pub mod arrow;
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
|
pub mod query;
|
||||||
pub mod table;
|
pub mod table;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
||||||
@@ -34,6 +38,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_class::<Table>()?;
|
m.add_class::<Table>()?;
|
||||||
m.add_class::<Index>()?;
|
m.add_class::<Index>()?;
|
||||||
m.add_class::<IndexConfig>()?;
|
m.add_class::<IndexConfig>()?;
|
||||||
|
m.add_class::<Query>()?;
|
||||||
|
m.add_class::<VectorQuery>()?;
|
||||||
|
m.add_class::<RecordBatchStream>()?;
|
||||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
125
python/src/query.rs
Normal file
125
python/src/query.rs
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use arrow::array::make_array;
|
||||||
|
use arrow::array::ArrayData;
|
||||||
|
use arrow::pyarrow::FromPyArrow;
|
||||||
|
use lancedb::query::{
|
||||||
|
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||||
|
};
|
||||||
|
use pyo3::pyclass;
|
||||||
|
use pyo3::pymethods;
|
||||||
|
use pyo3::PyAny;
|
||||||
|
use pyo3::PyRef;
|
||||||
|
use pyo3::PyResult;
|
||||||
|
use pyo3_asyncio::tokio::future_into_py;
|
||||||
|
|
||||||
|
use crate::arrow::RecordBatchStream;
|
||||||
|
use crate::error::PythonErrorExt;
|
||||||
|
use crate::util::parse_distance_type;
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct Query {
|
||||||
|
inner: LanceDbQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Query {
|
||||||
|
pub fn new(query: LanceDbQuery) -> Self {
|
||||||
|
Self { inner: query }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl Query {
|
||||||
|
pub fn r#where(&mut self, predicate: String) {
|
||||||
|
self.inner = self.inner.clone().only_if(predicate);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||||
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn limit(&mut self, limit: u32) {
|
||||||
|
self.inner = self.inner.clone().limit(limit as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn nearest_to(&mut self, vector: &PyAny) -> PyResult<VectorQuery> {
|
||||||
|
let data: ArrayData = ArrayData::from_pyarrow(vector)?;
|
||||||
|
let array = make_array(data);
|
||||||
|
let inner = self.inner.clone().nearest_to(array).infer_error()?;
|
||||||
|
Ok(VectorQuery { inner })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let inner_stream = inner.execute().await.infer_error()?;
|
||||||
|
Ok(RecordBatchStream::new(inner_stream))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct VectorQuery {
|
||||||
|
inner: LanceDbVectorQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl VectorQuery {
|
||||||
|
pub fn r#where(&mut self, predicate: String) {
|
||||||
|
self.inner = self.inner.clone().only_if(predicate);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||||
|
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn limit(&mut self, limit: u32) {
|
||||||
|
self.inner = self.inner.clone().limit(limit as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn column(&mut self, column: String) {
|
||||||
|
self.inner = self.inner.clone().column(&column);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn distance_type(&mut self, distance_type: String) -> PyResult<()> {
|
||||||
|
let distance_type = parse_distance_type(distance_type)?;
|
||||||
|
self.inner = self.inner.clone().distance_type(distance_type);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn postfilter(&mut self) {
|
||||||
|
self.inner = self.inner.clone().postfilter();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn refine_factor(&mut self, refine_factor: u32) {
|
||||||
|
self.inner = self.inner.clone().refine_factor(refine_factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn nprobes(&mut self, nprobe: u32) {
|
||||||
|
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bypass_vector_index(&mut self) {
|
||||||
|
self.inner = self.inner.clone().bypass_vector_index()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let inner_stream = inner.execute().await.infer_error()?;
|
||||||
|
Ok(RecordBatchStream::new(inner_stream))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ use pyo3_asyncio::tokio::future_into_py;
|
|||||||
use crate::{
|
use crate::{
|
||||||
error::PythonErrorExt,
|
error::PythonErrorExt,
|
||||||
index::{Index, IndexConfig},
|
index::{Index, IndexConfig},
|
||||||
|
query::Query,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
@@ -63,7 +64,7 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn add<'a>(self_: PyRef<'a, Self>, data: &PyAny, mode: String) -> PyResult<&'a PyAny> {
|
pub fn add<'a>(self_: PyRef<'a, Self>, data: &PyAny, mode: String) -> PyResult<&'a PyAny> {
|
||||||
let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?);
|
let batches = ArrowArrayStreamReader::from_pyarrow(data)?;
|
||||||
let mut op = self_.inner_ref()?.add(batches);
|
let mut op = self_.inner_ref()?.add(batches);
|
||||||
if mode == "append" {
|
if mode == "append" {
|
||||||
op = op.mode(AddDataMode::Append);
|
op = op.mode(AddDataMode::Append);
|
||||||
@@ -179,4 +180,8 @@ impl Table {
|
|||||||
async move { inner.restore().await.infer_error() },
|
async move { inner.restore().await.infer_error() },
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn query(&self) -> Query {
|
||||||
|
Query::new(self.inner_ref().unwrap().query())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use pyo3::{exceptions::PyRuntimeError, PyResult};
|
use lancedb::DistanceType;
|
||||||
|
use pyo3::{
|
||||||
|
exceptions::{PyRuntimeError, PyValueError},
|
||||||
|
PyResult,
|
||||||
|
};
|
||||||
|
|
||||||
/// A wrapper around a rust builder
|
/// A wrapper around a rust builder
|
||||||
///
|
///
|
||||||
@@ -33,3 +37,15 @@ impl<T> BuilderWrapper<T> {
|
|||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceType> {
|
||||||
|
match distance_type.as_ref().to_lowercase().as_str() {
|
||||||
|
"l2" => Ok(DistanceType::L2),
|
||||||
|
"cosine" => Ok(DistanceType::Cosine),
|
||||||
|
"dot" => Ok(DistanceType::Dot),
|
||||||
|
_ => Err(PyValueError::new_err(format!(
|
||||||
|
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||||
|
distance_type.as_ref()
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,9 +12,9 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use lance_linalg::distance::MetricType;
|
|
||||||
use lancedb::index::vector::IvfPqIndexBuilder;
|
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||||
use lancedb::index::Index;
|
use lancedb::index::Index;
|
||||||
|
use lancedb::DistanceType;
|
||||||
use neon::context::FunctionContext;
|
use neon::context::FunctionContext;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
@@ -72,8 +72,8 @@ fn get_index_params_builder(
|
|||||||
}
|
}
|
||||||
let mut builder = IvfPqIndexBuilder::default();
|
let mut builder = IvfPqIndexBuilder::default();
|
||||||
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
|
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
|
||||||
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
|
let distance_type = DistanceType::try_from(metric_type.value(cx).as_str())?;
|
||||||
builder = builder.distance_type(metric_type);
|
builder = builder.distance_type(distance_type);
|
||||||
}
|
}
|
||||||
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
|
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
|
||||||
builder = builder.num_partitions(np);
|
builder = builder.num_partitions(np);
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ use std::convert::TryFrom;
|
|||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
|
|
||||||
use futures::{TryFutureExt, TryStreamExt};
|
use futures::{TryFutureExt, TryStreamExt};
|
||||||
use lance_linalg::distance::MetricType;
|
use lancedb::query::{ExecutableQuery, QueryBase, Select};
|
||||||
|
use lancedb::DistanceType;
|
||||||
use neon::context::FunctionContext;
|
use neon::context::FunctionContext;
|
||||||
use neon::handle::Handle;
|
use neon::handle::Handle;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
@@ -56,53 +57,72 @@ impl JsQuery {
|
|||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
|
||||||
let mut builder = table.query();
|
let mut builder = table.query();
|
||||||
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
|
|
||||||
builder = builder.nearest_to(&query);
|
|
||||||
if let Some(metric_type) = query_obj
|
|
||||||
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
|
||||||
.map(|s| s.value(&mut cx))
|
|
||||||
.map(|s| MetricType::try_from(s.as_str()).unwrap())
|
|
||||||
{
|
|
||||||
builder = builder.metric_type(metric_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
|
|
||||||
builder = builder.nprobes(nprobes);
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(filter) = query_obj
|
if let Some(filter) = query_obj
|
||||||
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
|
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
|
||||||
.map(|s| s.value(&mut cx))
|
.map(|s| s.value(&mut cx))
|
||||||
{
|
{
|
||||||
builder = builder.filter(filter);
|
builder = builder.only_if(filter);
|
||||||
}
|
}
|
||||||
if let Some(select) = select {
|
if let Some(select) = select {
|
||||||
builder = builder.select(select.as_slice());
|
builder = builder.select(Select::columns(select.as_slice()));
|
||||||
}
|
}
|
||||||
if let Some(limit) = limit {
|
if let Some(limit) = limit {
|
||||||
builder = builder.limit(limit as usize);
|
builder = builder.limit(limit as usize);
|
||||||
};
|
};
|
||||||
|
|
||||||
builder = builder.prefilter(prefilter);
|
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||||
|
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
|
||||||
|
let mut vector_builder = builder.nearest_to(query).unwrap();
|
||||||
|
if let Some(distance_type) = query_obj
|
||||||
|
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
||||||
|
.map(|s| s.value(&mut cx))
|
||||||
|
.map(|s| DistanceType::try_from(s.as_str()).unwrap())
|
||||||
|
{
|
||||||
|
vector_builder = vector_builder.distance_type(distance_type);
|
||||||
|
}
|
||||||
|
|
||||||
rt.spawn(async move {
|
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
|
||||||
let record_batch_stream = builder.execute_stream();
|
vector_builder = vector_builder.nprobes(nprobes);
|
||||||
let results = record_batch_stream
|
|
||||||
.and_then(|stream| {
|
|
||||||
stream
|
|
||||||
.try_collect::<Vec<_>>()
|
|
||||||
.map_err(lancedb::error::Error::from)
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
if !prefilter {
|
||||||
let results = results.or_throw(&mut cx)?;
|
vector_builder = vector_builder.postfilter();
|
||||||
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
|
}
|
||||||
convert::new_js_buffer(buffer, &mut cx, is_electron)
|
rt.spawn(async move {
|
||||||
|
let results = vector_builder
|
||||||
|
.execute()
|
||||||
|
.and_then(|stream| {
|
||||||
|
stream
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.map_err(lancedb::error::Error::from)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
let results = results.or_throw(&mut cx)?;
|
||||||
|
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
|
||||||
|
convert::new_js_buffer(buffer, &mut cx, is_electron)
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
} else {
|
||||||
|
rt.spawn(async move {
|
||||||
|
let results = builder
|
||||||
|
.execute()
|
||||||
|
.and_then(|stream| {
|
||||||
|
stream
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.map_err(lancedb::error::Error::from)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
let results = results.or_throw(&mut cx)?;
|
||||||
|
let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?;
|
||||||
|
convert::new_js_buffer(buffer, &mut cx, is_electron)
|
||||||
|
});
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ impl JsTable {
|
|||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||||
let table_rst = database
|
let table_rst = database
|
||||||
.create_table(&table_name, Box::new(batch_reader))
|
.create_table(&table_name, batch_reader)
|
||||||
.write_options(WriteOptions {
|
.write_options(WriteOptions {
|
||||||
lance_write_params: Some(params),
|
lance_write_params: Some(params),
|
||||||
})
|
})
|
||||||
@@ -126,7 +126,7 @@ impl JsTable {
|
|||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||||
let add_result = table
|
let add_result = table
|
||||||
.add(Box::new(batch_reader))
|
.add(batch_reader)
|
||||||
.write_options(WriteOptions {
|
.write_options(WriteOptions {
|
||||||
lance_write_params: Some(params),
|
lance_write_params: Some(params),
|
||||||
})
|
})
|
||||||
|
|||||||
165
rust/lancedb/examples/ivf_pq.rs
Normal file
165
rust/lancedb/examples/ivf_pq.rs
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! This example demonstrates setting advanced parameters when building an IVF PQ index
|
||||||
|
//!
|
||||||
|
//! Snippets from this example are used in the documentation on ANN indices.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow_array::types::Float32Type;
|
||||||
|
use arrow_array::{
|
||||||
|
FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader,
|
||||||
|
};
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
use lancedb::connection::Connection;
|
||||||
|
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||||
|
use lancedb::index::Index;
|
||||||
|
use lancedb::query::{ExecutableQuery, QueryBase};
|
||||||
|
use lancedb::{connect, DistanceType, Result, Table};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
if std::path::Path::new("data").exists() {
|
||||||
|
std::fs::remove_dir_all("data").unwrap();
|
||||||
|
}
|
||||||
|
let uri = "data/sample-lancedb";
|
||||||
|
let db = connect(uri).execute().await?;
|
||||||
|
let tbl = create_table(&db).await?;
|
||||||
|
|
||||||
|
create_index(&tbl).await?;
|
||||||
|
search_index(&tbl).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_some_records() -> Result<Box<dyn RecordBatchReader + Send>> {
|
||||||
|
const TOTAL: usize = 1000;
|
||||||
|
const DIM: usize = 128;
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("id", DataType::Int32, false),
|
||||||
|
Field::new(
|
||||||
|
"vector",
|
||||||
|
DataType::FixedSizeList(
|
||||||
|
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||||
|
DIM as i32,
|
||||||
|
),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
|
||||||
|
// Create a RecordBatch stream.
|
||||||
|
let batches = RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||||
|
Arc::new(
|
||||||
|
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||||
|
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
|
||||||
|
DIM as i32,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap()]
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok),
|
||||||
|
schema.clone(),
|
||||||
|
);
|
||||||
|
Ok(Box::new(batches))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_table(db: &Connection) -> Result<Table> {
|
||||||
|
let initial_data: Box<dyn RecordBatchReader + Send> = create_some_records()?;
|
||||||
|
let tbl = db
|
||||||
|
.create_table("my_table", Box::new(initial_data))
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
Ok(tbl)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_index(table: &Table) -> Result<()> {
|
||||||
|
// --8<-- [start:create_index]
|
||||||
|
// For this example, `table` is a lancedb::Table with a column named
|
||||||
|
// "vector" that is a vector column with dimension 128.
|
||||||
|
|
||||||
|
// By default, if the column "vector" appears to be a vector column,
|
||||||
|
// then an IVF_PQ index with reasonable defaults is created.
|
||||||
|
table
|
||||||
|
.create_index(&["vector"], Index::Auto)
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
// For advanced cases, it is also possible to specifically request an
|
||||||
|
// IVF_PQ index and provide custom parameters.
|
||||||
|
table
|
||||||
|
.create_index(
|
||||||
|
&["vector"],
|
||||||
|
Index::IvfPq(
|
||||||
|
// Here we specify advanced indexing parameters. In this case
|
||||||
|
// we are creating an index that my have better recall than the
|
||||||
|
// default but is also larger and slower.
|
||||||
|
IvfPqIndexBuilder::default()
|
||||||
|
// This overrides the default distance type of L2
|
||||||
|
.distance_type(DistanceType::Cosine)
|
||||||
|
// With 1000 rows this have been ~31 by default
|
||||||
|
.num_partitions(50)
|
||||||
|
// With dimension 128 this would have been 8 by default
|
||||||
|
.num_sub_vectors(16),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
// --8<-- [end:create_index]
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn search_index(table: &Table) -> Result<()> {
|
||||||
|
// --8<-- [start:search1]
|
||||||
|
let query_vector = [1.0; 128];
|
||||||
|
// By default the index will find the 10 closest results using default
|
||||||
|
// search parameters that give a reasonable tradeoff between accuracy
|
||||||
|
// and search latency
|
||||||
|
let mut results = table
|
||||||
|
.vector_search(&query_vector)?
|
||||||
|
// Note: you should always set the distance_type to match the value used
|
||||||
|
// to train the index
|
||||||
|
.distance_type(DistanceType::Cosine)
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
while let Some(batch) = results.try_next().await? {
|
||||||
|
println!("{:?}", batch);
|
||||||
|
}
|
||||||
|
// We can also provide custom search parameters. Here we perform a
|
||||||
|
// slower but more accurate search
|
||||||
|
let mut results = table
|
||||||
|
.vector_search(&query_vector)?
|
||||||
|
.distance_type(DistanceType::Cosine)
|
||||||
|
// Override the default of 10 to get more rows
|
||||||
|
.limit(15)
|
||||||
|
// Override the default of 20 to search more partitions
|
||||||
|
.nprobes(30)
|
||||||
|
// Override the default of None to apply a refine step
|
||||||
|
.refine_factor(1)
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
while let Some(batch) = results.try_next().await? {
|
||||||
|
println!("{:?}", batch);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
// --8<-- [end:search1]
|
||||||
|
}
|
||||||
@@ -12,6 +12,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! This example demonstrates basic usage of LanceDb.
|
||||||
|
//!
|
||||||
|
//! Snippets from this example are used in the quickstart documentation.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arrow_array::types::Float32Type;
|
use arrow_array::types::Float32Type;
|
||||||
@@ -19,8 +23,10 @@ use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterat
|
|||||||
use arrow_schema::{DataType, Field, Schema};
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
|
|
||||||
|
use lancedb::arrow::IntoArrow;
|
||||||
use lancedb::connection::Connection;
|
use lancedb::connection::Connection;
|
||||||
use lancedb::index::Index;
|
use lancedb::index::Index;
|
||||||
|
use lancedb::query::{ExecutableQuery, QueryBase};
|
||||||
use lancedb::{connect, Result, Table as LanceDbTable};
|
use lancedb::{connect, Result, Table as LanceDbTable};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -57,14 +63,14 @@ async fn main() -> Result<()> {
|
|||||||
async fn open_with_existing_tbl() -> Result<()> {
|
async fn open_with_existing_tbl() -> Result<()> {
|
||||||
let uri = "data/sample-lancedb";
|
let uri = "data/sample-lancedb";
|
||||||
let db = connect(uri).execute().await?;
|
let db = connect(uri).execute().await?;
|
||||||
// --8<-- [start:open_with_existing_file]
|
#[allow(unused_variables)]
|
||||||
let _ = db.open_table("my_table").execute().await.unwrap();
|
// --8<-- [start:open_existing_tbl]
|
||||||
// --8<-- [end:open_with_existing_file]
|
let table = db.open_table("my_table").execute().await.unwrap();
|
||||||
|
// --8<-- [end:open_existing_tbl]
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_table(db: &Connection) -> Result<LanceDbTable> {
|
fn create_some_records() -> Result<impl IntoArrow> {
|
||||||
// --8<-- [start:create_table]
|
|
||||||
const TOTAL: usize = 1000;
|
const TOTAL: usize = 1000;
|
||||||
const DIM: usize = 128;
|
const DIM: usize = 128;
|
||||||
|
|
||||||
@@ -99,33 +105,22 @@ async fn create_table(db: &Connection) -> Result<LanceDbTable> {
|
|||||||
.map(Ok),
|
.map(Ok),
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
);
|
);
|
||||||
|
Ok(Box::new(batches))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_table(db: &Connection) -> Result<LanceDbTable> {
|
||||||
|
// --8<-- [start:create_table]
|
||||||
|
let initial_data = create_some_records()?;
|
||||||
let tbl = db
|
let tbl = db
|
||||||
.create_table("my_table", Box::new(batches))
|
.create_table("my_table", initial_data)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// --8<-- [end:create_table]
|
// --8<-- [end:create_table]
|
||||||
|
|
||||||
let new_batches = RecordBatchIterator::new(
|
|
||||||
vec![RecordBatch::try_new(
|
|
||||||
schema.clone(),
|
|
||||||
vec![
|
|
||||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
|
||||||
Arc::new(
|
|
||||||
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
|
||||||
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
|
|
||||||
DIM as i32,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
.unwrap()]
|
|
||||||
.into_iter()
|
|
||||||
.map(Ok),
|
|
||||||
schema.clone(),
|
|
||||||
);
|
|
||||||
// --8<-- [start:add]
|
// --8<-- [start:add]
|
||||||
tbl.add(Box::new(new_batches)).execute().await.unwrap();
|
let new_data = create_some_records()?;
|
||||||
|
tbl.add(new_data).execute().await.unwrap();
|
||||||
// --8<-- [end:add]
|
// --8<-- [end:add]
|
||||||
|
|
||||||
Ok(tbl)
|
Ok(tbl)
|
||||||
@@ -150,9 +145,10 @@ async fn create_index(table: &LanceDbTable) -> Result<()> {
|
|||||||
async fn search(table: &LanceDbTable) -> Result<Vec<RecordBatch>> {
|
async fn search(table: &LanceDbTable) -> Result<Vec<RecordBatch>> {
|
||||||
// --8<-- [start:search]
|
// --8<-- [start:search]
|
||||||
table
|
table
|
||||||
.search(&[1.0; 128])
|
.query()
|
||||||
.limit(2)
|
.limit(2)
|
||||||
.execute_stream()
|
.nearest_to(&[1.0; 128])?
|
||||||
|
.execute()
|
||||||
.await?
|
.await?
|
||||||
.try_collect::<Vec<_>>()
|
.try_collect::<Vec<_>>()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -101,3 +101,21 @@ impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> RecordBatchStream
|
|||||||
self.schema.clone()
|
self.schema.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A trait for converting incoming data to Arrow
|
||||||
|
///
|
||||||
|
/// Integrations should implement this trait to allow data to be
|
||||||
|
/// imported directly from the integration. For example, implementing
|
||||||
|
/// this trait for `Vec<Vec<...>>` would allow the `Vec` to be directly
|
||||||
|
/// used in methods like [`crate::connection::Connection::create_table`]
|
||||||
|
/// or [`crate::table::Table::add`]
|
||||||
|
pub trait IntoArrow {
|
||||||
|
/// Convert the data into an Arrow array
|
||||||
|
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
|
||||||
|
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
|
||||||
|
Ok(Box::new(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ use object_store::{
|
|||||||
};
|
};
|
||||||
use snafu::prelude::*;
|
use snafu::prelude::*;
|
||||||
|
|
||||||
|
use crate::arrow::IntoArrow;
|
||||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||||
use crate::table::{NativeTable, WriteOptions};
|
use crate::table::{NativeTable, WriteOptions};
|
||||||
@@ -116,23 +117,27 @@ impl TableNamesBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct NoData {}
|
||||||
|
|
||||||
|
impl IntoArrow for NoData {
|
||||||
|
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
|
||||||
|
unreachable!("NoData should never be converted to Arrow")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A builder for configuring a [`Connection::create_table`] operation
|
/// A builder for configuring a [`Connection::create_table`] operation
|
||||||
pub struct CreateTableBuilder<const HAS_DATA: bool> {
|
pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
|
||||||
parent: Arc<dyn ConnectionInternal>,
|
parent: Arc<dyn ConnectionInternal>,
|
||||||
pub(crate) name: String,
|
pub(crate) name: String,
|
||||||
pub(crate) data: Option<Box<dyn RecordBatchReader + Send>>,
|
pub(crate) data: Option<T>,
|
||||||
pub(crate) schema: Option<SchemaRef>,
|
pub(crate) schema: Option<SchemaRef>,
|
||||||
pub(crate) mode: CreateTableMode,
|
pub(crate) mode: CreateTableMode,
|
||||||
pub(crate) write_options: WriteOptions,
|
pub(crate) write_options: WriteOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builder methods that only apply when we have initial data
|
// Builder methods that only apply when we have initial data
|
||||||
impl CreateTableBuilder<true> {
|
impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
||||||
fn new(
|
fn new(parent: Arc<dyn ConnectionInternal>, name: String, data: T) -> Self {
|
||||||
parent: Arc<dyn ConnectionInternal>,
|
|
||||||
name: String,
|
|
||||||
data: Box<dyn RecordBatchReader + Send>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
parent,
|
parent,
|
||||||
name,
|
name,
|
||||||
@@ -151,12 +156,32 @@ impl CreateTableBuilder<true> {
|
|||||||
|
|
||||||
/// Execute the create table operation
|
/// Execute the create table operation
|
||||||
pub async fn execute(self) -> Result<Table> {
|
pub async fn execute(self) -> Result<Table> {
|
||||||
self.parent.clone().do_create_table(self).await
|
let parent = self.parent.clone();
|
||||||
|
let (data, builder) = self.extract_data()?;
|
||||||
|
parent.do_create_table(builder, data).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_data(
|
||||||
|
mut self,
|
||||||
|
) -> Result<(
|
||||||
|
Box<dyn RecordBatchReader + Send>,
|
||||||
|
CreateTableBuilder<false, NoData>,
|
||||||
|
)> {
|
||||||
|
let data = self.data.take().unwrap().into_arrow()?;
|
||||||
|
let builder = CreateTableBuilder::<false, NoData> {
|
||||||
|
parent: self.parent,
|
||||||
|
name: self.name,
|
||||||
|
data: None,
|
||||||
|
schema: self.schema,
|
||||||
|
mode: self.mode,
|
||||||
|
write_options: self.write_options,
|
||||||
|
};
|
||||||
|
Ok((data, builder))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builder methods that only apply when we do not have initial data
|
// Builder methods that only apply when we do not have initial data
|
||||||
impl CreateTableBuilder<false> {
|
impl CreateTableBuilder<false, NoData> {
|
||||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
||||||
Self {
|
Self {
|
||||||
parent,
|
parent,
|
||||||
@@ -174,7 +199,7 @@ impl CreateTableBuilder<false> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
|
||||||
/// Set the mode for creating the table
|
/// Set the mode for creating the table
|
||||||
///
|
///
|
||||||
/// This controls what happens if a table with the given name already exists
|
/// This controls what happens if a table with the given name already exists
|
||||||
@@ -237,17 +262,24 @@ pub(crate) trait ConnectionInternal:
|
|||||||
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
|
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
|
||||||
{
|
{
|
||||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
|
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
|
||||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<Table>;
|
async fn do_create_table(
|
||||||
|
&self,
|
||||||
|
options: CreateTableBuilder<false, NoData>,
|
||||||
|
data: Box<dyn RecordBatchReader + Send>,
|
||||||
|
) -> Result<Table>;
|
||||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table>;
|
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table>;
|
||||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||||
async fn drop_db(&self) -> Result<()>;
|
async fn drop_db(&self) -> Result<()>;
|
||||||
|
|
||||||
async fn do_create_empty_table(&self, options: CreateTableBuilder<false>) -> Result<Table> {
|
async fn do_create_empty_table(
|
||||||
let batches = RecordBatchIterator::new(vec![], options.schema.unwrap());
|
&self,
|
||||||
let opts = CreateTableBuilder::<true>::new(options.parent, options.name, Box::new(batches))
|
options: CreateTableBuilder<false, NoData>,
|
||||||
.mode(options.mode)
|
) -> Result<Table> {
|
||||||
.write_options(options.write_options);
|
let batches = Box::new(RecordBatchIterator::new(
|
||||||
self.do_create_table(opts).await
|
vec![],
|
||||||
|
options.schema.as_ref().unwrap().clone(),
|
||||||
|
));
|
||||||
|
self.do_create_table(options, batches).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,12 +317,12 @@ impl Connection {
|
|||||||
///
|
///
|
||||||
/// * `name` - The name of the table
|
/// * `name` - The name of the table
|
||||||
/// * `initial_data` - The initial data to write to the table
|
/// * `initial_data` - The initial data to write to the table
|
||||||
pub fn create_table(
|
pub fn create_table<T: IntoArrow>(
|
||||||
&self,
|
&self,
|
||||||
name: impl Into<String>,
|
name: impl Into<String>,
|
||||||
initial_data: Box<dyn RecordBatchReader + Send>,
|
initial_data: T,
|
||||||
) -> CreateTableBuilder<true> {
|
) -> CreateTableBuilder<true, T> {
|
||||||
CreateTableBuilder::<true>::new(self.internal.clone(), name.into(), initial_data)
|
CreateTableBuilder::<true, T>::new(self.internal.clone(), name.into(), initial_data)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an empty table with a given schema
|
/// Create an empty table with a given schema
|
||||||
@@ -303,8 +335,8 @@ impl Connection {
|
|||||||
&self,
|
&self,
|
||||||
name: impl Into<String>,
|
name: impl Into<String>,
|
||||||
schema: SchemaRef,
|
schema: SchemaRef,
|
||||||
) -> CreateTableBuilder<false> {
|
) -> CreateTableBuilder<false, NoData> {
|
||||||
CreateTableBuilder::<false>::new(self.internal.clone(), name.into(), schema)
|
CreateTableBuilder::<false, NoData>::new(self.internal.clone(), name.into(), schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Open an existing table in the database
|
/// Open an existing table in the database
|
||||||
@@ -694,7 +726,11 @@ impl ConnectionInternal for Database {
|
|||||||
Ok(f)
|
Ok(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<Table> {
|
async fn do_create_table(
|
||||||
|
&self,
|
||||||
|
options: CreateTableBuilder<false, NoData>,
|
||||||
|
data: Box<dyn RecordBatchReader + Send>,
|
||||||
|
) -> Result<Table> {
|
||||||
let table_uri = self.table_uri(&options.name)?;
|
let table_uri = self.table_uri(&options.name)?;
|
||||||
|
|
||||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||||
@@ -705,7 +741,7 @@ impl ConnectionInternal for Database {
|
|||||||
match NativeTable::create(
|
match NativeTable::create(
|
||||||
&table_uri,
|
&table_uri,
|
||||||
&options.name,
|
&options.name,
|
||||||
options.data.unwrap(),
|
data,
|
||||||
self.store_wrapper.clone(),
|
self.store_wrapper.clone(),
|
||||||
Some(write_params),
|
Some(write_params),
|
||||||
self.read_consistency_interval,
|
self.read_consistency_interval,
|
||||||
|
|||||||
@@ -342,7 +342,11 @@ mod test {
|
|||||||
use object_store::local::LocalFileSystem;
|
use object_store::local::LocalFileSystem;
|
||||||
use tempfile;
|
use tempfile;
|
||||||
|
|
||||||
use crate::{connect, table::WriteOptions};
|
use crate::{
|
||||||
|
connect,
|
||||||
|
query::{ExecutableQuery, QueryBase},
|
||||||
|
table::WriteOptions,
|
||||||
|
};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_e2e() {
|
async fn test_e2e() {
|
||||||
@@ -381,9 +385,11 @@ mod test {
|
|||||||
assert_eq!(t.count_rows(None).await.unwrap(), 100);
|
assert_eq!(t.count_rows(None).await.unwrap(), 100);
|
||||||
|
|
||||||
let q = t
|
let q = t
|
||||||
.search(&[0.1, 0.1, 0.1, 0.1])
|
.query()
|
||||||
.limit(10)
|
.limit(10)
|
||||||
.execute_stream()
|
.nearest_to(&[0.1, 0.1, 0.1, 0.1])
|
||||||
|
.unwrap()
|
||||||
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -36,8 +36,6 @@
|
|||||||
//!
|
//!
|
||||||
//! ### Quick Start
|
//! ### Quick Start
|
||||||
//!
|
//!
|
||||||
//! <div class="warning">Rust API is not stable yet, please expect breaking changes.</div>
|
|
||||||
//!
|
|
||||||
//! #### Connect to a database.
|
//! #### Connect to a database.
|
||||||
//!
|
//!
|
||||||
//! ```rust
|
//! ```rust
|
||||||
@@ -150,6 +148,7 @@
|
|||||||
//! # use arrow_schema::{DataType, Schema, Field};
|
//! # use arrow_schema::{DataType, Schema, Field};
|
||||||
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
|
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||||
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
|
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
|
||||||
|
//! # use lancedb::query::{ExecutableQuery, QueryBase};
|
||||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||||
//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
//! # let db = lancedb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||||
@@ -170,8 +169,10 @@
|
|||||||
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
|
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
|
||||||
//! # let table = db.open_table("my_table").execute().await.unwrap();
|
//! # let table = db.open_table("my_table").execute().await.unwrap();
|
||||||
//! let results = table
|
//! let results = table
|
||||||
//! .search(&[1.0; 128])
|
//! .query()
|
||||||
//! .execute_stream()
|
//! .nearest_to(&[1.0; 128])
|
||||||
|
//! .unwrap()
|
||||||
|
//! .execute()
|
||||||
//! .await
|
//! .await
|
||||||
//! .unwrap()
|
//! .unwrap()
|
||||||
//! .try_collect::<Vec<_>>()
|
//! .try_collect::<Vec<_>>()
|
||||||
@@ -193,9 +194,72 @@ pub(crate) mod remote;
|
|||||||
pub mod table;
|
pub mod table;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
pub use connection::Connection;
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use lance_linalg::distance::DistanceType;
|
use lance_linalg::distance::DistanceType as LanceDistanceType;
|
||||||
pub use table::Table;
|
pub use table::Table;
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum DistanceType {
|
||||||
|
/// Euclidean distance. This is a very common distance metric that
|
||||||
|
/// accounts for both magnitude and direction when determining the distance
|
||||||
|
/// between vectors. L2 distance has a range of [0, ∞).
|
||||||
|
L2,
|
||||||
|
/// Cosine distance. Cosine distance is a distance metric
|
||||||
|
/// calculated from the cosine similarity between two vectors. Cosine
|
||||||
|
/// similarity is a measure of similarity between two non-zero vectors of an
|
||||||
|
/// inner product space. It is defined to equal the cosine of the angle
|
||||||
|
/// between them. Unlike L2, the cosine distance is not affected by the
|
||||||
|
/// magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||||
|
///
|
||||||
|
/// Note: the cosine distance is undefined when one (or both) of the vectors
|
||||||
|
/// are all zeros (there is no direction). These vectors are invalid and may
|
||||||
|
/// never be returned from a vector search.
|
||||||
|
Cosine,
|
||||||
|
/// Dot product. Dot distance is the dot product of two vectors. Dot
|
||||||
|
/// distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||||
|
/// L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||||
|
Dot,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DistanceType> for LanceDistanceType {
|
||||||
|
fn from(value: DistanceType) -> Self {
|
||||||
|
match value {
|
||||||
|
DistanceType::L2 => LanceDistanceType::L2,
|
||||||
|
DistanceType::Cosine => LanceDistanceType::Cosine,
|
||||||
|
DistanceType::Dot => LanceDistanceType::Dot,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<LanceDistanceType> for DistanceType {
|
||||||
|
fn from(value: LanceDistanceType) -> Self {
|
||||||
|
match value {
|
||||||
|
LanceDistanceType::L2 => DistanceType::L2,
|
||||||
|
LanceDistanceType::Cosine => DistanceType::Cosine,
|
||||||
|
LanceDistanceType::Dot => DistanceType::Dot,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> TryFrom<&'a str> for DistanceType {
|
||||||
|
type Error = <LanceDistanceType as TryFrom<&'a str>>::Error;
|
||||||
|
|
||||||
|
fn try_from(value: &str) -> std::prelude::v1::Result<Self, Self::Error> {
|
||||||
|
LanceDistanceType::try_from(value).map(DistanceType::from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DistanceType {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
LanceDistanceType::from(*self).fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Connect to a database
|
/// Connect to a database
|
||||||
pub use connection::connect;
|
pub use connection::connect;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -14,13 +14,14 @@
|
|||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow_array::RecordBatchReader;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::header::CONTENT_TYPE;
|
use reqwest::header::CONTENT_TYPE;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tokio::task::spawn_blocking;
|
use tokio::task::spawn_blocking;
|
||||||
|
|
||||||
use crate::connection::{
|
use crate::connection::{
|
||||||
ConnectionInternal, CreateTableBuilder, OpenTableBuilder, TableNamesBuilder,
|
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
|
||||||
};
|
};
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::Table;
|
use crate::Table;
|
||||||
@@ -74,8 +75,11 @@ impl ConnectionInternal for RemoteDatabase {
|
|||||||
Ok(rsp.json::<ListTablesResponse>().await?.tables)
|
Ok(rsp.json::<ListTablesResponse>().await?.tables)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<Table> {
|
async fn do_create_table(
|
||||||
let data = options.data.unwrap();
|
&self,
|
||||||
|
options: CreateTableBuilder<false, NoData>,
|
||||||
|
data: Box<dyn RecordBatchReader + Send>,
|
||||||
|
) -> Result<Table> {
|
||||||
// TODO: https://github.com/lancedb/lancedb/issues/1026
|
// TODO: https://github.com/lancedb/lancedb/issues/1026
|
||||||
// We should accept data from an async source. In the meantime, spawn this as blocking
|
// We should accept data from an async source. In the meantime, spawn this as blocking
|
||||||
// to make sure we don't block the tokio runtime if the source is slow.
|
// to make sure we don't block the tokio runtime if the source is slow.
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ use async_trait::async_trait;
|
|||||||
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
|
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
connection::NoData,
|
||||||
error::Result,
|
error::Result,
|
||||||
index::{IndexBuilder, IndexConfig},
|
index::{IndexBuilder, IndexConfig},
|
||||||
query::Query,
|
query::{Query, QueryExecutionOptions, VectorQuery},
|
||||||
table::{
|
table::{
|
||||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||||
TableInternal, UpdateBuilder,
|
TableInternal, UpdateBuilder,
|
||||||
@@ -63,10 +64,25 @@ impl TableInternal for RemoteTable {
|
|||||||
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
|
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn add(&self, _add: AddDataBuilder) -> Result<()> {
|
async fn add(
|
||||||
|
&self,
|
||||||
|
_add: AddDataBuilder<NoData>,
|
||||||
|
_data: Box<dyn RecordBatchReader + Send>,
|
||||||
|
) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
|
async fn plain_query(
|
||||||
|
&self,
|
||||||
|
_query: &Query,
|
||||||
|
_options: QueryExecutionOptions,
|
||||||
|
) -> Result<DatasetRecordBatchStream> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn vector_query(
|
||||||
|
&self,
|
||||||
|
_query: &VectorQuery,
|
||||||
|
_options: QueryExecutionOptions,
|
||||||
|
) -> Result<DatasetRecordBatchStream> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
|
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow::array::AsArray;
|
||||||
|
use arrow::datatypes::Float32Type;
|
||||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -40,6 +42,8 @@ use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
|||||||
use log::info;
|
use log::info;
|
||||||
use snafu::whatever;
|
use snafu::whatever;
|
||||||
|
|
||||||
|
use crate::arrow::IntoArrow;
|
||||||
|
use crate::connection::NoData;
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
||||||
use crate::index::IndexConfig;
|
use crate::index::IndexConfig;
|
||||||
@@ -47,7 +51,9 @@ use crate::index::{
|
|||||||
vector::{suggested_num_partitions, suggested_num_sub_vectors},
|
vector::{suggested_num_partitions, suggested_num_sub_vectors},
|
||||||
Index, IndexBuilder,
|
Index, IndexBuilder,
|
||||||
};
|
};
|
||||||
use crate::query::{Query, Select, DEFAULT_TOP_K};
|
use crate::query::{
|
||||||
|
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
|
||||||
|
};
|
||||||
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
|
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
|
||||||
|
|
||||||
use self::dataset::DatasetConsistencyWrapper;
|
use self::dataset::DatasetConsistencyWrapper;
|
||||||
@@ -120,14 +126,14 @@ pub enum AddDataMode {
|
|||||||
|
|
||||||
/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`]
|
/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`]
|
||||||
/// operation
|
/// operation
|
||||||
pub struct AddDataBuilder {
|
pub struct AddDataBuilder<T: IntoArrow> {
|
||||||
parent: Arc<dyn TableInternal>,
|
parent: Arc<dyn TableInternal>,
|
||||||
pub(crate) data: Box<dyn RecordBatchReader + Send>,
|
pub(crate) data: T,
|
||||||
pub(crate) mode: AddDataMode,
|
pub(crate) mode: AddDataMode,
|
||||||
pub(crate) write_options: WriteOptions,
|
pub(crate) write_options: WriteOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for AddDataBuilder {
|
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("AddDataBuilder")
|
f.debug_struct("AddDataBuilder")
|
||||||
.field("parent", &self.parent)
|
.field("parent", &self.parent)
|
||||||
@@ -137,7 +143,7 @@ impl std::fmt::Debug for AddDataBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AddDataBuilder {
|
impl<T: IntoArrow> AddDataBuilder<T> {
|
||||||
pub fn mode(mut self, mode: AddDataMode) -> Self {
|
pub fn mode(mut self, mode: AddDataMode) -> Self {
|
||||||
self.mode = mode;
|
self.mode = mode;
|
||||||
self
|
self
|
||||||
@@ -149,7 +155,15 @@ impl AddDataBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn execute(self) -> Result<()> {
|
pub async fn execute(self) -> Result<()> {
|
||||||
self.parent.clone().add(self).await
|
let parent = self.parent.clone();
|
||||||
|
let data = self.data.into_arrow()?;
|
||||||
|
let without_data = AddDataBuilder::<NoData> {
|
||||||
|
data: NoData {},
|
||||||
|
mode: self.mode,
|
||||||
|
parent: self.parent,
|
||||||
|
write_options: self.write_options,
|
||||||
|
};
|
||||||
|
parent.add(without_data, data).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,8 +243,21 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
|||||||
async fn schema(&self) -> Result<SchemaRef>;
|
async fn schema(&self) -> Result<SchemaRef>;
|
||||||
/// Count the number of rows in this table.
|
/// Count the number of rows in this table.
|
||||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||||
async fn add(&self, add: AddDataBuilder) -> Result<()>;
|
async fn plain_query(
|
||||||
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream>;
|
&self,
|
||||||
|
query: &Query,
|
||||||
|
options: QueryExecutionOptions,
|
||||||
|
) -> Result<DatasetRecordBatchStream>;
|
||||||
|
async fn vector_query(
|
||||||
|
&self,
|
||||||
|
query: &VectorQuery,
|
||||||
|
options: QueryExecutionOptions,
|
||||||
|
) -> Result<DatasetRecordBatchStream>;
|
||||||
|
async fn add(
|
||||||
|
&self,
|
||||||
|
add: AddDataBuilder<NoData>,
|
||||||
|
data: Box<dyn arrow_array::RecordBatchReader + Send>,
|
||||||
|
) -> Result<()>;
|
||||||
async fn delete(&self, predicate: &str) -> Result<()>;
|
async fn delete(&self, predicate: &str) -> Result<()>;
|
||||||
async fn update(&self, update: UpdateBuilder) -> Result<()>;
|
async fn update(&self, update: UpdateBuilder) -> Result<()>;
|
||||||
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
|
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
|
||||||
@@ -306,7 +333,7 @@ impl Table {
|
|||||||
///
|
///
|
||||||
/// * `batches` data to be added to the Table
|
/// * `batches` data to be added to the Table
|
||||||
/// * `options` options to control how data is added
|
/// * `options` options to control how data is added
|
||||||
pub fn add(&self, batches: Box<dyn RecordBatchReader + Send>) -> AddDataBuilder {
|
pub fn add<T: IntoArrow>(&self, batches: T) -> AddDataBuilder<T> {
|
||||||
AddDataBuilder {
|
AddDataBuilder {
|
||||||
parent: self.inner.clone(),
|
parent: self.inner.clone(),
|
||||||
data: batches,
|
data: batches,
|
||||||
@@ -528,21 +555,30 @@ impl Table {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Search the table with a given query vector.
|
/// Create a [`Query`] Builder.
|
||||||
///
|
///
|
||||||
/// This is a convenience method for preparing an ANN query.
|
/// Queries allow you to search your existing data. By default the query will
|
||||||
pub fn search(&self, query: &[f32]) -> Query {
|
/// return all the data in the table in no particular order. The builder
|
||||||
self.query().nearest_to(query)
|
/// returned by this method can be used to control the query using filtering,
|
||||||
}
|
/// vector similarity, sorting, and more.
|
||||||
|
|
||||||
/// Create a generic [`Query`] Builder.
|
|
||||||
///
|
///
|
||||||
/// When appropriate, various indices and statistics based pruning will be used to
|
/// Note: By default, all columns are returned. For best performance, you should
|
||||||
/// accelerate the query.
|
/// only fetch the columns you need. See [`Query::select_with_projection`] for
|
||||||
|
/// more details.
|
||||||
|
///
|
||||||
|
/// When appropriate, various indices and statistics will be used to accelerate
|
||||||
|
/// the query.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// ## Run a vector search (ANN) query.
|
/// ## Vector search
|
||||||
|
///
|
||||||
|
/// This example will find the 10 rows whose value in the "vector" column are
|
||||||
|
/// closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
|
||||||
|
/// on the "vector" column then this will perform an ANN search.
|
||||||
|
///
|
||||||
|
/// The [`Query::refine_factor`] and [`Query::nprobes`] methods are used to
|
||||||
|
/// control the recall / latency tradeoff of the search.
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// # use arrow_array::RecordBatch;
|
/// # use arrow_array::RecordBatch;
|
||||||
@@ -551,19 +587,25 @@ impl Table {
|
|||||||
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
|
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
|
||||||
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
|
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
|
||||||
/// use crate::lancedb::Table;
|
/// use crate::lancedb::Table;
|
||||||
|
/// use crate::lancedb::query::ExecutableQuery;
|
||||||
/// let stream = tbl
|
/// let stream = tbl
|
||||||
/// .query()
|
/// .query()
|
||||||
/// .nearest_to(&[1.0, 2.0, 3.0])
|
/// .nearest_to(&[1.0, 2.0, 3.0])
|
||||||
|
/// .unwrap()
|
||||||
/// .refine_factor(5)
|
/// .refine_factor(5)
|
||||||
/// .nprobes(10)
|
/// .nprobes(10)
|
||||||
/// .execute_stream()
|
/// .execute()
|
||||||
/// .await
|
/// .await
|
||||||
/// .unwrap();
|
/// .unwrap();
|
||||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||||
/// # });
|
/// # });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// ## Run a SQL-style filter
|
/// ## SQL-style filter
|
||||||
|
///
|
||||||
|
/// This query will return up to 1000 rows whose value in the `id` column
|
||||||
|
/// is greater than 5. LanceDb supports a broad set of filtering functions.
|
||||||
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// # use arrow_array::RecordBatch;
|
/// # use arrow_array::RecordBatch;
|
||||||
/// # use futures::TryStreamExt;
|
/// # use futures::TryStreamExt;
|
||||||
@@ -571,18 +613,23 @@ impl Table {
|
|||||||
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
|
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
|
||||||
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
|
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
|
||||||
/// use crate::lancedb::Table;
|
/// use crate::lancedb::Table;
|
||||||
|
/// use crate::lancedb::query::{ExecutableQuery, QueryBase};
|
||||||
/// let stream = tbl
|
/// let stream = tbl
|
||||||
/// .query()
|
/// .query()
|
||||||
/// .filter("id > 5")
|
/// .only_if("id > 5")
|
||||||
/// .limit(1000)
|
/// .limit(1000)
|
||||||
/// .execute_stream()
|
/// .execute()
|
||||||
/// .await
|
/// .await
|
||||||
/// .unwrap();
|
/// .unwrap();
|
||||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||||
/// # });
|
/// # });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// ## Run a full scan query.
|
/// ## Full scan
|
||||||
|
///
|
||||||
|
/// This query will return everything in the table in no particular
|
||||||
|
/// order.
|
||||||
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// # use arrow_array::RecordBatch;
|
/// # use arrow_array::RecordBatch;
|
||||||
/// # use futures::TryStreamExt;
|
/// # use futures::TryStreamExt;
|
||||||
@@ -590,7 +637,8 @@ impl Table {
|
|||||||
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
|
/// # let conn = lancedb::connect("/tmp").execute().await.unwrap();
|
||||||
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
|
/// # let tbl = conn.open_table("tbl").execute().await.unwrap();
|
||||||
/// use crate::lancedb::Table;
|
/// use crate::lancedb::Table;
|
||||||
/// let stream = tbl.query().execute_stream().await.unwrap();
|
/// use crate::lancedb::query::ExecutableQuery;
|
||||||
|
/// let stream = tbl.query().execute().await.unwrap();
|
||||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||||
/// # });
|
/// # });
|
||||||
/// ```
|
/// ```
|
||||||
@@ -598,6 +646,15 @@ impl Table {
|
|||||||
Query::new(self.inner.clone())
|
Query::new(self.inner.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Search the table with a given query vector.
|
||||||
|
///
|
||||||
|
/// This is a convenience method for preparing a vector query and
|
||||||
|
/// is the same thing as calling `nearest_to` on the builder returned
|
||||||
|
/// by `query`. See [`Query::nearest_to`] for more details.
|
||||||
|
pub fn vector_search(&self, query: impl IntoQueryVector) -> Result<VectorQuery> {
|
||||||
|
self.query().nearest_to(query)
|
||||||
|
}
|
||||||
|
|
||||||
/// Optimize the on-disk data and indices for better performance.
|
/// Optimize the on-disk data and indices for better performance.
|
||||||
///
|
///
|
||||||
/// <section class="warning">Experimental API</section>
|
/// <section class="warning">Experimental API</section>
|
||||||
@@ -1051,7 +1108,7 @@ impl NativeTable {
|
|||||||
/*num_bits=*/ 8,
|
/*num_bits=*/ 8,
|
||||||
num_sub_vectors as usize,
|
num_sub_vectors as usize,
|
||||||
false,
|
false,
|
||||||
index.distance_type,
|
index.distance_type.into(),
|
||||||
index.max_iterations as usize,
|
index.max_iterations as usize,
|
||||||
);
|
);
|
||||||
dataset
|
dataset
|
||||||
@@ -1107,6 +1164,86 @@ impl NativeTable {
|
|||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn generic_query(
|
||||||
|
&self,
|
||||||
|
query: &VectorQuery,
|
||||||
|
options: QueryExecutionOptions,
|
||||||
|
) -> Result<DatasetRecordBatchStream> {
|
||||||
|
let ds_ref = self.dataset.get().await?;
|
||||||
|
let mut scanner: Scanner = ds_ref.scan();
|
||||||
|
|
||||||
|
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||||
|
// If there is a vector query, default to limit=10 if unspecified
|
||||||
|
let column = if let Some(col) = query.column.as_ref() {
|
||||||
|
col.clone()
|
||||||
|
} else {
|
||||||
|
// Infer a vector column with the same dimension of the query vector.
|
||||||
|
let arrow_schema = Schema::from(ds_ref.schema());
|
||||||
|
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
|
||||||
|
};
|
||||||
|
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
|
||||||
|
message: format!("Column {} not found in dataset schema", column),
|
||||||
|
})?;
|
||||||
|
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
||||||
|
if !f.data_type().is_floating() {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"The data type of the vector column '{}' is not a floating point type",
|
||||||
|
column
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if dim != query_vector.len() as i32 {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"The dimension of the query vector does not match with the dimension of the vector column '{}':
|
||||||
|
query dim={}, expected vector dim={}",
|
||||||
|
column,
|
||||||
|
query_vector.len(),
|
||||||
|
dim,
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||||
|
scanner.nearest(
|
||||||
|
&column,
|
||||||
|
query_vector,
|
||||||
|
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||||
|
)?;
|
||||||
|
} else {
|
||||||
|
// If there is no vector query, it's ok to not have a limit
|
||||||
|
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
|
||||||
|
}
|
||||||
|
scanner.nprobs(query.nprobes);
|
||||||
|
scanner.use_index(query.use_index);
|
||||||
|
scanner.prefilter(query.prefilter);
|
||||||
|
scanner.batch_size(options.max_batch_length as usize);
|
||||||
|
|
||||||
|
match &query.base.select {
|
||||||
|
Select::Columns(select) => {
|
||||||
|
scanner.project(select.as_slice())?;
|
||||||
|
}
|
||||||
|
Select::Dynamic(select_with_transform) => {
|
||||||
|
scanner.project_with_transform(select_with_transform.as_slice())?;
|
||||||
|
}
|
||||||
|
Select::All => { /* Do nothing */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(filter) = &query.base.filter {
|
||||||
|
scanner.filter(filter)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(refine_factor) = query.refine_factor {
|
||||||
|
scanner.refine(refine_factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(distance_type) = query.distance_type {
|
||||||
|
scanner.distance_metric(distance_type.into());
|
||||||
|
}
|
||||||
|
Ok(scanner.try_into_stream().await?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
@@ -1176,7 +1313,11 @@ impl TableInternal for NativeTable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add(&self, add: AddDataBuilder) -> Result<()> {
|
async fn add(
|
||||||
|
&self,
|
||||||
|
add: AddDataBuilder<NoData>,
|
||||||
|
data: Box<dyn RecordBatchReader + Send>,
|
||||||
|
) -> Result<()> {
|
||||||
let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
||||||
mode: match add.mode {
|
mode: match add.mode {
|
||||||
AddDataMode::Append => WriteMode::Append,
|
AddDataMode::Append => WriteMode::Append,
|
||||||
@@ -1193,7 +1334,7 @@ impl TableInternal for NativeTable {
|
|||||||
|
|
||||||
self.dataset.ensure_mutable().await?;
|
self.dataset.ensure_mutable().await?;
|
||||||
|
|
||||||
let dataset = Dataset::write(add.data, &self.uri, Some(lance_params)).await?;
|
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
|
||||||
self.dataset.set_latest(dataset).await;
|
self.dataset.set_latest(dataset).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -1232,63 +1373,21 @@ impl TableInternal for NativeTable {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream> {
|
async fn plain_query(
|
||||||
let ds_ref = self.dataset.get().await?;
|
&self,
|
||||||
let mut scanner: Scanner = ds_ref.scan();
|
query: &Query,
|
||||||
|
options: QueryExecutionOptions,
|
||||||
|
) -> Result<DatasetRecordBatchStream> {
|
||||||
|
self.generic_query(&query.clone().into_vector(), options)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
async fn vector_query(
|
||||||
// If there is a vector query, default to limit=10 if unspecified
|
&self,
|
||||||
let column = if let Some(col) = query.column.as_ref() {
|
query: &VectorQuery,
|
||||||
col.clone()
|
options: QueryExecutionOptions,
|
||||||
} else {
|
) -> Result<DatasetRecordBatchStream> {
|
||||||
// Infer a vector column with the same dimension of the query vector.
|
self.generic_query(query, options).await
|
||||||
let arrow_schema = Schema::from(ds_ref.schema());
|
|
||||||
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
|
|
||||||
};
|
|
||||||
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
|
|
||||||
message: format!("Column {} not found in dataset schema", column),
|
|
||||||
})?;
|
|
||||||
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32)
|
|
||||||
{
|
|
||||||
return Err(Error::Schema {
|
|
||||||
message: format!(
|
|
||||||
"Vector column '{}' does not match the dimension of the query vector: dim={}",
|
|
||||||
column,
|
|
||||||
query_vector.len(),
|
|
||||||
),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
scanner.nearest(&column, query_vector, query.limit.unwrap_or(DEFAULT_TOP_K))?;
|
|
||||||
} else {
|
|
||||||
// If there is no vector query, it's ok to not have a limit
|
|
||||||
scanner.limit(query.limit.map(|limit| limit as i64), None)?;
|
|
||||||
}
|
|
||||||
scanner.nprobs(query.nprobes);
|
|
||||||
scanner.use_index(query.use_index);
|
|
||||||
scanner.prefilter(query.prefilter);
|
|
||||||
|
|
||||||
match &query.select {
|
|
||||||
Select::Simple(select) => {
|
|
||||||
scanner.project(select.as_slice())?;
|
|
||||||
}
|
|
||||||
Select::Projection(select_with_transform) => {
|
|
||||||
scanner.project_with_transform(select_with_transform.as_slice())?;
|
|
||||||
}
|
|
||||||
Select::All => { /* Do nothing */ }
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(filter) = &query.filter {
|
|
||||||
scanner.filter(filter)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(refine_factor) = query.refine_factor {
|
|
||||||
scanner.refine(refine_factor);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(metric_type) = query.metric_type {
|
|
||||||
scanner.distance_metric(metric_type);
|
|
||||||
}
|
|
||||||
Ok(scanner.try_into_stream().await?)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn merge_insert(
|
async fn merge_insert(
|
||||||
@@ -1450,6 +1549,7 @@ mod tests {
|
|||||||
use crate::connect;
|
use crate::connect;
|
||||||
use crate::connection::ConnectBuilder;
|
use crate::connection::ConnectBuilder;
|
||||||
use crate::index::scalar::BTreeIndexBuilder;
|
use crate::index::scalar::BTreeIndexBuilder;
|
||||||
|
use crate::query::{ExecutableQuery, QueryBase};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@@ -1512,11 +1612,7 @@ mod tests {
|
|||||||
|
|
||||||
let batches = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let schema = batches.schema().clone();
|
let schema = batches.schema().clone();
|
||||||
let table = conn
|
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||||
.create_table("test", Box::new(batches))
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
|
|
||||||
let new_batches = RecordBatchIterator::new(
|
let new_batches = RecordBatchIterator::new(
|
||||||
@@ -1530,7 +1626,7 @@ mod tests {
|
|||||||
schema.clone(),
|
schema.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
table.add(Box::new(new_batches)).execute().await.unwrap();
|
table.add(new_batches).execute().await.unwrap();
|
||||||
assert_eq!(table.count_rows(None).await.unwrap(), 20);
|
assert_eq!(table.count_rows(None).await.unwrap(), 20);
|
||||||
assert_eq!(table.name(), "test");
|
assert_eq!(table.name(), "test");
|
||||||
}
|
}
|
||||||
@@ -1544,7 +1640,7 @@ mod tests {
|
|||||||
// Create a dataset with i=0..10
|
// Create a dataset with i=0..10
|
||||||
let batches = merge_insert_test_batches(0, 0);
|
let batches = merge_insert_test_batches(0, 0);
|
||||||
let table = conn
|
let table = conn
|
||||||
.create_table("my_table", Box::new(batches))
|
.create_table("my_table", batches)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1592,11 +1688,7 @@ mod tests {
|
|||||||
|
|
||||||
let batches = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let schema = batches.schema().clone();
|
let schema = batches.schema().clone();
|
||||||
let table = conn
|
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||||
.create_table("test", Box::new(batches))
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
|
|
||||||
let batches = vec![RecordBatch::try_new(
|
let batches = vec![RecordBatch::try_new(
|
||||||
@@ -1611,7 +1703,7 @@ mod tests {
|
|||||||
|
|
||||||
// Can overwrite using AddDataOptions::mode
|
// Can overwrite using AddDataOptions::mode
|
||||||
table
|
table
|
||||||
.add(Box::new(new_batches))
|
.add(new_batches)
|
||||||
.mode(AddDataMode::Overwrite)
|
.mode(AddDataMode::Overwrite)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
@@ -1629,7 +1721,7 @@ mod tests {
|
|||||||
|
|
||||||
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
|
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
|
||||||
table
|
table
|
||||||
.add(Box::new(new_batches))
|
.add(new_batches)
|
||||||
.write_options(WriteOptions {
|
.write_options(WriteOptions {
|
||||||
lance_write_params: Some(param),
|
lance_write_params: Some(param),
|
||||||
})
|
})
|
||||||
@@ -1674,7 +1766,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let table = conn
|
let table = conn
|
||||||
.create_table("my_table", Box::new(record_batch_iter))
|
.create_table("my_table", record_batch_iter)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1689,8 +1781,8 @@ mod tests {
|
|||||||
|
|
||||||
let mut batches = table
|
let mut batches = table
|
||||||
.query()
|
.query()
|
||||||
.select(&["id", "name"])
|
.select(Select::columns(&["id", "name"]))
|
||||||
.execute_stream()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.try_collect::<Vec<_>>()
|
.try_collect::<Vec<_>>()
|
||||||
@@ -1811,7 +1903,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let table = conn
|
let table = conn
|
||||||
.create_table("my_table", Box::new(record_batch_iter))
|
.create_table("my_table", record_batch_iter)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1841,7 +1933,7 @@ mod tests {
|
|||||||
|
|
||||||
let mut batches = table
|
let mut batches = table
|
||||||
.query()
|
.query()
|
||||||
.select(&[
|
.select(Select::columns(&[
|
||||||
"string",
|
"string",
|
||||||
"large_string",
|
"large_string",
|
||||||
"int32",
|
"int32",
|
||||||
@@ -1855,8 +1947,8 @@ mod tests {
|
|||||||
"timestamp_ms",
|
"timestamp_ms",
|
||||||
"vec_f32",
|
"vec_f32",
|
||||||
"vec_f64",
|
"vec_f64",
|
||||||
])
|
]))
|
||||||
.execute_stream()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.try_collect::<Vec<_>>()
|
.try_collect::<Vec<_>>()
|
||||||
@@ -1932,7 +2024,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let tbl = conn
|
let tbl = conn
|
||||||
.create_table("my_table", Box::new(make_test_batches()))
|
.create_table("my_table", make_test_batches())
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1971,7 +2063,7 @@ mod tests {
|
|||||||
|
|
||||||
let batches = make_test_batches();
|
let batches = make_test_batches();
|
||||||
|
|
||||||
conn.create_table("my_table", Box::new(batches))
|
conn.create_table("my_table", batches)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -2064,11 +2156,7 @@ mod tests {
|
|||||||
schema,
|
schema,
|
||||||
);
|
);
|
||||||
|
|
||||||
let table = conn
|
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||||
.create_table("test", Box::new(batches))
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
table
|
table
|
||||||
@@ -2139,7 +2227,7 @@ mod tests {
|
|||||||
Ok(FixedSizeListArray::from(data))
|
Ok(FixedSizeListArray::from(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn some_sample_data() -> impl RecordBatchReader {
|
fn some_sample_data() -> Box<dyn RecordBatchReader + Send> {
|
||||||
let batch = RecordBatch::try_new(
|
let batch = RecordBatch::try_new(
|
||||||
Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])),
|
Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])),
|
||||||
vec![Arc::new(Int32Array::from(vec![1]))],
|
vec![Arc::new(Int32Array::from(vec![1]))],
|
||||||
@@ -2148,7 +2236,7 @@ mod tests {
|
|||||||
let schema = batch.schema().clone();
|
let schema = batch.schema().clone();
|
||||||
let batch = Ok(batch);
|
let batch = Ok(batch);
|
||||||
|
|
||||||
RecordBatchIterator::new(vec![batch], schema)
|
Box::new(RecordBatchIterator::new(vec![batch], schema))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -2165,10 +2253,7 @@ mod tests {
|
|||||||
let table = conn
|
let table = conn
|
||||||
.create_table(
|
.create_table(
|
||||||
"my_table",
|
"my_table",
|
||||||
Box::new(RecordBatchIterator::new(
|
RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()),
|
||||||
vec![Ok(batch.clone())],
|
|
||||||
batch.schema(),
|
|
||||||
)),
|
|
||||||
)
|
)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
@@ -2232,7 +2317,7 @@ mod tests {
|
|||||||
assert_eq!(table1.count_rows(None).await.unwrap(), 0);
|
assert_eq!(table1.count_rows(None).await.unwrap(), 0);
|
||||||
assert_eq!(table2.count_rows(None).await.unwrap(), 0);
|
assert_eq!(table2.count_rows(None).await.unwrap(), 0);
|
||||||
|
|
||||||
table1.add(Box::new(data)).execute().await.unwrap();
|
table1.add(data).execute().await.unwrap();
|
||||||
assert_eq!(table1.count_rows(None).await.unwrap(), 1);
|
assert_eq!(table1.count_rows(None).await.unwrap(), 1);
|
||||||
|
|
||||||
match interval {
|
match interval {
|
||||||
@@ -2265,21 +2350,13 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let table = conn
|
let table = conn
|
||||||
.create_table("my_table", Box::new(some_sample_data()))
|
.create_table("my_table", some_sample_data())
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let version = table.version().await.unwrap();
|
let version = table.version().await.unwrap();
|
||||||
table
|
table.add(some_sample_data()).execute().await.unwrap();
|
||||||
.add(Box::new(some_sample_data()))
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
table.checkout(version).await.unwrap();
|
table.checkout(version).await.unwrap();
|
||||||
assert!(table
|
assert!(table.add(some_sample_data()).execute().await.is_err())
|
||||||
.add(Box::new(some_sample_data()))
|
|
||||||
.execute()
|
|
||||||
.await
|
|
||||||
.is_err())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user