Compare commits

..

56 Commits

Author SHA1 Message Date
David Myriel
9e278fc5a6 fix small details 2025-05-05 23:03:17 +02:00
David Myriel
09fed1f286 add quickstart doc 2025-05-05 22:02:11 +02:00
Will Jones
cee2b5ea42 chore: upgrade pyarrow pin (#2192)
Closes #2191


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Chores**
- Updated the required version of the pyarrow package to version 16 or
higher.
- Adjusted automated testing workflows to install pyarrow version 16 for
compatibility checks.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-05-05 11:23:13 -07:00
Alex Pilon
f315f9665a feat: implement bindings to return merge stats (#2367)
Based on this comment:
https://github.com/lancedb/lancedb/issues/2228#issuecomment-2730463075
and https://github.com/lancedb/lance/pull/2357

Here is my attempt at implementing bindings for returning merge stats
from a `merge_insert.execute` call for lancedb.

Note: I have almost no idea what I am doing in Rust but tried to follow
existing code patterns and pay attention to compiler hints.
- The change in nodejs binding appeared to be necessary to get
compilation to work, presumably this could actual work properly by
returning some kind of NAPI JS object of the stats data?
- I am unsure of what to do with the remote/table.rs changes -
necessarily for compilation to work; I assume this is related to LanceDB
cloud, but unsure the best way to handle that at this point.

Proof of function:

```python
import pandas as pd
import lancedb


db = lancedb.connect("/tmp/test.db")

test_data = pd.DataFrame(
    {
        "title": ["Hello", "Test Document", "Example", "Data Sample", "Last One"],
        "id": [1, 2, 3, 4, 5],
        "content": [
            "World",
            "This is a test",
            "Another example",
            "More test data",
            "Final entry",
        ],
    }
)

table = db.create_table("documents", data=test_data, exist_ok=True, mode="overwrite")

update_data = pd.DataFrame(
    {
        "title": [
            "Hello, World",
            "Test Document, it's good",
            "Example",
            "Data Sample",
            "Last One",
            "New One",
        ],
        "id": [1, 2, 3, 4, 5, 6],
        "content": [
            "World",
            "This is a test",
            "Another example",
            "More test data",
            "Final entry",
            "New content",
        ],
    }
)

stats = (
    table.merge_insert(on="id")
    .when_matched_update_all()
    .when_not_matched_insert_all()
    .execute(update_data)
)

print(stats)
```

returns

```
{'num_inserted_rows': 1, 'num_updated_rows': 5, 'num_deleted_rows': 0}
```

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Merge-insert operations now return detailed statistics, including
counts of inserted, updated, and deleted rows.
- **Bug Fixes**
- Tests updated to validate returned merge-insert statistics for
accuracy.
- **Documentation**
- Method documentation improved to reflect new return values and clarify
merge operation results.
- Added documentation for the new `MergeStats` interface detailing
operation statistics.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2025-05-01 10:00:20 -07:00
Andrew C. Oliver
5deb26bc8b fix: prevent embedded objects from returning null in all of their fields (#2355)
metadata{filename=xyz} filename would be there structurally, but ALWAYS
null.

I didn't include this as a file but it may be useful for understanding
the problem for people searching on this issue so I'm including it here
as documentation. Before this patch any field that is more than 1 deep
is accepted but returns null values for subfields when queried.

```js
const lancedb = require('@lancedb/lancedb');

// Debug logger
function debug(message, data) {
  console.log(`[TEST] ${message}`, data !== undefined ? data : '');
}

// Log when our unwrapArrowObject is called
const kParent = Symbol.for("parent");
const kRowIndex = Symbol.for("rowIndex");

// Override console.log for our test
const originalConsoleLog = console.log;
console.log = function() {
  // Filter out noisy logs
  if (arguments[0] && typeof arguments[0] === 'string' && arguments[0].includes('[INFO] [LanceDB]')) {
    originalConsoleLog.apply(console, arguments);
  }
  originalConsoleLog.apply(console, arguments);
};

async function main() {
  debug('Starting test...');
  
  // Connect to the database
  debug('Connecting to database...');
  const db = await lancedb.connect('./.lancedb');
  
  // Try to open an existing table, or create a new one if it doesn't exist
  let table;
  try {
    table = await db.openTable('test_nested_fields');
    debug('Opened existing table');
  } catch (e) {
    debug('Creating new table...');
    
    // Create test data with nested metadata structure
    const data = [
      {
        id: 'test1',
        vector: [1, 2, 3],
        metadata: {
          filePath: "/path/to/file1.ts",
          startLine: 10,
          endLine: 20,
          text: "function test() { return true; }"
        }
      },
      {
        id: 'test2',
        vector: [4, 5, 6],
        metadata: {
          filePath: "/path/to/file2.ts",
          startLine: 30,
          endLine: 40,
          text: "function test2() { return false; }"
        }
      }
    ];
    
    debug('Data to be inserted:', JSON.stringify(data, null, 2));
    
    // Create the table
    table = await db.createTable('test_nested_fields', data);
    debug('Table created successfully');
  }
  
  // Query the table and get results
  debug('Querying table...');
  const results = await table.search([1, 2, 3]).limit(10).toArray();
  
  // Log the results
  debug('Number of results:', results.length);
  
  if (results.length > 0) {
    const firstResult = results[0];
    debug('First result properties:', Object.keys(firstResult));
    
    // Check if metadata is accessible and what properties it has
    if (firstResult.metadata) {
      debug('Metadata properties:', Object.keys(firstResult.metadata));
      debug('Metadata filePath:', firstResult.metadata.filePath);
      debug('Metadata startLine:', firstResult.metadata.startLine);
      
      // Destructure to see if that helps
      const { filePath, startLine, endLine, text } = firstResult.metadata;
      debug('Destructured values:', { filePath, startLine, endLine, text });
      
      // Check if it's a proxy object
      debug('Result is proxy?', Object.getPrototypeOf(firstResult) === Object.prototype ? false : true);
      debug('Metadata is proxy?', Object.getPrototypeOf(firstResult.metadata) === Object.prototype ? false : true);
    } else {
      debug('Metadata is not accessible!');
    }
  }
  
  // Close the database
  await db.close();
}

main().catch(e => {
  console.error('Error:', e);
}); 
```

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **Bug Fixes**
- Improved handling of nested struct fields to ensure accurate
preservation of values during serialization and deserialization.
- Enhanced robustness when accessing nested object properties, reducing
errors with missing or null values.

- **Tests**
- Added tests to verify correct handling of nested struct fields through
serialization and deserialization.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2025-05-01 09:38:55 -07:00
Lance Release
3cc670ac38 Updating package-lock.json 2025-04-29 23:21:19 +00:00
Lance Release
4ade3e31e2 Updating package-lock.json 2025-04-29 22:19:46 +00:00
Lance Release
a222d2cd91 Updating package-lock.json 2025-04-29 22:19:30 +00:00
Lance Release
508e621f3d Bump version: 0.19.1-beta.0 → 0.19.1-beta.1 2025-04-29 22:19:14 +00:00
Lance Release
a1a0472f3f Bump version: 0.22.1-beta.0 → 0.22.1-beta.1 2025-04-29 22:18:53 +00:00
Wyatt Alt
3425a6d339 feat: upgrade lance to v0.27.0-beta.2 (#2364)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Chores**
- Updated dependencies for related components to use the latest version
from a specific repository source. No changes to features or public
functionality.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-29 14:59:56 -07:00
Ryan Green
af54e0ce06 feat: add table stats API (#2363)
* Add a new "table stats" API to expose basic table and fragment
statistics with local and remote table implementations

### Questions
* This is using `calculate_data_stats` to determine total bytes in the
table. This seems like a potentially expensive operation - are there any
concerns about performance for large datasets?

### Notes
* bytes_on_disk seems to be stored at the column level but there does
not seem to be a way to easily calculate total bytes per fragment. This
may need to be added in lance before we can support fragment size
(bytes) statistics.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Added a method to retrieve comprehensive table statistics, including
total rows, index counts, storage size, and detailed fragment size
metrics such as minimum, maximum, mean, and percentiles.
- Enabled fetching of table statistics from remote sources through
asynchronous requests.
- Extended table interfaces across Python, Rust, and Node.js to support
synchronous and asynchronous retrieval of table statistics.
- **Tests**
- Introduced tests to verify the accuracy of the new table statistics
feature for both populated and empty tables.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-29 15:19:08 -02:30
Lance Release
089905fe8f Updating package-lock.json 2025-04-28 19:13:36 +00:00
Lance Release
554939e5d2 Updating package-lock.json 2025-04-28 17:20:58 +00:00
Lance Release
7a13814922 Updating package-lock.json 2025-04-28 17:20:42 +00:00
Lance Release
e9f25f6a12 Bump version: 0.19.0 → 0.19.1-beta.0 2025-04-28 17:20:26 +00:00
Lance Release
419a433244 Bump version: 0.22.0 → 0.22.1-beta.0 2025-04-28 17:20:10 +00:00
LuQQiu
a9311c4dc0 feat: add list/create/delete/update/checkout tag API (#2353)
add the tag related API to list existing tags, attach tag to a version,
update the tag version, delete tag, get the version of the tag, and
checkout the version that the tag bounded to.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced table version tagging, allowing users to create, update,
delete, and list human-readable tags for specific table versions.
  - Enabled checking out a table by either version number or tag name.
- Added new interfaces for tag management in both Python and Node.js
APIs, supporting synchronous and asynchronous workflows.

- **Bug Fixes**
  - None.

- **Documentation**
- Updated documentation to describe the new tagging features, including
usage examples.

- **Tests**
- Added comprehensive tests for tag creation, updating, deletion,
listing, and version checkout by tag in both Python and Node.js
environments.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-28 10:04:46 -07:00
LuQQiu
178bcf9c90 fix: hybrid search explain plan analyze plan (#2360)
Fix hybrid search explain plan analyze plan API

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Added options to view the execution plan and analyze the runtime
performance of hybrid queries.
- **Refactor**
- Improved internal handling of query setup for better modularity and
maintainability.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-27 18:39:43 -07:00
Lance Release
b9be092cb1 Updating package-lock.json 2025-04-25 22:05:57 +00:00
Lance Release
e8c0c52315 Updating package-lock.json 2025-04-25 21:17:03 +00:00
Lance Release
a60fa0d3b7 Updating package-lock.json 2025-04-25 21:16:48 +00:00
Lance Release
726d629b9b Bump version: 0.19.0-beta.12 → 0.19.0 2025-04-25 21:16:30 +00:00
Lance Release
b493f56dee Bump version: 0.19.0-beta.11 → 0.19.0-beta.12 2025-04-25 21:16:25 +00:00
Lance Release
a8b5ad7e74 Bump version: 0.22.0-beta.12 → 0.22.0 2025-04-25 21:16:07 +00:00
Lance Release
f8f6264883 Bump version: 0.22.0-beta.11 → 0.22.0-beta.12 2025-04-25 21:16:07 +00:00
Will Jones
d8517117f1 feat: upgrade Lance to v0.26.0 (#2359)
Upstream changelog:
https://github.com/lancedb/lance/releases/tag/v0.26.0

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Chores**
- Updated dependency management to use published crate versions for
improved reliability and maintainability.
- Added a temporary workaround for build issues by pinning a specific
version of a dependency.
- **Refactor**
- Improved resource management and concurrency by updating internal
ownership models for object storage components.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-25 13:59:12 -07:00
Lance Release
ab66dd5ed2 Updating package-lock.json 2025-04-25 06:04:06 +00:00
Lance Release
cbb9a7877c Updating package-lock.json 2025-04-25 05:02:47 +00:00
Lance Release
b7fc223535 Updating package-lock.json 2025-04-25 05:02:32 +00:00
Lance Release
1fdaf7a1a4 Bump version: 0.19.0-beta.10 → 0.19.0-beta.11 2025-04-25 05:02:16 +00:00
Lance Release
d11819c90c Bump version: 0.22.0-beta.10 → 0.22.0-beta.11 2025-04-25 05:01:57 +00:00
BubbleCal
9b902272f1 fix: sync hybrid search ignores the distance range params (#2356)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for distance range filtering in hybrid vector queries,
allowing users to specify lower and upper bounds for search results.

- **Tests**
- Introduced new tests to validate distance range filtering and
reranking in both synchronous and asynchronous hybrid query scenarios.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-04-25 13:01:22 +08:00
Will Jones
8c0622fa2c fix: remote limit to avoid "Limit must be non-negative" (#2354)
To workaround this issue: https://github.com/lancedb/lancedb/issues/2211

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Improved handling of large query parameters to prevent potential
overflow issues when using the "k" parameter in queries.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-24 15:04:06 -07:00
Philip Meier
2191f948c3 fix: add missing pydantic model config compat (#2316)
Fixes #2315.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Refactor**
- Enhanced query processing to maintain smooth functionality across
different dependency versions, ensuring improved stability and
performance.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-22 14:46:10 -07:00
Will Jones
acc3b03004 ci: fix docs deploy (#2351)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Chores**
- Improved CI workflow for documentation builds by optimizing Rust build
settings and updating the runner environment.
  - Fixed a typo in a workflow step name.
- Streamlined caching steps to reduce redundancy and improve efficiency.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-22 13:55:34 -07:00
Lance Release
7f091b8c8e Updating package-lock.json 2025-04-22 19:16:43 +00:00
Lance Release
c19bdd9a24 Updating package-lock.json 2025-04-22 18:24:16 +00:00
Lance Release
dad0ff5cd2 Updating package-lock.json 2025-04-22 18:23:59 +00:00
Lance Release
a705621067 Bump version: 0.19.0-beta.9 → 0.19.0-beta.10 2025-04-22 18:23:39 +00:00
Lance Release
39614fdb7d Bump version: 0.22.0-beta.9 → 0.22.0-beta.10 2025-04-22 18:23:17 +00:00
Ryan Green
96d534d4bc feat: add retries to remote client for requests with stream bodies (#2349)
Closes https://github.com/lancedb/lancedb/issues/2307
* Adds retries to remote operations with stream bodies (add,
merge_insert)
* Change default retryable status codes to 409, 429, 500, 502, 503, 504
* Don't retry add or merge_insert operations on 5xx responses

Notes:
* Supporting retries on stream bodies means we have to buffer the body
into memory so it can be cloned on retry. This will impact memory use
patterns for the remote client. This buffering can be disabled by
disabling retries (i.e. setting retries to 0 in RetryConfig)
* It does not seem that retry config can be specified by env vars as the
documentation suggests. I added a follow-up issue
[here](https://github.com/lancedb/lancedb/issues/2350)



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Enhanced retry support for remote requests with configurable limits
and exponential backoff with jitter.
- Added robust retry logic for streaming data uploads, enabling retries
with buffered data to ensure reliability.

- **Bug Fixes**
- Improved error handling and retry behavior for HTTP status codes 409
and 504.

- **Refactor**
- Centralized and modularized HTTP request sending and retry logic
across remote database and table operations.
  - Streamlined request ID management for improved traceability.
- Simplified error message construction in index waiting functionality.

- **Tests**
  - Added a test verifying merge-insert retries on HTTP 409 responses.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-22 15:40:44 -02:30
Lance Release
5051d30d09 Updating package-lock.json 2025-04-21 23:55:43 +00:00
Lance Release
db853c4041 Updating package-lock.json 2025-04-21 22:50:56 +00:00
Lance Release
76d1d22bdc Updating package-lock.json 2025-04-21 22:50:40 +00:00
Lance Release
d8746c61c6 Bump version: 0.19.0-beta.8 → 0.19.0-beta.9 2025-04-21 22:50:20 +00:00
Lance Release
1a66df2627 Bump version: 0.22.0-beta.8 → 0.22.0-beta.9 2025-04-21 22:49:59 +00:00
Will Jones
44670076c1 fix: move timeout to avoid retries (#2347)
I added a timeout to query execution options in
https://github.com/lancedb/lancedb/pull/2288. However, this was send to
the request timeout, but the retry implementation is unaware of this
timeout. So once the query timed out, a retry would be triggered.
Instead, this PR changes it so the timeout happens outside the retry
loop.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Bug Fixes**
- Improved query timeout handling to provide clearer error messages and
more reliable cancellation if a query takes too long to complete.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 14:27:04 -07:00
Will Jones
92f0b16e46 fix(python): make sure pandas is optional (#2346)
Fixes #2344


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Tests**
- Updated tests to use PyArrow Tables instead of pandas DataFrames where
possible, reducing reliance on pandas.
- Tests that require pandas are now automatically skipped if pandas is
not installed.
- **Chores**
- Improved workflow to uninstall both pylance and pandas in a specific
test step.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 13:42:13 -07:00
Eileen Noonan
1620ba3508 docs: make table.update() nodejs guide consistent with API documentation (#2334)
The docs in the Guide here do not match the [API reference]
(https://lancedb.github.io/lancedb/js/classes/Table/#updateopts) for the
nodejs client.

I am writing an Elixir wrapper over the typescript library (Rust
forthcoming!) and confirmed in testing that the API reference is correct
vs the Guide.

Following the Guide docs, the error I got was:

"lance error: Invalid user input: Schema error: No field named bar.
Valid fields are foo. For a query of:

await table.update({foo: "buzz"}, { where: "foo = 'bar'"});
Over a table with a schema of just {foo: Utf8}.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Documentation**
- Reformatted a code snippet in the guide to enhance readability by
splitting it into multiple lines for improved clarity.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 08:38:16 -07:00
Ryan Green
3ae90dde80 feat: add new table API to wait for async indexing (#2338)
* Add new wait_for_index() table operation that polls until indices are
created/fully indexed
* Add an optional wait timeout parameter to all create_index operations
* Python and NodeJS interfaces

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Added optional waiting for index creation completion with configurable
timeout.
- Introduced methods to poll and wait for indices to be fully built
across sync and async tables.
  - Extended index creation APIs to accept a wait timeout parameter.
- **Bug Fixes**
- Added a new timeout error variant for improved error reporting on
index operations.
- **Tests**
- Added tests covering successful index readiness waiting, timeout
scenarios, and missing index cases.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 08:41:21 -02:30
Magnus
4f07fea6df feat: add ColPali embedding support with MultiVector type (#2170)
This PR adds ColPali support with ColPaliEmbeddings class (tagged
"colpali") using ColQwen2.5 for multi-vector text/image embeddings. Also
added MultiVector Pydantic type to handle the vector lists.

I've added some integration test for the embedding model and some unit
test for the new Pydantic type. Could be a template for other ColPali
variants as well. or until transformers🤗 starts supporting it.


Still `TODO`:

- [ ] Documentation
- [ ] Add an example

_Could also allow Image as query, but didn't work well when testing it._

[ColPali-Engine](https://github.com/illuin-tech/colpali) version:
0.3.9.dev17+g3faee24

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced support for ColPali-based multimodal multi-vector
embeddings for both text and images.
- Added a new embedding class for generating multi-vector embeddings,
configurable for various model and processing options.
- Added a new Pydantic type for multi-vector embeddings, supporting
validation and schema generation for lists of fixed-dimension vectors.

- **Bug Fixes**
- Ensured proper asynchronous index creation in query tests for improved
reliability.

- **Tests**
- Added integration tests for ColPali embeddings, including
text-to-image search and validation of multi-vector fields.
- Added comprehensive tests for the new multi-vector Pydantic type,
covering schema, validation, and default value behavior.

- **Chores**
  - Updated optional dependencies to include the ColPali engine.
  - Added utility to check for availability of flash attention support.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 11:47:37 +08:00
Lance Release
3d7d82cf86 Updating package-lock.json 2025-04-17 23:13:37 +00:00
Lance Release
edc4e40a7b Updating package-lock.json 2025-04-17 22:16:36 +00:00
Lance Release
ca3806a02f Updating package-lock.json 2025-04-17 22:16:20 +00:00
Lance Release
35cff12e31 Bump version: 0.19.0-beta.7 → 0.19.0-beta.8 2025-04-17 22:16:02 +00:00
79 changed files with 4160 additions and 669 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.19.0-beta.7"
current_version = "0.19.1-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -18,17 +18,24 @@ concurrency:
group: "pages"
cancel-in-progress: true
env:
# This reduces the disk space needed for the build
RUSTFLAGS: "-C debuginfo=0"
# according to: https://matklad.github.io/2021/09/04/fast-rust-builds.html
# CI builds are faster with incremental disabled.
CARGO_INCREMENTAL: "0"
jobs:
# Single deploy job since we're just deploying
build:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: buildjet-8vcpu-ubuntu-2204
runs-on: ubuntu-24.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install dependecies needed for ubuntu
- name: Install dependencies needed for ubuntu
run: |
sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default
@@ -38,6 +45,7 @@ jobs:
python-version: "3.10"
cache: "pip"
cache-dependency-path: "docs/requirements.txt"
- uses: Swatinem/rust-cache@v2
- name: Build Python
working-directory: python
run: |
@@ -49,7 +57,6 @@ jobs:
node-version: 20
cache: 'npm'
cache-dependency-path: node/package-lock.json
- uses: Swatinem/rust-cache@v2
- name: Install node dependencies
working-directory: node
run: |

View File

@@ -136,9 +136,9 @@ jobs:
- uses: ./.github/workflows/run_tests
with:
integration: true
- name: Test without pylance
- name: Test without pylance or pandas
run: |
pip uninstall -y pylance
pip uninstall -y pylance pandas
pytest -vv python/tests/test_table.py
# Make sure wheels are not included in the Rust cache
- name: Delete wheels
@@ -228,6 +228,7 @@ jobs:
- name: Install lancedb
run: |
pip install "pydantic<2"
pip install pyarrow==16
pip install --extra-index-url https://pypi.fury.io/lancedb/ -e .[tests]
pip install tantivy
- name: Run tests

427
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -21,16 +21,14 @@ categories = ["database-implementations"]
rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.26.0", "features" = [
"dynamodb",
], tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance-io = { version = "=0.26.0", tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance-index = { version = "=0.26.0", tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance-linalg = { version = "=0.26.0", tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance-table = { version = "=0.26.0", tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance-testing = { version = "=0.26.0", tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance-datafusion = { version = "=0.26.0", tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance-encoding = { version = "=0.26.0", tag = "v0.26.0-beta.1", git = "https://github.com/lancedb/lance" }
lance = { "version" = "=0.27.0", "features" = ["dynamodb"], tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
lance-io = { version = "=0.27.0", tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
lance-index = { version = "=0.27.0", tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
lance-linalg = { version = "=0.27.0", tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
lance-table = { version = "=0.27.0", tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
lance-testing = { version = "=0.27.0", tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
lance-datafusion = { version = "=0.27.0", tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
lance-encoding = { version = "=0.27.0", tag = "v0.27.0-beta.2", git="https://github.com/lancedb/lance.git" }
# Note that this one does not include pyarrow
arrow = { version = "54.1", optional = false }
arrow-array = "54.1"

View File

@@ -105,7 +105,8 @@ markdown_extensions:
nav:
- Home:
- LanceDB: index.md
- 🏃🏼‍♂️ Quick start: basic.md
- 👉 Quickstart: quickstart.md
- 🏃🏼‍♂️ Basic Usage: basic.md
- 📚 Concepts:
- Vector search: concepts/vector_search.md
- Indexing:
@@ -237,7 +238,9 @@ nav:
- 👾 JavaScript (lancedb): js/globals.md
- 🦀 Rust: https://docs.rs/lancedb/latest/lancedb/
- Quick start: basic.md
- Getting Started:
- Quickstart: quickstart.md
- Basic Usage: basic.md
- Concepts:
- Vector search: concepts/vector_search.md
- Indexing:

View File

@@ -1,4 +1,4 @@
# Quick start
# Basic Usage
!!! info "LanceDB can be run in a number of ways:"

View File

@@ -765,7 +765,10 @@ This can be used to update zero to all rows depending on how many rows match the
];
const tbl = await db.createTable("my_table", data)
await tbl.update({vector: [10, 10]}, { where: "x = 2"})
await tbl.update({
values: { vector: [10, 10] },
where: "x = 2"
});
```
=== "vectordb (deprecated)"
@@ -784,7 +787,10 @@ This can be used to update zero to all rows depending on how many rows match the
];
const tbl = await db.createTable("my_table", data)
await tbl.update({ where: "x = 2", values: {vector: [10, 10]} })
await tbl.update({
where: "x = 2",
values: { vector: [10, 10] }
});
```
#### Updating using a sql query

View File

@@ -33,20 +33,20 @@ Construct a MergeInsertBuilder. __Internal use only.__
### execute()
```ts
execute(data): Promise<void>
execute(data): Promise<MergeStats>
```
Executes the merge insert operation
Nothing is returned but the `Table` is updated
#### Parameters
* **data**: [`Data`](../type-aliases/Data.md)
#### Returns
`Promise`&lt;`void`&gt;
`Promise`&lt;[`MergeStats`](../interfaces/MergeStats.md)&gt;
Statistics about the merge operation: counts of inserted, updated, and deleted rows
***

View File

@@ -117,8 +117,8 @@ wish to return to standard mode, call `checkoutLatest`.
#### Parameters
* **version**: `number`
The version to checkout
* **version**: `string` \| `number`
The version to checkout, could be version number or tag
#### Returns
@@ -615,6 +615,50 @@ of the given query
***
### stats()
```ts
abstract stats(): Promise<TableStatistics>
```
Returns table and fragment statistics
#### Returns
`Promise`&lt;[`TableStatistics`](../interfaces/TableStatistics.md)&gt;
The table and fragment statistics
***
### tags()
```ts
abstract tags(): Promise<Tags>
```
Get a tags manager for this table.
Tags allow you to label specific versions of a table with a human-readable name.
The returned tags manager can be used to list, create, update, or delete tags.
#### Returns
`Promise`&lt;[`Tags`](Tags.md)&gt;
A tags manager for this table
#### Example
```typescript
const tagsManager = await table.tags();
await tagsManager.create("v1", 1);
const tags = await tagsManager.list();
console.log(tags); // { "v1": { version: 1, manifestSize: ... } }
```
***
### toArrow()
```ts
@@ -753,3 +797,26 @@ Retrieve the version of the table
#### Returns
`Promise`&lt;`number`&gt;
***
### waitForIndex()
```ts
abstract waitForIndex(indexNames, timeoutSeconds): Promise<void>
```
Waits for asynchronous indexing to complete on the table.
#### Parameters
* **indexNames**: `string`[]
The name of the indices to wait for
* **timeoutSeconds**: `number`
The number of seconds to wait before timing out
This will raise an error if the indices are not created and fully indexed within the timeout.
#### Returns
`Promise`&lt;`void`&gt;

View File

@@ -0,0 +1,35 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / TagContents
# Class: TagContents
## Constructors
### new TagContents()
```ts
new TagContents(): TagContents
```
#### Returns
[`TagContents`](TagContents.md)
## Properties
### manifestSize
```ts
manifestSize: number;
```
***
### version
```ts
version: number;
```

View File

@@ -0,0 +1,99 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / Tags
# Class: Tags
## Constructors
### new Tags()
```ts
new Tags(): Tags
```
#### Returns
[`Tags`](Tags.md)
## Methods
### create()
```ts
create(tag, version): Promise<void>
```
#### Parameters
* **tag**: `string`
* **version**: `number`
#### Returns
`Promise`&lt;`void`&gt;
***
### delete()
```ts
delete(tag): Promise<void>
```
#### Parameters
* **tag**: `string`
#### Returns
`Promise`&lt;`void`&gt;
***
### getVersion()
```ts
getVersion(tag): Promise<number>
```
#### Parameters
* **tag**: `string`
#### Returns
`Promise`&lt;`number`&gt;
***
### list()
```ts
list(): Promise<Record<string, TagContents>>
```
#### Returns
`Promise`&lt;`Record`&lt;`string`, [`TagContents`](TagContents.md)&gt;&gt;
***
### update()
```ts
update(tag, version): Promise<void>
```
#### Parameters
* **tag**: `string`
* **version**: `number`
#### Returns
`Promise`&lt;`void`&gt;

View File

@@ -27,6 +27,8 @@
- [QueryBase](classes/QueryBase.md)
- [RecordBatchIterator](classes/RecordBatchIterator.md)
- [Table](classes/Table.md)
- [TagContents](classes/TagContents.md)
- [Tags](classes/Tags.md)
- [VectorColumnOptions](classes/VectorColumnOptions.md)
- [VectorQuery](classes/VectorQuery.md)
@@ -40,6 +42,8 @@
- [ConnectionOptions](interfaces/ConnectionOptions.md)
- [CreateTableOptions](interfaces/CreateTableOptions.md)
- [ExecutableQuery](interfaces/ExecutableQuery.md)
- [FragmentStatistics](interfaces/FragmentStatistics.md)
- [FragmentSummaryStats](interfaces/FragmentSummaryStats.md)
- [FtsOptions](interfaces/FtsOptions.md)
- [FullTextQuery](interfaces/FullTextQuery.md)
- [FullTextSearchOptions](interfaces/FullTextSearchOptions.md)
@@ -50,6 +54,7 @@
- [IndexStatistics](interfaces/IndexStatistics.md)
- [IvfFlatOptions](interfaces/IvfFlatOptions.md)
- [IvfPqOptions](interfaces/IvfPqOptions.md)
- [MergeStats](interfaces/MergeStats.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
- [OptimizeStats](interfaces/OptimizeStats.md)
@@ -57,6 +62,7 @@
- [RemovalStats](interfaces/RemovalStats.md)
- [RetryConfig](interfaces/RetryConfig.md)
- [TableNamesOptions](interfaces/TableNamesOptions.md)
- [TableStatistics](interfaces/TableStatistics.md)
- [TimeoutConfig](interfaces/TimeoutConfig.md)
- [UpdateOptions](interfaces/UpdateOptions.md)
- [Version](interfaces/Version.md)

View File

@@ -0,0 +1,37 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / FragmentStatistics
# Interface: FragmentStatistics
## Properties
### lengths
```ts
lengths: FragmentSummaryStats;
```
Statistics on the number of rows in the table fragments
***
### numFragments
```ts
numFragments: number;
```
The number of fragments in the table
***
### numSmallFragments
```ts
numSmallFragments: number;
```
The number of uncompacted fragments in the table

View File

@@ -0,0 +1,77 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / FragmentSummaryStats
# Interface: FragmentSummaryStats
## Properties
### max
```ts
max: number;
```
The number of rows in the fragment with the most rows
***
### mean
```ts
mean: number;
```
The mean number of rows in the fragments
***
### min
```ts
min: number;
```
The number of rows in the fragment with the fewest rows
***
### p25
```ts
p25: number;
```
The 25th percentile of number of rows in the fragments
***
### p50
```ts
p50: number;
```
The 50th percentile of number of rows in the fragments
***
### p75
```ts
p75: number;
```
The 75th percentile of number of rows in the fragments
***
### p99
```ts
p99: number;
```
The 99th percentile of number of rows in the fragments

View File

@@ -39,3 +39,11 @@ and the same name, then an error will be returned. This is true even if
that index is out of date.
The default is true
***
### waitTimeoutSeconds?
```ts
optional waitTimeoutSeconds: number;
```

View File

@@ -0,0 +1,31 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / MergeStats
# Interface: MergeStats
## Properties
### numDeletedRows
```ts
numDeletedRows: bigint;
```
***
### numInsertedRows
```ts
numInsertedRows: bigint;
```
***
### numUpdatedRows
```ts
numUpdatedRows: bigint;
```

View File

@@ -0,0 +1,47 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / TableStatistics
# Interface: TableStatistics
## Properties
### fragmentStats
```ts
fragmentStats: FragmentStatistics;
```
Statistics on table fragments
***
### numIndices
```ts
numIndices: number;
```
The number of indices in the table
***
### numRows
```ts
numRows: number;
```
The number of rows in the table
***
### totalBytes
```ts
totalBytes: number;
```
The total number of bytes in the table

101
docs/src/quickstart.md Normal file
View File

@@ -0,0 +1,101 @@
# Getting Started with LanceDB: A Minimal Vector Search Tutorial
Let's set up a LanceDB database, insert vector data, and perform a simple vector search. We'll use simple character classes like "knight" and "rogue" to illustrate semantic relevance.
## 1. Install Dependencies
Before starting, make sure you have the necessary packages:
```bash
pip install lancedb pandas numpy
```
## 2. Import Required Libraries
```python
import lancedb
import pandas as pd
import numpy as np
```
## 3. Connect to LanceDB
You can use a local directory to store your database:
```python
db = lancedb.connect("./lancedb")
```
## 4. Create Sample Data
Add sample text data and corresponding 4D vectors:
```python
data = pd.DataFrame([
{"id": "1", "vector": [1.0, 0.0, 0.0, 0.0], "text": "knight"},
{"id": "2", "vector": [0.9, 0.1, 0.0, 0.0], "text": "warrior"},
{"id": "3", "vector": [0.0, 1.0, 0.0, 0.0], "text": "rogue"},
{"id": "4", "vector": [0.0, 0.9, 0.1, 0.0], "text": "thief"},
{"id": "5", "vector": [0.5, 0.5, 0.0, 0.0], "text": "ranger"},
])
```
## 5. Create a Table in LanceDB
```python
table = db.create_table("rpg_classes", data=data, mode="overwrite")
```
Let's see how the table looks:
```python
print(data)
```
| id | vector | text |
|----|--------|------|
| 1 | [1.0, 0.0, 0.0, 0.0] | knight |
| 2 | [0.9, 0.1, 0.0, 0.0] | warrior |
| 3 | [0.0, 1.0, 0.0, 0.0] | rogue |
| 4 | [0.0, 0.9, 0.1, 0.0] | thief |
| 5 | [0.5, 0.5, 0.0, 0.0] | ranger |
## 6. Perform a Vector Search
Search for the most similar character classes to our query vector:
```python
# Query as if we are searching for "rogue"
results = table.search([0.95, 0.05, 0.0, 0.0]).limit(3).to_df()
print(results)
```
This will return the top 3 closest classes to the vector, effectively showing how LanceDB can be used for semantic search.
| id | vector | text | _distance |
|------|------------------------|----------|-----------|
| 3 | [0.0, 1.0, 0.0, 0.0] | rogue | 0.00 |
| 4 | [0.0, 0.9, 0.1, 0.0] | thief | 0.02 |
| 5 | [0.5, 0.5, 0.0, 0.0] | ranger | 0.50 |
Let's try searching for "knight"
```python
query_vector = [1.0, 0.0, 0.0, 0.0]
results = table.search(query_vector).limit(3).to_pandas()
print(results)
```
| id | vector | text | _distance |
|------|------------------------|----------|-----------|
| 1 | [1.0, 0.0, 0.0, 0.0] | knight | 0.00 |
| 2 | [0.9, 0.1, 0.0, 0.0] | warrior | 0.02 |
| 5 | [0.5, 0.5, 0.0, 0.0] | ranger | 0.50 |
## Next Steps
That's it - you just conducted vector search!
For more beginner tips, check out the [Basic Usage](basic.md) guide.

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.7</version>
<version>0.19.1-beta.1</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.19.0-beta.7</version>
<version>0.19.1-beta.1</version>
<packaging>pom</packaging>
<name>LanceDB Parent</name>

44
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"cpu": [
"x64",
"arm64"
@@ -52,11 +52,11 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.7",
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.7"
"@lancedb/vectordb-darwin-arm64": "0.19.1-beta.1",
"@lancedb/vectordb-darwin-x64": "0.19.1-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.1-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.19.1-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.19.1-beta.1"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
@@ -327,9 +327,9 @@
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.0-beta.7.tgz",
"integrity": "sha512-HpbVKw4Vs+mPv7uPwaK7ilJlGrGdjOrNlC2mSkMCj0OlEwGRVcEcrSyijI7LXQH7ybEgNnDhSds5TuzBV26SGg==",
"version": "0.19.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.19.1-beta.1.tgz",
"integrity": "sha512-Epvel0pF5TM6MtIWQ2KhqezqSSHTL3Wr7a2rGAwz6X/XY23i6DbMPpPs0HyeIDzDrhxNfE3cz3S+SiCA6xpR0g==",
"cpu": [
"arm64"
],
@@ -340,9 +340,9 @@
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.0-beta.7.tgz",
"integrity": "sha512-x3X7nqIYVZtxaa0uZUk/M99vKvDinZ5G0+8k2NqZ696YXGWKGyRxR6k8ZzKYCoCTSuYXnBftgKoIlwJGtNt8Bw==",
"version": "0.19.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.19.1-beta.1.tgz",
"integrity": "sha512-hOiUSlIoISbiXytp46hToi/r6sF5pImAsfbzCsIq8ExDV4TPa8fjbhcIT80vxxOwc2mpSSK4HsVJYod95RSbEQ==",
"cpu": [
"x64"
],
@@ -353,9 +353,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.0-beta.7.tgz",
"integrity": "sha512-Vwj0HI3+b4NgXKf+5+W/GfLBCGoQMBGM47vA/ts1dpe/PxraOQYPDv67I5kbXkCQKwhal7b0iZx/PbMu0JZPyw==",
"version": "0.19.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.19.1-beta.1.tgz",
"integrity": "sha512-/1JhGVDEngwrlM8o2TNW8G6nJ9U/VgHKAORmj/cTA7O30helJIoo9jfvUAUy+vZ4VoEwRXQbMI+gaYTg0l3MTg==",
"cpu": [
"arm64"
],
@@ -366,9 +366,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.0-beta.7.tgz",
"integrity": "sha512-Dx2B6UWQei9D7Rt+MgHWqPTYtEK2w3EgsNb5ENEWUTZxH7lD/CV7Sw0JMK5LDG209fFcpXFerveF6J8ZC8uGBQ==",
"version": "0.19.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.19.1-beta.1.tgz",
"integrity": "sha512-zNRGSSUt8nTJMmll4NdxhQjwxR8Rezq3T4dsRoiDts5ienMam5HFjYiZ3FkDZQo16rgq2BcbFuH1G8u1chywlg==",
"cpu": [
"x64"
],
@@ -379,9 +379,9 @@
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.19.0-beta.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.0-beta.7.tgz",
"integrity": "sha512-F5LZGa+gkUH1TgsWZWLLAMejwXFIWdash7+85ip4k2M0ThyqLF/dtlldOvteUEd5+flxihGjHg6TUtnSY8XBFA==",
"version": "0.19.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.19.1-beta.1.tgz",
"integrity": "sha512-yV550AJGlsIFdm1KoHQPJ1TZx121ZXCIdebBtBZj3wOObIhyB/i0kZAtGvwjkmr7EYyfzt1EHZzbjSGVdehIAA==",
"cpu": [
"x64"
],

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"description": " Serverless, low-latency vector database for AI applications",
"private": false,
"main": "dist/index.js",
@@ -89,10 +89,10 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-x64": "0.19.0-beta.7",
"@lancedb/vectordb-darwin-arm64": "0.19.0-beta.7",
"@lancedb/vectordb-linux-x64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.0-beta.7",
"@lancedb/vectordb-win32-x64-msvc": "0.19.0-beta.7"
"@lancedb/vectordb-darwin-x64": "0.19.1-beta.1",
"@lancedb/vectordb-darwin-arm64": "0.19.1-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.19.1-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.19.1-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.19.1-beta.1"
}
}

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.19.0-beta.7"
version = "0.19.1-beta.1"
license.workspace = true
description.workspace = true
repository.workspace = true
@@ -28,6 +28,9 @@ napi-derive = "2.16.4"
lzma-sys = { version = "*", features = ["static"] }
log.workspace = true
# Workaround for build failure until we can fix it.
aws-lc-sys = "=0.28.0"
[build-dependencies]
napi-build = "2.1"

View File

@@ -374,6 +374,71 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
expect(table2.numRows).toBe(4);
expect(table2.schema).toEqual(schema);
});
it("should correctly retain values in nested struct fields", async function () {
// Define test data with nested struct
const testData = [
{
id: "doc1",
vector: [1, 2, 3],
metadata: {
filePath: "/path/to/file1.ts",
startLine: 10,
endLine: 20,
text: "function test() { return true; }",
},
},
{
id: "doc2",
vector: [4, 5, 6],
metadata: {
filePath: "/path/to/file2.ts",
startLine: 30,
endLine: 40,
text: "function test2() { return false; }",
},
},
];
// Create Arrow table from the data
const table = makeArrowTable(testData);
// Verify schema has the nested struct fields
const metadataField = table.schema.fields.find(
(f) => f.name === "metadata",
);
expect(metadataField).toBeDefined();
// biome-ignore lint/suspicious/noExplicitAny: accessing fields in different Arrow versions
const childNames = metadataField?.type.children.map((c: any) => c.name);
expect(childNames).toEqual([
"filePath",
"startLine",
"endLine",
"text",
]);
// Convert to buffer and back (simulating storage and retrieval)
const buf = await fromTableToBuffer(table);
const retrievedTable = tableFromIPC(buf);
// Verify the retrieved table has the same structure
const rows = [];
for (let i = 0; i < retrievedTable.numRows; i++) {
rows.push(retrievedTable.get(i));
}
// Check values in the first row
const firstRow = rows[0];
expect(firstRow.id).toBe("doc1");
expect(firstRow.vector.toJSON()).toEqual([1, 2, 3]);
// Verify metadata values are preserved (this is where the bug is)
expect(firstRow.metadata).toBeDefined();
expect(firstRow.metadata.filePath).toBe("/path/to/file1.ts");
expect(firstRow.metadata.startLine).toBe(10);
expect(firstRow.metadata.endLine).toBe(20);
expect(firstRow.metadata.text).toBe("function test() { return true; }");
});
});
class DummyEmbedding extends EmbeddingFunction<string> {

View File

@@ -71,6 +71,29 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
await expect(table.countRows()).resolves.toBe(3);
});
it("should show table stats", async () => {
await table.add([{ id: 1 }, { id: 2 }]);
await table.add([{ id: 1 }]);
await expect(table.stats()).resolves.toEqual({
fragmentStats: {
lengths: {
max: 2,
mean: 1,
min: 1,
p25: 1,
p50: 2,
p75: 2,
p99: 2,
},
numFragments: 2,
numSmallFragments: 2,
},
numIndices: 0,
numRows: 3,
totalBytes: 24,
});
});
it("should overwrite data if asked", async () => {
await table.add([{ id: 1 }, { id: 2 }]);
await table.add([{ id: 1 }], { mode: "overwrite" });
@@ -315,11 +338,16 @@ describe("merge insert", () => {
{ a: 3, b: "y" },
{ a: 4, b: "z" },
];
await table
const stats = await table
.mergeInsert("a")
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.execute(newData);
expect(stats.numInsertedRows).toBe(1n);
expect(stats.numUpdatedRows).toBe(2n);
expect(stats.numDeletedRows).toBe(0n);
const expected = [
{ a: 1, b: "a" },
{ a: 2, b: "x" },
@@ -507,6 +535,15 @@ describe("When creating an index", () => {
expect(indices2.length).toBe(0);
});
it("should wait for index readiness", async () => {
// Create an index and then wait for it to be ready
await tbl.createIndex("vec");
const indices = await tbl.listIndices();
expect(indices.length).toBeGreaterThan(0);
const idxName = indices[0].name;
await expect(tbl.waitForIndex([idxName], 5)).resolves.toBeUndefined();
});
it("should search with distance range", async () => {
await tbl.createIndex("vec");
@@ -824,6 +861,7 @@ describe("When creating an index", () => {
// Only build index over v1
await tbl.createIndex("vec", {
config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }),
waitTimeoutSeconds: 30,
});
const rst = await tbl
@@ -1168,6 +1206,73 @@ describe("when dealing with versioning", () => {
});
});
describe("when dealing with tags", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
});
it("can manage tags", async () => {
const conn = await connect(tmpDir.name, {
readConsistencyInterval: 0,
});
const table = await conn.createTable("my_table", [
{ id: 1n, vector: [0.1, 0.2] },
]);
expect(await table.version()).toBe(1);
await table.add([{ id: 2n, vector: [0.3, 0.4] }]);
expect(await table.version()).toBe(2);
const tagsManager = await table.tags();
const initialTags = await tagsManager.list();
expect(Object.keys(initialTags).length).toBe(0);
const tag1 = "tag1";
await tagsManager.create(tag1, 1);
expect(await tagsManager.getVersion(tag1)).toBe(1);
const tagsAfterFirst = await tagsManager.list();
expect(Object.keys(tagsAfterFirst).length).toBe(1);
expect(tagsAfterFirst).toHaveProperty(tag1);
expect(tagsAfterFirst[tag1].version).toBe(1);
await tagsManager.create("tag2", 2);
expect(await tagsManager.getVersion("tag2")).toBe(2);
const tagsAfterSecond = await tagsManager.list();
expect(Object.keys(tagsAfterSecond).length).toBe(2);
expect(tagsAfterSecond).toHaveProperty(tag1);
expect(tagsAfterSecond[tag1].version).toBe(1);
expect(tagsAfterSecond).toHaveProperty("tag2");
expect(tagsAfterSecond["tag2"].version).toBe(2);
await table.add([{ id: 3n, vector: [0.5, 0.6] }]);
await tagsManager.update(tag1, 3);
expect(await tagsManager.getVersion(tag1)).toBe(3);
await tagsManager.delete("tag2");
const tagsAfterDelete = await tagsManager.list();
expect(Object.keys(tagsAfterDelete).length).toBe(1);
expect(tagsAfterDelete).toHaveProperty(tag1);
expect(tagsAfterDelete[tag1].version).toBe(3);
await table.add([{ id: 4n, vector: [0.7, 0.8] }]);
expect(await table.version()).toBe(4);
await table.checkout(tag1);
expect(await table.version()).toBe(3);
await table.checkoutLatest();
expect(await table.version()).toBe(4);
});
});
describe("when optimizing a dataset", () => {
let tmpDir: tmp.DirResult;
let table: Table;

View File

@@ -639,8 +639,9 @@ function transposeData(
): Vector {
if (field.type instanceof Struct) {
const childFields = field.type.children;
const fullPath = [...path, field.name];
const childVectors = childFields.map((child) => {
return transposeData(data, child, [...path, child.name]);
return transposeData(data, child, fullPath);
});
const structData = makeData({
type: field.type,
@@ -652,7 +653,14 @@ function transposeData(
const values = data.map((datum) => {
let current: unknown = datum;
for (const key of valuesPath) {
if (isObject(current) && Object.hasOwn(current, key)) {
if (current == null) {
return null;
}
if (
isObject(current) &&
(Object.hasOwn(current, key) || key in current)
) {
current = current[key];
} else {
return null;

View File

@@ -23,6 +23,12 @@ export {
OptimizeStats,
CompactionStats,
RemovalStats,
TableStatistics,
FragmentStatistics,
FragmentSummaryStats,
Tags,
TagContents,
MergeStats,
} from "./native.js";
export {

View File

@@ -681,4 +681,6 @@ export interface IndexOptions {
* The default is true
*/
replace?: boolean;
waitTimeoutSeconds?: number;
}

View File

@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { Data, Schema, fromDataToBuffer } from "./arrow";
import { NativeMergeInsertBuilder } from "./native";
import { MergeStats, NativeMergeInsertBuilder } from "./native";
/** A builder used to create and run a merge insert operation */
export class MergeInsertBuilder {
@@ -73,9 +73,9 @@ export class MergeInsertBuilder {
/**
* Executes the merge insert operation
*
* Nothing is returned but the `Table` is updated
* @returns Statistics about the merge operation: counts of inserted, updated, and deleted rows
*/
async execute(data: Data): Promise<void> {
async execute(data: Data): Promise<MergeStats> {
let schema: Schema;
if (this.#schema instanceof Promise) {
schema = await this.#schema;
@@ -84,6 +84,6 @@ export class MergeInsertBuilder {
schema = this.#schema;
}
const buffer = await fromDataToBuffer(data, undefined, schema);
await this.#native.execute(buffer);
return await this.#native.execute(buffer);
}
}

View File

@@ -20,6 +20,8 @@ import {
IndexConfig,
IndexStatistics,
OptimizeStats,
TableStatistics,
Tags,
Table as _NativeTable,
} from "./native";
import {
@@ -246,6 +248,19 @@ export abstract class Table {
*/
abstract prewarmIndex(name: string): Promise<void>;
/**
* Waits for asynchronous indexing to complete on the table.
*
* @param indexNames The name of the indices to wait for
* @param timeoutSeconds The number of seconds to wait before timing out
*
* This will raise an error if the indices are not created and fully indexed within the timeout.
*/
abstract waitForIndex(
indexNames: string[],
timeoutSeconds: number,
): Promise<void>;
/**
* Create a {@link Query} Builder.
*
@@ -361,7 +376,7 @@ export abstract class Table {
*
* Calling this method will set the table into time-travel mode. If you
* wish to return to standard mode, call `checkoutLatest`.
* @param {number} version The version to checkout
* @param {number | string} version The version to checkout, could be version number or tag
* @example
* ```typescript
* import * as lancedb from "@lancedb/lancedb"
@@ -377,7 +392,8 @@ export abstract class Table {
* console.log(await table.version()); // 2
* ```
*/
abstract checkout(version: number): Promise<void>;
abstract checkout(version: number | string): Promise<void>;
/**
* Checkout the latest version of the table. _This is an in-place operation._
*
@@ -391,6 +407,23 @@ export abstract class Table {
*/
abstract listVersions(): Promise<Version[]>;
/**
* Get a tags manager for this table.
*
* Tags allow you to label specific versions of a table with a human-readable name.
* The returned tags manager can be used to list, create, update, or delete tags.
*
* @returns {Tags} A tags manager for this table
* @example
* ```typescript
* const tagsManager = await table.tags();
* await tagsManager.create("v1", 1);
* const tags = await tagsManager.list();
* console.log(tags); // { "v1": { version: 1, manifestSize: ... } }
* ```
*/
abstract tags(): Promise<Tags>;
/**
* Restore the table to the currently checked out version
*
@@ -450,6 +483,13 @@ export abstract class Table {
* Use {@link Table.listIndices} to find the names of the indices.
*/
abstract indexStats(name: string): Promise<IndexStatistics | undefined>;
/** Returns table and fragment statistics
*
* @returns {TableStatistics} The table and fragment statistics
*
*/
abstract stats(): Promise<TableStatistics>;
}
export class LocalTable extends Table {
@@ -569,7 +609,12 @@ export class LocalTable extends Table {
// Bit of a hack to get around the fact that TS has no package-scope.
// biome-ignore lint/suspicious/noExplicitAny: skip
const nativeIndex = (options?.config as any)?.inner;
await this.inner.createIndex(nativeIndex, column, options?.replace);
await this.inner.createIndex(
nativeIndex,
column,
options?.replace,
options?.waitTimeoutSeconds,
);
}
async dropIndex(name: string): Promise<void> {
@@ -580,6 +625,13 @@ export class LocalTable extends Table {
await this.inner.prewarmIndex(name);
}
async waitForIndex(
indexNames: string[],
timeoutSeconds: number,
): Promise<void> {
await this.inner.waitForIndex(indexNames, timeoutSeconds);
}
query(): Query {
return new Query(this.inner);
}
@@ -674,8 +726,11 @@ export class LocalTable extends Table {
return await this.inner.version();
}
async checkout(version: number): Promise<void> {
await this.inner.checkout(version);
async checkout(version: number | string): Promise<void> {
if (typeof version === "string") {
return this.inner.checkoutTag(version);
}
return this.inner.checkout(version);
}
async checkoutLatest(): Promise<void> {
@@ -694,6 +749,10 @@ export class LocalTable extends Table {
await this.inner.restore();
}
async tags(): Promise<Tags> {
return await this.inner.tags();
}
async optimize(options?: Partial<OptimizeOptions>): Promise<OptimizeStats> {
let cleanupOlderThanMs;
if (
@@ -724,6 +783,11 @@ export class LocalTable extends Table {
}
return stats;
}
async stats(): Promise<TableStatistics> {
return await this.inner.stats();
}
mergeInsert(on: string | string[]): MergeInsertBuilder {
on = Array.isArray(on) ? on : [on];
return new MergeInsertBuilder(this.inner.mergeInsert(on), this.schema());

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.19.0-beta.7",
"version": "0.19.1-beta.1",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -37,7 +37,7 @@ impl NativeMergeInsertBuilder {
}
#[napi(catch_unwind)]
pub async fn execute(&self, buf: Buffer) -> napi::Result<()> {
pub async fn execute(&self, buf: Buffer) -> napi::Result<MergeStats> {
let data = ipc_file_to_batches(buf.to_vec())
.and_then(IntoArrow::into_arrow)
.map_err(|e| {
@@ -46,12 +46,14 @@ impl NativeMergeInsertBuilder {
let this = self.clone();
this.inner.execute(data).await.map_err(|e| {
let stats = this.inner.execute(data).await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to execute merge insert: {}",
convert_error(&e)
))
})
})?;
Ok(stats.into())
}
}
@@ -60,3 +62,20 @@ impl From<MergeInsertBuilder> for NativeMergeInsertBuilder {
Self { inner }
}
}
#[napi(object)]
pub struct MergeStats {
pub num_inserted_rows: BigInt,
pub num_updated_rows: BigInt,
pub num_deleted_rows: BigInt,
}
impl From<lancedb::table::MergeStats> for MergeStats {
fn from(stats: lancedb::table::MergeStats) -> Self {
Self {
num_inserted_rows: stats.num_inserted_rows.into(),
num_updated_rows: stats.num_updated_rows.into(),
num_deleted_rows: stats.num_deleted_rows.into(),
}
}
}

View File

@@ -111,6 +111,7 @@ impl Table {
index: Option<&Index>,
column: String,
replace: Option<bool>,
wait_timeout_s: Option<i64>,
) -> napi::Result<()> {
let lancedb_index = if let Some(index) = index {
index.consume()?
@@ -121,6 +122,10 @@ impl Table {
if let Some(replace) = replace {
builder = builder.replace(replace);
}
if let Some(timeout) = wait_timeout_s {
builder =
builder.wait_timeout(std::time::Duration::from_secs(timeout.try_into().unwrap()));
}
builder.execute().await.default_error()
}
@@ -140,6 +145,24 @@ impl Table {
.default_error()
}
#[napi(catch_unwind)]
pub async fn wait_for_index(&self, index_names: Vec<String>, timeout_s: i64) -> Result<()> {
let timeout = std::time::Duration::from_secs(timeout_s.try_into().unwrap());
let index_names: Vec<&str> = index_names.iter().map(|s| s.as_str()).collect();
let slice: &[&str] = &index_names;
self.inner_ref()?
.wait_for_index(slice, timeout)
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn stats(&self) -> Result<TableStatistics> {
let stats = self.inner_ref()?.stats().await.default_error()?;
Ok(stats.into())
}
#[napi(catch_unwind)]
pub async fn update(
&self,
@@ -232,6 +255,14 @@ impl Table {
.default_error()
}
#[napi(catch_unwind)]
pub async fn checkout_tag(&self, tag: String) -> napi::Result<()> {
self.inner_ref()?
.checkout_tag(tag.as_str())
.await
.default_error()
}
#[napi(catch_unwind)]
pub async fn checkout_latest(&self) -> napi::Result<()> {
self.inner_ref()?.checkout_latest().await.default_error()
@@ -264,6 +295,13 @@ impl Table {
self.inner_ref()?.restore().await.default_error()
}
#[napi(catch_unwind)]
pub async fn tags(&self) -> napi::Result<Tags> {
Ok(Tags {
inner: self.inner_ref()?.clone(),
})
}
#[napi(catch_unwind)]
pub async fn optimize(
&self,
@@ -523,9 +561,158 @@ impl From<lancedb::index::IndexStatistics> for IndexStatistics {
}
}
#[napi(object)]
pub struct TableStatistics {
/// The total number of bytes in the table
pub total_bytes: i64,
/// The number of rows in the table
pub num_rows: i64,
/// The number of indices in the table
pub num_indices: i64,
/// Statistics on table fragments
pub fragment_stats: FragmentStatistics,
}
#[napi(object)]
pub struct FragmentStatistics {
/// The number of fragments in the table
pub num_fragments: i64,
/// The number of uncompacted fragments in the table
pub num_small_fragments: i64,
/// Statistics on the number of rows in the table fragments
pub lengths: FragmentSummaryStats,
}
#[napi(object)]
pub struct FragmentSummaryStats {
/// The number of rows in the fragment with the fewest rows
pub min: i64,
/// The number of rows in the fragment with the most rows
pub max: i64,
/// The mean number of rows in the fragments
pub mean: i64,
/// The 25th percentile of number of rows in the fragments
pub p25: i64,
/// The 50th percentile of number of rows in the fragments
pub p50: i64,
/// The 75th percentile of number of rows in the fragments
pub p75: i64,
/// The 99th percentile of number of rows in the fragments
pub p99: i64,
}
impl From<lancedb::table::TableStatistics> for TableStatistics {
fn from(v: lancedb::table::TableStatistics) -> Self {
Self {
total_bytes: v.total_bytes as i64,
num_rows: v.num_rows as i64,
num_indices: v.num_indices as i64,
fragment_stats: FragmentStatistics {
num_fragments: v.fragment_stats.num_fragments as i64,
num_small_fragments: v.fragment_stats.num_small_fragments as i64,
lengths: FragmentSummaryStats {
min: v.fragment_stats.lengths.min as i64,
max: v.fragment_stats.lengths.max as i64,
mean: v.fragment_stats.lengths.mean as i64,
p25: v.fragment_stats.lengths.p25 as i64,
p50: v.fragment_stats.lengths.p50 as i64,
p75: v.fragment_stats.lengths.p75 as i64,
p99: v.fragment_stats.lengths.p99 as i64,
},
},
}
}
}
#[napi(object)]
pub struct Version {
pub version: i64,
pub timestamp: i64,
pub metadata: HashMap<String, String>,
}
#[napi]
pub struct TagContents {
pub version: i64,
pub manifest_size: i64,
}
#[napi]
pub struct Tags {
inner: LanceDbTable,
}
#[napi]
impl Tags {
#[napi]
pub async fn list(&self) -> napi::Result<HashMap<String, TagContents>> {
let rust_tags = self.inner.tags().await.default_error()?;
let tag_list = rust_tags.as_ref().list().await.default_error()?;
let tag_contents = tag_list
.into_iter()
.map(|(k, v)| {
(
k,
TagContents {
version: v.version as i64,
manifest_size: v.manifest_size as i64,
},
)
})
.collect();
Ok(tag_contents)
}
#[napi]
pub async fn get_version(&self, tag: String) -> napi::Result<i64> {
let rust_tags = self.inner.tags().await.default_error()?;
rust_tags
.as_ref()
.get_version(tag.as_str())
.await
.map(|v| v as i64)
.default_error()
}
#[napi]
pub async unsafe fn create(&mut self, tag: String, version: i64) -> napi::Result<()> {
let mut rust_tags = self.inner.tags().await.default_error()?;
rust_tags
.as_mut()
.create(tag.as_str(), version as u64)
.await
.default_error()
}
#[napi]
pub async unsafe fn delete(&mut self, tag: String) -> napi::Result<()> {
let mut rust_tags = self.inner.tags().await.default_error()?;
rust_tags
.as_mut()
.delete(tag.as_str())
.await
.default_error()
}
#[napi]
pub async unsafe fn update(&mut self, tag: String, version: i64) -> napi::Result<()> {
let mut rust_tags = self.inner.tags().await.default_error()?;
rust_tags
.as_mut()
.update(tag.as_str(), version as u64)
.await
.default_error()
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.22.0-beta.8"
current_version = "0.22.1-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.22.0-beta.8"
version = "0.22.1-beta.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -7,7 +7,7 @@ dependencies = [
"numpy",
"overrides>=0.7",
"packaging",
"pyarrow>=14",
"pyarrow>=16",
"pydantic>=1.10",
"tqdm>=4.27.0",
]
@@ -77,6 +77,7 @@ embeddings = [
"pillow",
"open-clip-torch",
"cohere",
"colpali-engine>=0.3.10",
"huggingface_hub",
"InstructorEmbedding",
"google.generativeai",

View File

@@ -1,5 +1,5 @@
from datetime import timedelta
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
from typing import Dict, List, Optional, Tuple, Any, TypedDict, Union, Literal
import pyarrow as pa
@@ -47,7 +47,7 @@ class Table:
): ...
async def list_versions(self) -> List[Dict[str, Any]]: ...
async def version(self) -> int: ...
async def checkout(self, version: int): ...
async def checkout(self, version: Union[int, str]): ...
async def checkout_latest(self): ...
async def restore(self, version: Optional[int] = None): ...
async def list_indices(self) -> list[IndexConfig]: ...
@@ -61,9 +61,18 @@ class Table:
cleanup_since_ms: Optional[int] = None,
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
@property
def tags(self) -> Tags: ...
def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ...
class Tags:
async def list(self) -> Dict[str, Tag]: ...
async def get_version(self, tag: str) -> int: ...
async def create(self, tag: str, version: int): ...
async def delete(self, tag: str): ...
async def update(self, tag: str, version: int): ...
class IndexConfig:
index_type: str
columns: List[str]
@@ -195,3 +204,7 @@ class RemovalStats:
class OptimizeStats:
compaction: CompactionStats
prune: RemovalStats
class Tag(TypedDict):
version: int
manifest_size: int

View File

@@ -9,7 +9,7 @@ import numpy as np
import pyarrow as pa
import pyarrow.dataset
from .dependencies import pandas as pd
from .dependencies import _check_for_pandas, pandas as pd
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
@@ -63,7 +63,7 @@ def data_to_reader(
data: DATA, schema: Optional[pa.Schema] = None
) -> pa.RecordBatchReader:
"""Convert various types of input into a RecordBatchReader"""
if pd is not None and isinstance(data, pd.DataFrame):
if _check_for_pandas(data) and isinstance(data, pd.DataFrame):
return pa.Table.from_pandas(data, schema=schema).to_reader()
elif isinstance(data, pa.Table):
return data.to_reader()

View File

@@ -19,3 +19,4 @@ from .imagebind import ImageBindEmbeddings
from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction
from .colpali import ColPaliEmbeddings

View File

@@ -0,0 +1,255 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from functools import lru_cache
from typing import List, Union, Optional, Any
import numpy as np
import io
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import TEXT, IMAGES, is_flash_attn_2_available
@register("colpali")
class ColPaliEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the ColPali engine for
multimodal multi-vector embeddings.
This embedding function supports ColQwen2.5 models, producing multivector outputs
for both text and image inputs. The output embeddings are lists of vectors, each
vector being 128-dimensional by default, represented as List[List[float]].
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
device : str
The device for inference (default "cuda:0").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
pool_factor: int = 2
quantization_config: Optional[Any] = None
batch_size: int = 2
_model = None
_processor = None
_token_pooler = None
_vector_dim = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
(
self._model,
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
self.quantization_config,
)
@staticmethod
@lru_cache(maxsize=1)
def _load_model(
model_name: str,
dtype: str,
device: str,
use_token_pooling: bool,
quantization_config: Optional[Any],
):
"""
Initialize and cache the ColPali model, processor, and token pooler.
"""
torch = attempt_import_or_raise("torch", "torch")
transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
if quantization_config is not None:
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
raise ValueError("quantization_config must be a BitsAndBytesConfig")
if dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif dtype == "float16":
torch_dtype = torch.float16
elif dtype == "float64":
torch_dtype = torch.float64
else:
torch_dtype = torch.float32
model = colpali_engine.models.ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
quantization_config=quantization_config
if quantization_config is not None
else None,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
).eval()
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
model_name
)
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
return model, processor, token_pooler
def ndims(self):
"""
Return the dimension of a vector in the multivector output (e.g., 128).
"""
torch = attempt_import_or_raise("torch", "torch")
if self._vector_dim is None:
dummy_query = "test"
batch_queries = self._processor.process_queries([dummy_query]).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
self._vector_dim = query_embeddings[0].shape[-1]
return self._vector_dim
def _process_embeddings(self, embeddings):
"""
Format model embeddings into List[List[float]].
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if isinstance(embeddings, torch.Tensor):
tensors = embeddings.detach().cpu()
if tensors.dtype == torch.bfloat16:
tensors = tensors.to(torch.float32)
return (
tensors.numpy()
.astype(np.float64 if self.dtype == "float64" else np.float32)
.tolist()
)
return []
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
"""
Generate embeddings for text input.
"""
torch = attempt_import_or_raise("torch", "torch")
text = self.sanitize_input(text)
all_embeddings = []
for i in range(0, len(text), self.batch_size):
batch_text = text[i : i + self.batch_size]
batch_queries = self._processor.process_queries(batch_text).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
def _prepare_images(self, images: IMAGES) -> List:
"""
Convert image inputs to PIL Images.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
try:
for image in images:
if isinstance(image, str):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
return pil_images
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
"""
Generate embeddings for a batch of images.
"""
torch = attempt_import_or_raise("torch", "torch")
pil_images = self._prepare_images(images)
all_embeddings = []
for i in range(0, len(pil_images), self.batch_size):
batch_images = pil_images[i : i + self.batch_size]
batch_images = self._processor.process_images(batch_images).to(
self._model.device
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings
def compute_query_embeddings(
self, query: Union[str, IMAGES], *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a single user query (text only).
"""
if not isinstance(query, str):
raise ValueError(
"Query must be a string, image to image search is not supported"
)
return self.generate_text_embeddings([query])
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[List[List[float]]]:
"""
Compute embeddings for a batch of source images.
Parameters
----------
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
Batch of images (paths, URLs, bytes, or PIL Images).
"""
images = self.sanitize_input(images)
return self.generate_image_embeddings(images)

View File

@@ -18,6 +18,7 @@ import numpy as np
import pyarrow as pa
from ..dependencies import pandas as pd
from ..util import attempt_import_or_raise
# ruff: noqa: PERF203
@@ -275,3 +276,12 @@ def url_retrieve(url: str):
def api_key_not_found_help(provider):
logging.error("Could not find API key for %s", provider)
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
def is_flash_attn_2_available():
try:
attempt_import_or_raise("flash_attn", "flash_attn")
return True
except ImportError:
return False

View File

@@ -152,6 +152,104 @@ def Vector(
return FixedSizeList
def MultiVector(
dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True
) -> Type:
"""Pydantic MultiVector Type for multi-vector embeddings.
This type represents a list of vectors, each with the same dimension.
Useful for models that produce multiple embeddings per input, like ColPali.
Parameters
----------
dim : int
The dimension of each vector in the multi-vector.
value_type : pyarrow.DataType, optional
The value type of the vectors, by default pa.float32()
nullable : bool, optional
Whether the multi-vector is nullable, by default it is True.
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import MultiVector
...
>>> class MyModel(pydantic.BaseModel):
... id: int
... text: str
... embeddings: MultiVector(128) # List of 128-dimensional vectors
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("text", pa.utf8(), False),
... pa.field("embeddings", pa.list_(pa.list_(pa.float32(), 128)))
... ])
"""
class MultiVectorList(list, FixedSizeListMixin):
def __repr__(self):
return f"MultiVector(dim={dim})"
@staticmethod
def nullable() -> bool:
return nullable
@staticmethod
def dim() -> int:
return dim
@staticmethod
def value_arrow_type() -> pa.DataType:
return value_type
@staticmethod
def is_multi_vector() -> bool:
return True
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.list_schema(
items_schema=core_schema.list_schema(
min_length=dim,
max_length=dim,
items_schema=core_schema.float_schema(),
),
),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
@classmethod
def validate(cls, v):
if not isinstance(v, (list, range)):
raise TypeError("A list of vectors is needed")
for vec in v:
if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim:
raise TypeError(f"Each vector must be a list of {dim} numbers")
return cls(v)
if PYDANTIC_VERSION.major < 2:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["items"] = {
"type": "array",
"items": {"type": "number"},
"minItems": dim,
"maxItems": dim,
}
return MultiVectorList
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
"""Convert a field with native Python type to Arrow data type.
@@ -206,6 +304,9 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
fields = _pydantic_model_to_fields(tp)
return pa.struct(fields)
if issubclass(tp, FixedSizeListMixin):
if getattr(tp, "is_multi_vector", lambda: False)():
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
# For regular Vector
return pa.list_(tp.value_arrow_type(), tp.dim())
return _py_type_to_arrow_type(tp, field)

View File

@@ -28,6 +28,8 @@ import pyarrow.compute as pc
import pyarrow.fs as pa_fs
import pydantic
from lancedb.pydantic import PYDANTIC_VERSION
from . import __version__
from .arrow import AsyncRecordBatchReader
from .dependencies import pandas as pd
@@ -498,10 +500,14 @@ class Query(pydantic.BaseModel):
)
return query
class Config:
# This tells pydantic to allow custom types (needed for the `vector` query since
# pa.Array wouln't be allowed otherwise)
arbitrary_types_allowed = True
# This tells pydantic to allow custom types (needed for the `vector` query since
# pa.Array wouln't be allowed otherwise)
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
class Config:
arbitrary_types_allowed = True
else:
model_config = {"arbitrary_types_allowed": True}
class LanceQueryBuilder(ABC):
@@ -1586,6 +1592,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._refine_factor = None
self._distance_type = None
self._phrase_query = None
self._lower_bound = None
self._upper_bound = None
def _validate_query(self, query, vector=None, text=None):
if query is not None and (vector is not None or text is not None):
@@ -1628,47 +1636,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
raise NotImplementedError("to_query_object not yet supported on a hybrid query")
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text
)
self._fts_query = LanceFtsQueryBuilder(
self._table, fts_query, fts_columns=self._fts_columns
)
vector_query = self._query_to_vector(
self._table, vector_query, self._vector_column
)
self._vector_query = LanceVectorQueryBuilder(
self._table, vector_query, self._vector_column
)
if self._limit:
self._vector_query.limit(self._limit)
self._fts_query.limit(self._limit)
if self._columns:
self._vector_query.select(self._columns)
self._fts_query.select(self._columns)
if self._where:
self._vector_query.where(self._where, self._postfilter)
self._fts_query.where(self._where, self._postfilter)
if self._with_row_id:
self._vector_query.with_row_id(True)
self._fts_query.with_row_id(True)
if self._phrase_query:
self._fts_query.phrase_query(True)
if self._distance_type:
self._vector_query.metric(self._distance_type)
if self._nprobes:
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
self._vector_query.refine_factor(self._refine_factor)
if self._ef:
self._vector_query.ef(self._ef)
if self._bypass_vector_index:
self._vector_query.bypass_vector_index()
if self._reranker is None:
self._reranker = RRFReranker()
self._create_query_builders()
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(
self._fts_query.with_row_id(True).to_arrow, timeout=timeout
@@ -1991,6 +1959,112 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._bypass_vector_index = True
return self
def explain_plan(self, verbose: Optional[bool] = False) -> str:
"""Return the execution plan for this query.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [99.0, 99]}])
>>> query = [100, 100]
>>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
Parameters
----------
verbose : bool, default False
Use a verbose output format.
Returns
-------
plan : str
""" # noqa: E501
self._create_query_builders()
results = ["Vector Search Plan:"]
results.append(
self._table._explain_plan(
self._vector_query.to_query_object(), verbose=verbose
)
)
results.append("FTS Search Plan:")
results.append(
self._table._explain_plan(
self._fts_query.to_query_object(), verbose=verbose
)
)
return "\n".join(results)
def analyze_plan(self):
"""Execute the query and display with runtime metrics.
Returns
-------
plan : str
"""
self._create_query_builders()
results = ["Vector Search Plan:"]
results.append(self._table._analyze_plan(self._vector_query.to_query_object()))
results.append("FTS Search Plan:")
results.append(self._table._analyze_plan(self._fts_query.to_query_object()))
return "\n".join(results)
def _create_query_builders(self):
"""Set up and configure the vector and FTS query builders."""
vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text
)
self._fts_query = LanceFtsQueryBuilder(
self._table, fts_query, fts_columns=self._fts_columns
)
vector_query = self._query_to_vector(
self._table, vector_query, self._vector_column
)
self._vector_query = LanceVectorQueryBuilder(
self._table, vector_query, self._vector_column
)
# Apply common configurations
if self._limit:
self._vector_query.limit(self._limit)
self._fts_query.limit(self._limit)
if self._columns:
self._vector_query.select(self._columns)
self._fts_query.select(self._columns)
if self._where:
self._vector_query.where(self._where, self._postfilter)
self._fts_query.where(self._where, self._postfilter)
if self._with_row_id:
self._vector_query.with_row_id(True)
self._fts_query.with_row_id(True)
if self._phrase_query:
self._fts_query.phrase_query(True)
if self._distance_type:
self._vector_query.metric(self._distance_type)
if self._nprobes:
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
self._vector_query.refine_factor(self._refine_factor)
if self._ef:
self._vector_query.ef(self._ef)
if self._bypass_vector_index:
self._vector_query.bypass_vector_index()
if self._lower_bound or self._upper_bound:
self._vector_query.distance_range(
lower_bound=self._lower_bound, upper_bound=self._upper_bound
)
if self._reranker is None:
self._reranker = RRFReranker()
class AsyncQueryBase(object):
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):

View File

@@ -18,7 +18,7 @@ from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import AsyncTable, IndexStatistics, Query, Table
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
class RemoteTable(Table):
@@ -54,6 +54,10 @@ class RemoteTable(Table):
"""Get the current version of the table"""
return LOOP.run(self._table.version())
@property
def tags(self) -> Tags:
return Tags(self._table)
@cached_property
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
"""
@@ -81,7 +85,7 @@ class RemoteTable(Table):
"""to_pandas() is not yet supported on LanceDB cloud."""
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
def checkout(self, version: int):
def checkout(self, version: Union[int, str]):
return LOOP.run(self._table.checkout(version))
def checkout_latest(self):
@@ -104,6 +108,7 @@ class RemoteTable(Table):
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*,
replace: bool = False,
wait_timeout: timedelta = None,
):
"""Creates a scalar index
Parameters
@@ -126,13 +131,18 @@ class RemoteTable(Table):
else:
raise ValueError(f"Unknown index type: {index_type}")
LOOP.run(self._table.create_index(column, config=config, replace=replace))
LOOP.run(
self._table.create_index(
column, config=config, replace=replace, wait_timeout=wait_timeout
)
)
def create_fts_index(
self,
column: str,
*,
replace: bool = False,
wait_timeout: timedelta = None,
with_position: bool = True,
# tokenizer configs:
base_tokenizer: str = "simple",
@@ -153,7 +163,11 @@ class RemoteTable(Table):
remove_stop_words=remove_stop_words,
ascii_folding=ascii_folding,
)
LOOP.run(self._table.create_index(column, config=config, replace=replace))
LOOP.run(
self._table.create_index(
column, config=config, replace=replace, wait_timeout=wait_timeout
)
)
def create_index(
self,
@@ -165,6 +179,7 @@ class RemoteTable(Table):
replace: Optional[bool] = None,
accelerator: Optional[str] = None,
index_type="vector",
wait_timeout: Optional[timedelta] = None,
):
"""Create an index on the table.
Currently, the only parameters that matter are
@@ -236,7 +251,11 @@ class RemoteTable(Table):
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
LOOP.run(self._table.create_index(vector_column_name, config=config))
LOOP.run(
self._table.create_index(
vector_column_name, config=config, wait_timeout=wait_timeout
)
)
def add(
self,
@@ -554,6 +573,14 @@ class RemoteTable(Table):
def drop_index(self, index_name: str):
return LOOP.run(self._table.drop_index(index_name))
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
):
return LOOP.run(self._table.wait_for_index(index_names, timeout))
def stats(self):
return LOOP.run(self._table.stats())
def uses_v2_manifest_paths(self) -> bool:
raise NotImplementedError(
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"

View File

@@ -77,6 +77,7 @@ if TYPE_CHECKING:
OptimizeStats,
CleanupStats,
CompactionStats,
Tag,
)
from .db import LanceDBConnection
from .index import IndexConfig
@@ -582,6 +583,35 @@ class Table(ABC):
"""
raise NotImplementedError
@property
@abstractmethod
def tags(self) -> Tags:
"""Tag management for the table.
Similar to Git, tags are a way to add metadata to a specific version of the
table.
.. warning::
Tagged versions are exempted from the :py:meth:`cleanup_old_versions()`
process.
To remove a version that has been tagged, you must first
:py:meth:`~Tags.delete` the associated tag.
Examples
--------
.. code-block:: python
table = db.open_table("my_table")
table.tags.create("v2-prod-20250203", 10)
tags = table.tags.list()
"""
raise NotImplementedError
@property
@abstractmethod
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
@@ -631,6 +661,7 @@ class Table(ABC):
index_cache_size: Optional[int] = None,
*,
index_type: VectorIndexType = "IVF_PQ",
wait_timeout: Optional[timedelta] = None,
num_bits: int = 8,
max_iterations: int = 50,
sample_rate: int = 256,
@@ -666,6 +697,8 @@ class Table(ABC):
num_bits: int
The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
Only 4 and 8 are supported.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
"""
raise NotImplementedError
@@ -689,6 +722,30 @@ class Table(ABC):
"""
raise NotImplementedError
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
"""
Wait for indexing to complete for the given index names.
This will poll the table until all the indices are fully indexed,
or raise a timeout exception if the timeout is reached.
Parameters
----------
index_names: str
The name of the indices to poll
timeout: timedelta
Timeout to wait for asynchronous indexing. The default is 5 minutes.
"""
raise NotImplementedError
@abstractmethod
def stats(self) -> TableStatistics:
"""
Retrieve table and fragment statistics.
"""
raise NotImplementedError
@abstractmethod
def create_scalar_index(
self,
@@ -696,6 +753,7 @@ class Table(ABC):
*,
replace: bool = True,
index_type: ScalarIndexType = "BTREE",
wait_timeout: Optional[timedelta] = None,
):
"""Create a scalar index on a column.
@@ -708,7 +766,8 @@ class Table(ABC):
Replace the existing index if it exists.
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE"
The type of index to create.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
Examples
--------
@@ -767,6 +826,7 @@ class Table(ABC):
stem: bool = False,
remove_stop_words: bool = False,
ascii_folding: bool = False,
wait_timeout: Optional[timedelta] = None,
):
"""Create a full-text search index on the table.
@@ -822,6 +882,8 @@ class Table(ABC):
ascii_folding : bool, default False
Whether to fold ASCII characters. This converts accented characters to
their ASCII equivalent. For example, "café" would be converted to "cafe".
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
"""
raise NotImplementedError
@@ -900,10 +962,12 @@ class Table(ABC):
>>> table = db.create_table("my_table", data)
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
>>> # Perform a "upsert" operation
>>> table.merge_insert("a") \\
>>> stats = table.merge_insert("a") \\
... .when_matched_update_all() \\
... .when_not_matched_insert_all() \\
... .execute(new_data)
>>> stats
{'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0}
>>> # The order of new rows is non-deterministic since we use
>>> # a hash-join as part of this operation and so we sort here
>>> table.to_arrow().sort_by("a").to_pandas()
@@ -1329,7 +1393,7 @@ class Table(ABC):
"""
@abstractmethod
def checkout(self, version: int):
def checkout(self, version: Union[int, str]):
"""
Checks out a specific version of the Table
@@ -1344,6 +1408,12 @@ class Table(ABC):
Any operation that modifies the table will fail while the table is in a checked
out state.
Parameters
----------
version: int | str,
The version to check out. A version number (`int`) or a tag
(`str`) can be provided.
To return the table to a normal state use `[Self::checkout_latest]`
"""
@@ -1513,7 +1583,45 @@ class LanceTable(Table):
"""Get the current version of the table"""
return LOOP.run(self._table.version())
def checkout(self, version: int):
@property
def tags(self) -> Tags:
"""Tag management for the table.
Similar to Git, tags are a way to add metadata to a specific version of the
table.
.. warning::
Tagged versions are exempted from the :py:meth:`cleanup_old_versions()`
process.
To remove a version that has been tagged, you must first
:py:meth:`~Tags.delete` the associated tag.
Returns
-------
Tags
The tag manager for managing tags for the table.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table",
... [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.tags.create("v1", table.version)
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> tags = table.tags.list()
>>> print(tags["v1"]["version"])
1
>>> table.checkout("v1")
>>> table.to_pandas()
vector type
0 [1.1, 0.9] vector
"""
return Tags(self._table)
def checkout(self, version: Union[int, str]):
"""Checkout a version of the table. This is an in-place operation.
This allows viewing previous versions of the table. If you wish to
@@ -1525,8 +1633,9 @@ class LanceTable(Table):
Parameters
----------
version : int
The version to checkout.
version: int | str,
The version to check out. A version number (`int`) or a tag
(`str`) can be provided.
Examples
--------
@@ -1771,6 +1880,14 @@ class LanceTable(Table):
"""
return LOOP.run(self._table.prewarm_index(name))
def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
return LOOP.run(self._table.wait_for_index(index_names, timeout))
def stats(self) -> TableStatistics:
return LOOP.run(self._table.stats())
def create_scalar_index(
self,
column: str,
@@ -2374,7 +2491,9 @@ class LanceTable(Table):
on_bad_vectors: OnBadVectorsType,
fill_value: float,
):
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
return LOOP.run(
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
@deprecation.deprecated(
deprecated_in="0.21.0",
@@ -2964,6 +3083,7 @@ class AsyncTable:
config: Optional[
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
wait_timeout: Optional[timedelta] = None,
):
"""Create an index to speed up queries
@@ -2988,6 +3108,8 @@ class AsyncTable:
For advanced configuration you can specify the type of index you would
like to create. You can also specify index-specific parameters when
creating an index object.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
"""
if config is not None:
if not isinstance(
@@ -2998,7 +3120,9 @@ class AsyncTable:
" Bitmap, LabelList, or FTS"
)
try:
await self._inner.create_index(column, index=config, replace=replace)
await self._inner.create_index(
column, index=config, replace=replace, wait_timeout=wait_timeout
)
except ValueError as e:
if "not support the requested language" in str(e):
supported_langs = ", ".join(lang_mapping.values())
@@ -3043,6 +3167,29 @@ class AsyncTable:
"""
await self._inner.prewarm_index(name)
async def wait_for_index(
self, index_names: Iterable[str], timeout: timedelta = timedelta(seconds=300)
) -> None:
"""
Wait for indexing to complete for the given index names.
This will poll the table until all the indices are fully indexed,
or raise a timeout exception if the timeout is reached.
Parameters
----------
index_names: str
The name of the indices to poll
timeout: timedelta
Timeout to wait for asynchronous indexing. The default is 5 minutes.
"""
await self._inner.wait_for_index(index_names, timeout)
async def stats(self) -> TableStatistics:
"""
Retrieve table and fragment statistics.
"""
return await self._inner.stats()
async def add(
self,
data: DATA,
@@ -3134,10 +3281,12 @@ class AsyncTable:
>>> table = db.create_table("my_table", data)
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
>>> # Perform a "upsert" operation
>>> table.merge_insert("a") \\
>>> stats = table.merge_insert("a") \\
... .when_matched_update_all() \\
... .when_not_matched_insert_all() \\
... .execute(new_data)
>>> stats
{'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0}
>>> # The order of new rows is non-deterministic since we use
>>> # a hash-join as part of this operation and so we sort here
>>> table.to_arrow().sort_by("a").to_pandas()
@@ -3493,7 +3642,7 @@ class AsyncTable:
)
if isinstance(data, pa.Table):
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
await self._inner.execute_merge_insert(
return await self._inner.execute_merge_insert(
data,
dict(
on=merge._on,
@@ -3694,7 +3843,7 @@ class AsyncTable:
return versions
async def checkout(self, version: int):
async def checkout(self, version: int | str):
"""
Checks out a specific version of the Table
@@ -3709,6 +3858,12 @@ class AsyncTable:
Any operation that modifies the table will fail while the table is in a checked
out state.
Parameters
----------
version: int | str,
The version to check out. A version number (`int`) or a tag
(`str`) can be provided.
To return the table to a normal state use `[Self::checkout_latest]`
"""
try:
@@ -3746,6 +3901,24 @@ class AsyncTable:
"""
await self._inner.restore(version)
@property
def tags(self) -> AsyncTags:
"""Tag management for the dataset.
Similar to Git, tags are a way to add metadata to a specific version of the
dataset.
.. warning::
Tagged versions are exempted from the
:py:meth:`optimize(cleanup_older_than)` process.
To remove a version that has been tagged, you must first
:py:meth:`~Tags.delete` the associated tag.
"""
return AsyncTags(self._inner)
async def optimize(
self,
*,
@@ -3915,3 +4088,217 @@ class IndexStatistics:
# a dictionary instead of a class.
def __getitem__(self, key):
return getattr(self, key)
@dataclass
class TableStatistics:
"""
Statistics about a table and fragments.
Attributes
----------
total_bytes: int
The total number of bytes in the table.
num_rows: int
The total number of rows in the table.
num_indices: int
The total number of indices in the table.
fragment_stats: FragmentStatistics
Statistics about fragments in the table.
"""
total_bytes: int
num_rows: int
num_indices: int
fragment_stats: FragmentStatistics
@dataclass
class FragmentStatistics:
"""
Statistics about fragments.
Attributes
----------
num_fragments: int
The total number of fragments in the table.
num_small_fragments: int
The total number of small fragments in the table.
Small fragments have low row counts and may need to be compacted.
lengths: FragmentSummaryStats
Statistics about the number of rows in the table fragments.
"""
num_fragments: int
num_small_fragments: int
lengths: FragmentSummaryStats
@dataclass
class FragmentSummaryStats:
"""
Statistics about fragments sizes
Attributes
----------
min: int
The number of rows in the fragment with the fewest rows.
max: int
The number of rows in the fragment with the most rows.
mean: int
The mean number of rows in the fragments.
p25: int
The 25th percentile of number of rows in the fragments.
p50: int
The 50th percentile of number of rows in the fragments.
p75: int
The 75th percentile of number of rows in the fragments.
p99: int
The 99th percentile of number of rows in the fragments.
"""
min: int
max: int
mean: int
p25: int
p50: int
p75: int
p99: int
class Tags:
"""
Table tag manager.
"""
def __init__(self, table):
self._table = table
def list(self) -> Dict[str, Tag]:
"""
List all table tags.
Returns
-------
dict[str, Tag]
A dictionary mapping tag names to version numbers.
"""
return LOOP.run(self._table.tags.list())
def get_version(self, tag: str) -> int:
"""
Get the version of a tag.
Parameters
----------
tag: str,
The name of the tag to get the version for.
"""
return LOOP.run(self._table.tags.get_version(tag))
def create(self, tag: str, version: int) -> None:
"""
Create a tag for a given table version.
Parameters
----------
tag: str,
The name of the tag to create. This name must be unique among all tag
names for the table.
version: int,
The table version to tag.
"""
LOOP.run(self._table.tags.create(tag, version))
def delete(self, tag: str) -> None:
"""
Delete tag from the table.
Parameters
----------
tag: str,
The name of the tag to delete.
"""
LOOP.run(self._table.tags.delete(tag))
def update(self, tag: str, version: int) -> None:
"""
Update tag to a new version.
Parameters
----------
tag: str,
The name of the tag to update.
version: int,
The new table version to tag.
"""
LOOP.run(self._table.tags.update(tag, version))
class AsyncTags:
"""
Async table tag manager.
"""
def __init__(self, table):
self._table = table
async def list(self) -> Dict[str, Tag]:
"""
List all table tags.
Returns
-------
dict[str, Tag]
A dictionary mapping tag names to version numbers.
"""
return await self._table.tags.list()
async def get_version(self, tag: str) -> int:
"""
Get the version of a tag.
Parameters
----------
tag: str,
The name of the tag to get the version for.
"""
return await self._table.tags.get_version(tag)
async def create(self, tag: str, version: int) -> None:
"""
Create a tag for a given table version.
Parameters
----------
tag: str,
The name of the tag to create. This name must be unique among all tag
names for the table.
version: int,
The table version to tag.
"""
await self._table.tags.create(tag, version)
async def delete(self, tag: str) -> None:
"""
Delete tag from the table.
Parameters
----------
tag: str,
The name of the tag to delete.
"""
await self._table.tags.delete(tag)
async def update(self, tag: str, version: int) -> None:
"""
Update tag to a new version.
Parameters
----------
tag: str,
The name of the tag to update.
version: int,
The new table version to tag.
"""
await self._table.tags.update(tag, version)

View File

@@ -18,15 +18,19 @@ def test_upsert(mem_db):
{"id": 1, "name": "Bobby"},
{"id": 2, "name": "Charlie"},
]
(
stats = (
table.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(new_users)
)
table.count_rows() # 3
stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0}
# --8<-- [end:upsert_basic]
assert table.count_rows() == 3
assert stats["num_inserted_rows"] == 1
assert stats["num_updated_rows"] == 1
assert stats["num_deleted_rows"] == 0
@pytest.mark.asyncio
@@ -44,15 +48,19 @@ async def test_upsert_async(mem_db_async):
{"id": 1, "name": "Bobby"},
{"id": 2, "name": "Charlie"},
]
await (
stats = await (
table.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(new_users)
)
await table.count_rows() # 3
stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0}
# --8<-- [end:upsert_basic_async]
assert await table.count_rows() == 3
assert stats["num_inserted_rows"] == 1
assert stats["num_updated_rows"] == 1
assert stats["num_deleted_rows"] == 0
def test_insert_if_not_exists(mem_db):
@@ -69,10 +77,16 @@ def test_insert_if_not_exists(mem_db):
{"domain": "google.com", "name": "Google"},
{"domain": "facebook.com", "name": "Facebook"},
]
(table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains))
stats = (
table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains)
)
table.count_rows() # 3
stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0}
# --8<-- [end:insert_if_not_exists]
assert table.count_rows() == 3
assert stats["num_inserted_rows"] == 1
assert stats["num_updated_rows"] == 0
assert stats["num_deleted_rows"] == 0
@pytest.mark.asyncio
@@ -90,12 +104,16 @@ async def test_insert_if_not_exists_async(mem_db_async):
{"domain": "google.com", "name": "Google"},
{"domain": "facebook.com", "name": "Facebook"},
]
await (
stats = await (
table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains)
)
await table.count_rows() # 3
stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0}
# --8<-- [end:insert_if_not_exists_async]
assert await table.count_rows() == 3
assert stats["num_inserted_rows"] == 1
assert stats["num_updated_rows"] == 0
assert stats["num_deleted_rows"] == 0
def test_replace_range(mem_db):
@@ -113,7 +131,7 @@ def test_replace_range(mem_db):
new_chunks = [
{"doc_id": 1, "chunk_id": 0, "text": "Baz"},
]
(
stats = (
table.merge_insert(["doc_id", "chunk_id"])
.when_matched_update_all()
.when_not_matched_insert_all()
@@ -121,8 +139,12 @@ def test_replace_range(mem_db):
.execute(new_chunks)
)
table.count_rows("doc_id = 1") # 1
stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1}
# --8<-- [end:replace_range]
assert table.count_rows("doc_id = 1") == 1
assert stats["num_inserted_rows"] == 0
assert stats["num_updated_rows"] == 1
assert stats["num_deleted_rows"] == 1
@pytest.mark.asyncio
@@ -141,7 +163,7 @@ async def test_replace_range_async(mem_db_async):
new_chunks = [
{"doc_id": 1, "chunk_id": 0, "text": "Baz"},
]
await (
stats = await (
table.merge_insert(["doc_id", "chunk_id"])
.when_matched_update_all()
.when_not_matched_insert_all()
@@ -149,5 +171,9 @@ async def test_replace_range_async(mem_db_async):
.execute(new_chunks)
)
await table.count_rows("doc_id = 1") # 1
stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1}
# --8<-- [end:replace_range_async]
assert await table.count_rows("doc_id = 1") == 1
assert stats["num_inserted_rows"] == 0
assert stats["num_updated_rows"] == 1
assert stats["num_deleted_rows"] == 1

View File

@@ -11,7 +11,7 @@ import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, Vector, MultiVector
import requests
# These are integration tests for embedding functions.
@@ -575,3 +575,67 @@ def test_voyageai_multimodal_embedding_text_function():
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali(tmp_path):
import requests
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create()
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = (
func.VectorField()
) # Multivector image embeddings
table = db.create_table("media", schema=MediaItems)
texts = [
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
"a vintage bicycle leaning against a wall",
"a plate of delicious pasta",
"fresh fruit salad in a bowl",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# Get images as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
# Test text-to-image search
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
# Verify multivector dimensions
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)

View File

@@ -4,13 +4,32 @@
import lancedb
from lancedb.query import LanceHybridQueryBuilder
from lancedb.rerankers.rrf import RRFReranker
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import pytest_asyncio
from lancedb.index import FTS
from lancedb.table import AsyncTable
from lancedb.table import AsyncTable, Table
@pytest.fixture
def sync_table(tmpdir_factory) -> Table:
tmp_path = str(tmpdir_factory.mktemp("data"))
db = lancedb.connect(tmp_path)
data = pa.table(
{
"text": pa.array(["a", "b", "cat", "dog"]),
"vector": pa.array(
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
type=pa.list_(pa.float32(), list_size=2),
),
}
)
table = db.create_table("test", data)
table.create_fts_index("text", with_position=False, use_tantivy=False)
return table
@pytest_asyncio.fixture
@@ -102,6 +121,42 @@ async def test_async_hybrid_query_default_limit(table: AsyncTable):
assert texts.count("a") == 1
def test_hybrid_query_distance_range(sync_table: Table):
reranker = RRFReranker(return_score="all")
result = (
sync_table.search(query_type="hybrid")
.vector([0.0, 0.4])
.text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
print(result)
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio
async def test_hybrid_query_distance_range_async(table: AsyncTable):
reranker = RRFReranker(return_score="all")
result = await (
table.query()
.nearest_to([0.0, 0.4])
.nearest_to_text("cat and dog")
.distance_range(lower_bound=0.2, upper_bound=0.5)
.rerank(reranker)
.limit(2)
.to_arrow()
)
assert len(result) == 2
for dist in result["_distance"]:
if dist.is_valid:
assert 0.2 <= dist.as_py() <= 0.5
@pytest.mark.asyncio
async def test_explain_plan(table: AsyncTable):
plan = await (

View File

@@ -9,7 +9,13 @@ from typing import List, Optional, Tuple
import pyarrow as pa
import pydantic
import pytest
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from lancedb.pydantic import (
PYDANTIC_VERSION,
LanceModel,
Vector,
pydantic_to_schema,
MultiVector,
)
from pydantic import BaseModel
from pydantic import Field
@@ -354,3 +360,55 @@ def test_optional_nested_model():
),
]
)
def test_multi_vector():
class TestModel(pydantic.BaseModel):
vec: MultiVector(8)
schema = pydantic_to_schema(TestModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)]
)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=[[1.0] * 7])
with pytest.raises(pydantic.ValidationError):
TestModel(vec=[[1.0] * 9])
TestModel(vec=[[1.0] * 8])
TestModel(vec=[[1.0] * 8, [2.0] * 8])
TestModel(vec=[])
def test_multi_vector_nullable():
class NullableModel(pydantic.BaseModel):
vec: MultiVector(16, nullable=False)
schema = pydantic_to_schema(NullableModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)]
)
class DefaultModel(pydantic.BaseModel):
vec: MultiVector(16)
schema = pydantic_to_schema(DefaultModel)
assert schema == pa.schema(
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)]
)
def test_multi_vector_in_lance_model():
class TestModel(LanceModel):
id: int
vectors: MultiVector(16) = Field(default=[[0.0] * 16])
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["id", "vectors"]
t = TestModel(id=1)
assert t.vectors == [[0.0] * 16]

View File

@@ -257,7 +257,9 @@ async def test_distance_range_with_new_rows_async():
}
)
table = await conn.create_table("test", data)
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
await table.create_index(
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=2)
)
q = [0, 0]
rs = await table.query().nearest_to(q).to_arrow()

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import re
from concurrent.futures import ThreadPoolExecutor
import contextlib
from datetime import timedelta
@@ -235,6 +235,10 @@ def test_table_add_in_threadpool():
def test_table_create_indices():
def handler(request):
index_stats = dict(
index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0
)
if request.path == "/v1/table/test/create_index/":
request.send_response(200)
request.end_headers()
@@ -258,6 +262,47 @@ def test_table_create_indices():
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/list/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
indexes=[
{
"index_name": "id_idx",
"columns": ["id"],
},
{
"index_name": "text_idx",
"columns": ["text"],
},
{
"index_name": "vector_idx",
"columns": ["vector"],
},
]
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/id_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/text_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/vector_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
request.wfile.write(payload.encode())
elif "/drop/" in request.path:
request.send_response(200)
request.end_headers()
@@ -269,14 +314,125 @@ def test_table_create_indices():
# Parameters are well-tested through local and async tests.
# This is a smoke-test.
table = db.create_table("test", [{"id": 1}])
table.create_scalar_index("id")
table.create_fts_index("text")
table.create_scalar_index("vector")
table.create_scalar_index("id", wait_timeout=timedelta(seconds=2))
table.create_fts_index("text", wait_timeout=timedelta(seconds=2))
table.create_index(
vector_column_name="vector", wait_timeout=timedelta(seconds=10)
)
table.wait_for_index(["id_idx"], timedelta(seconds=2))
table.wait_for_index(["text_idx", "vector_idx"], timedelta(seconds=2))
table.drop_index("vector_idx")
table.drop_index("id_idx")
table.drop_index("text_idx")
def test_table_wait_for_index_timeout():
def handler(request):
index_stats = dict(
index_type="BTREE", num_indexed_rows=1000, num_unindexed_rows=1
)
if request.path == "/v1/table/test/create/?mode=create":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
version=1,
schema=dict(
fields=[
dict(name="id", type={"type": "int64"}, nullable=False),
]
),
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/list/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
indexes=[
{
"index_name": "id_idx",
"columns": ["id"],
},
]
)
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/index/id_idx/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(index_stats)
print(f"{index_stats=}")
request.wfile.write(payload.encode())
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
with pytest.raises(
RuntimeError,
match=re.escape(
'Timeout error: timed out waiting for indices: ["id_idx"] after 1s'
),
):
table.wait_for_index(["id_idx"], timedelta(seconds=1))
def test_stats():
stats = {
"total_bytes": 38,
"num_rows": 2,
"num_indices": 0,
"fragment_stats": {
"num_fragments": 1,
"num_small_fragments": 1,
"lengths": {
"min": 2,
"max": 2,
"mean": 2,
"p25": 2,
"p50": 2,
"p75": 2,
"p99": 2,
},
},
}
def handler(request):
if request.path == "/v1/table/test/create/?mode=create":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/stats/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(stats)
request.wfile.write(payload.encode())
else:
print(request.path)
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
res = table.stats()
print(f"{res=}")
assert res == stats
@contextlib.contextmanager
def query_test_table(query_handler, *, server_version=Version("0.1.0")):
def handler(request):

View File

@@ -9,9 +9,9 @@ from typing import List
from unittest.mock import patch
import lancedb
from lancedb.dependencies import _PANDAS_AVAILABLE
from lancedb.index import HnswPq, HnswSq, IvfPq
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset
@@ -138,13 +138,16 @@ def test_create_table(mem_db: DBConnection):
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
]
df = pd.DataFrame(rows)
pa_table = pa.Table.from_pandas(df, schema=schema)
pa_table = pa.Table.from_pylist(rows, schema=schema)
data = [
("Rows", rows),
("pd_DataFrame", df),
("pa_Table", pa_table),
]
if _PANDAS_AVAILABLE:
import pandas as pd
df = pd.DataFrame(rows)
data.append(("pd_DataFrame", df))
for name, d in data:
tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow()
@@ -296,7 +299,7 @@ def test_add_subschema(mem_db: DBConnection):
data = {"price": 10.0, "item": "foo"}
table.add([data])
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
data = pa.Table.from_pydict({"price": [2.0], "vector": [[3.1, 4.1]]})
table.add(data)
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
table.add([data])
@@ -405,6 +408,7 @@ def test_add_nullability(mem_db: DBConnection):
def test_add_pydantic_model(mem_db: DBConnection):
pytest.importorskip("pandas")
# https://github.com/lancedb/lancedb/issues/562
class Metadata(BaseModel):
@@ -473,10 +477,10 @@ def test_polars(mem_db: DBConnection):
table = mem_db.create_table("test", data=pl.DataFrame(data))
assert len(table) == 2
result = table.to_pandas()
assert np.allclose(result["vector"].tolist(), data["vector"])
assert result["item"].tolist() == data["item"]
assert np.allclose(result["price"].tolist(), data["price"])
result = table.to_arrow()
assert np.allclose(result["vector"].to_pylist(), data["vector"])
assert result["item"].to_pylist() == data["item"]
assert np.allclose(result["price"].to_pylist(), data["price"])
schema = pa.schema(
[
@@ -525,6 +529,113 @@ def test_versioning(mem_db: DBConnection):
assert len(table) == 2
def test_tags(mem_db: DBConnection):
table = mem_db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
table.tags.create("tag1", 1)
tags = table.tags.list()
assert "tag1" in tags
assert tags["tag1"]["version"] == 1
table.add(
data=[
{"vector": [10.0, 11.0], "item": "baz", "price": 30.0},
],
)
table.tags.create("tag2", 2)
tags = table.tags.list()
assert "tag1" in tags
assert "tag2" in tags
assert tags["tag1"]["version"] == 1
assert tags["tag2"]["version"] == 2
table.tags.delete("tag2")
table.tags.update("tag1", 2)
tags = table.tags.list()
assert "tag1" in tags
assert tags["tag1"]["version"] == 2
table.tags.update("tag1", 1)
tags = table.tags.list()
assert "tag1" in tags
assert tags["tag1"]["version"] == 1
table.checkout("tag1")
assert table.version == 1
assert table.count_rows() == 2
table.tags.create("tag2", 2)
table.checkout("tag2")
assert table.version == 2
assert table.count_rows() == 3
table.checkout_latest()
table.add(
data=[
{"vector": [12.0, 13.0], "item": "baz", "price": 40.0},
],
)
@pytest.mark.asyncio
async def test_async_tags(mem_db_async: AsyncConnection):
table = await mem_db_async.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
await table.tags.create("tag1", 1)
tags = await table.tags.list()
assert "tag1" in tags
assert tags["tag1"]["version"] == 1
await table.add(
data=[
{"vector": [10.0, 11.0], "item": "baz", "price": 30.0},
],
)
await table.tags.create("tag2", 2)
tags = await table.tags.list()
assert "tag1" in tags
assert "tag2" in tags
assert tags["tag1"]["version"] == 1
assert tags["tag2"]["version"] == 2
await table.tags.delete("tag2")
await table.tags.update("tag1", 2)
tags = await table.tags.list()
assert "tag1" in tags
assert tags["tag1"]["version"] == 2
await table.tags.update("tag1", 1)
tags = await table.tags.list()
assert "tag1" in tags
assert tags["tag1"]["version"] == 1
await table.checkout("tag1")
assert await table.version() == 1
assert await table.count_rows() == 2
await table.tags.create("tag2", 2)
await table.checkout("tag2")
assert await table.version() == 2
assert await table.count_rows() == 3
await table.checkout_latest()
await table.add(
data=[
{"vector": [12.0, 13.0], "item": "baz", "price": 40.0},
],
)
@patch("lancedb.table.AsyncTable.create_index")
def test_create_index_method(mock_create_index, mem_db: DBConnection):
table = mem_db.create_table(
@@ -688,7 +799,7 @@ def test_delete(mem_db: DBConnection):
assert len(table.list_versions()) == 2
assert table.version == 2
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
assert table.to_arrow()["id"].to_pylist() == [1]
def test_update(mem_db: DBConnection):
@@ -852,6 +963,7 @@ def test_merge_insert(mem_db: DBConnection):
ids=["pa.Table", "pd.DataFrame", "rows"],
)
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
pytest.importorskip("pandas")
initial_data = pa.table(
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
)
@@ -948,7 +1060,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
func = MockTextEmbeddingFunction.create()
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
df = pa.table({"text": texts, "vector": func.compute_source_embeddings(texts)})
conf = EmbeddingFunctionConfig(
source_column="text", vector_column="vector", function=func
@@ -973,7 +1085,7 @@ def test_create_f16_table(mem_db: DBConnection):
text: str
vector: Vector(32, value_type=pa.float16())
df = pd.DataFrame(
df = pa.table(
{
"text": [f"s-{i}" for i in range(512)],
"vector": [np.random.randn(32).astype(np.float16) for _ in range(512)],
@@ -986,7 +1098,7 @@ def test_create_f16_table(mem_db: DBConnection):
table.add(df)
table.create_index(num_partitions=2, num_sub_vectors=2)
query = df.vector.iloc[2]
query = df["vector"][2].as_py()
expected = table.search(query).limit(2).to_arrow()
assert "s-2" in expected["text"].to_pylist()
@@ -1002,7 +1114,7 @@ def test_add_with_embedding_function(mem_db: DBConnection):
table = mem_db.create_table("my_table", schema=MyTable)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
df = pa.table({"text": texts})
table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"]
@@ -1033,14 +1145,14 @@ def test_multiple_vector_columns(mem_db: DBConnection):
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
df = pa.Table.from_pylist(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
result1 = table.search(q, vector_column_name="vector1").limit(1).to_arrow()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_arrow()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
assert result1["text"][0] != result2["text"][0]
def test_create_scalar_index(mem_db: DBConnection):
@@ -1078,22 +1190,22 @@ def test_empty_query(mem_db: DBConnection):
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
val = df.id.iloc[0]
df = table.search().select(["id"]).where("text='bar'").limit(1).to_arrow()
val = df["id"][0].as_py()
assert val == 1
table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)])
df = table.search().select(["id"]).to_pandas()
assert len(df) == 100
df = table.search().select(["id"]).to_arrow()
assert df.num_rows == 100
# None is the same as default
df = table.search().select(["id"]).limit(None).to_pandas()
assert len(df) == 100
df = table.search().select(["id"]).limit(None).to_arrow()
assert df.num_rows == 100
# invalid limist is the same as None, wihch is the same as default
df = table.search().select(["id"]).limit(-1).to_pandas()
assert len(df) == 100
df = table.search().select(["id"]).limit(-1).to_arrow()
assert df.num_rows == 100
# valid limit should work
df = table.search().select(["id"]).limit(42).to_pandas()
assert len(df) == 42
df = table.search().select(["id"]).limit(42).to_arrow()
assert df.num_rows == 42
def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
@@ -1112,14 +1224,14 @@ def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
{"vector_col": v1, "text": "foo"},
{"vector_col": v2, "text": "bar"},
]
df = pd.DataFrame(data)
df = pa.Table.from_pylist(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
result2 = table.search(q).limit(1).to_pandas()
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_arrow()
result2 = table.search(q).limit(1).to_arrow()
assert result1["text"].iloc[0] == result2["text"].iloc[0]
assert result1["text"][0].as_py() == result2["text"][0].as_py()
def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
@@ -1139,12 +1251,12 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
df = pa.Table.from_pylist(data)
table.add(df)
q = np.random.randn(10)
with pytest.raises(ValueError):
table.search(q).limit(1).to_pandas()
table.search(q).limit(1).to_arrow()
def test_compact_cleanup(tmp_db: DBConnection):
@@ -1583,3 +1695,31 @@ def test_replace_field_metadata(tmp_path):
schema = table.schema
field = schema[0].metadata
assert field == {b"foo": b"bar"}
def test_stats(mem_db: DBConnection):
table = mem_db.create_table(
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
assert len(table) == 2
stats = table.stats()
print(f"{stats=}")
assert stats == {
"total_bytes": 38,
"num_rows": 2,
"num_indices": 0,
"fragment_stats": {
"num_fragments": 1,
"num_small_fragments": 1,
"lengths": {
"min": 2,
"max": 2,
"mean": 2,
"p25": 2,
"p50": 2,
"p75": 2,
"p99": 2,
},
},
}

View File

@@ -652,6 +652,11 @@ impl HybridQuery {
self.inner_vec.bypass_vector_index();
}
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner_vec.distance_range(lower_bound, upper_bound);
}
pub fn to_vector_query(&mut self) -> PyResult<VectorQuery> {
Ok(VectorQuery {
inner: self.inner_vec.inner.clone(),

View File

@@ -2,6 +2,11 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{collections::HashMap, sync::Arc};
use crate::{
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::Query,
};
use arrow::{
datatypes::{DataType, Schema},
ffi_stream::ArrowArrayStreamReader,
@@ -12,19 +17,13 @@ use lancedb::table::{
Table as LanceDbTable,
};
use pyo3::{
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
exceptions::{PyIOError, PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods, PyInt, PyString},
Bound, FromPyObject, PyAny, PyObject, PyRef, PyResult, Python,
};
use pyo3_async_runtimes::tokio::future_into_py;
use crate::{
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::Query,
};
/// Statistics about a compaction operation.
#[pyclass(get_all)]
#[derive(Clone, Debug)]
@@ -177,15 +176,19 @@ impl Table {
})
}
#[pyo3(signature = (column, index=None, replace=None))]
#[pyo3(signature = (column, index=None, replace=None, wait_timeout=None))]
pub fn create_index<'a>(
self_: PyRef<'a, Self>,
column: String,
index: Option<Bound<'_, PyAny>>,
replace: Option<bool>,
wait_timeout: Option<Bound<'_, PyAny>>,
) -> PyResult<Bound<'a, PyAny>> {
let index = extract_index_params(&index)?;
let mut op = self_.inner_ref()?.create_index(&[column], index);
let timeout = wait_timeout.map(|t| t.extract::<std::time::Duration>().unwrap());
let mut op = self_
.inner_ref()?
.create_index_with_timeout(&[column], index, timeout);
if let Some(replace) = replace {
op = op.replace(replace);
}
@@ -204,6 +207,26 @@ impl Table {
})
}
pub fn wait_for_index<'a>(
self_: PyRef<'a, Self>,
index_names: Vec<String>,
timeout: Bound<'_, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
let inner = self_.inner_ref()?.clone();
let timeout = timeout.extract::<std::time::Duration>()?;
future_into_py(self_.py(), async move {
let index_refs = index_names
.iter()
.map(String::as_str)
.collect::<Vec<&str>>();
inner
.wait_for_index(&index_refs, timeout)
.await
.infer_error()?;
Ok(())
})
}
pub fn prewarm_index(self_: PyRef<'_, Self>, index_name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
@@ -256,6 +279,40 @@ impl Table {
})
}
pub fn stats(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let stats = inner.stats().await.infer_error()?;
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("total_bytes", stats.total_bytes)?;
dict.set_item("num_rows", stats.num_rows)?;
dict.set_item("num_indices", stats.num_indices)?;
let fragment_stats = PyDict::new(py);
fragment_stats.set_item("num_fragments", stats.fragment_stats.num_fragments)?;
fragment_stats.set_item(
"num_small_fragments",
stats.fragment_stats.num_small_fragments,
)?;
let fragment_lengths = PyDict::new(py);
fragment_lengths.set_item("min", stats.fragment_stats.lengths.min)?;
fragment_lengths.set_item("max", stats.fragment_stats.lengths.max)?;
fragment_lengths.set_item("mean", stats.fragment_stats.lengths.mean)?;
fragment_lengths.set_item("p25", stats.fragment_stats.lengths.p25)?;
fragment_lengths.set_item("p50", stats.fragment_stats.lengths.p50)?;
fragment_lengths.set_item("p75", stats.fragment_stats.lengths.p75)?;
fragment_lengths.set_item("p99", stats.fragment_stats.lengths.p99)?;
fragment_stats.set_item("lengths", fragment_lengths)?;
dict.set_item("fragment_stats", fragment_stats)?;
Ok(Some(dict.unbind()))
})
})
}
pub fn __repr__(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),
@@ -298,10 +355,26 @@ impl Table {
})
}
pub fn checkout(self_: PyRef<'_, Self>, version: u64) -> PyResult<Bound<'_, PyAny>> {
pub fn checkout(self_: PyRef<'_, Self>, version: PyObject) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.checkout(version).await.infer_error()
let py = self_.py();
let (is_int, int_value, string_value) = if let Ok(i) = version.downcast_bound::<PyInt>(py) {
let num: u64 = i.extract()?;
(true, num, String::new())
} else if let Ok(s) = version.downcast_bound::<PyString>(py) {
let str_value = s.to_string();
(false, 0, str_value)
} else {
return Err(PyIOError::new_err(
"version must be an integer or a string.",
));
};
future_into_py(py, async move {
if is_int {
inner.checkout(int_value).await.infer_error()
} else {
inner.checkout_tag(&string_value).await.infer_error()
}
})
}
@@ -328,6 +401,11 @@ impl Table {
Query::new(self.inner_ref().unwrap().query())
}
#[getter]
pub fn tags(&self) -> PyResult<Tags> {
Ok(Tags::new(self.inner_ref()?.clone()))
}
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None, retrain=None))]
pub fn optimize(
@@ -411,8 +489,14 @@ impl Table {
}
future_into_py(self_.py(), async move {
builder.execute(Box::new(batches)).await.infer_error()?;
Ok(())
let stats = builder.execute(Box::new(batches)).await.infer_error()?;
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("num_inserted_rows", stats.num_inserted_rows)?;
dict.set_item("num_updated_rows", stats.num_updated_rows)?;
dict.set_item("num_deleted_rows", stats.num_deleted_rows)?;
Ok(dict.unbind())
})
})
}
@@ -562,3 +646,72 @@ pub struct MergeInsertParams {
when_not_matched_by_source_delete: bool,
when_not_matched_by_source_condition: Option<String>,
}
#[pyclass]
pub struct Tags {
inner: LanceDbTable,
}
impl Tags {
pub fn new(table: LanceDbTable) -> Self {
Self { inner: table }
}
}
#[pymethods]
impl Tags {
pub fn list(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let tags = inner.tags().await.infer_error()?;
let res = tags.list().await.infer_error()?;
Python::with_gil(|py| {
let py_dict = PyDict::new(py);
for (key, contents) in res {
let value_dict = PyDict::new(py);
value_dict.set_item("version", contents.version)?;
value_dict.set_item("manifest_size", contents.manifest_size)?;
py_dict.set_item(key, value_dict)?;
}
Ok(py_dict.unbind())
})
})
}
pub fn get_version(self_: PyRef<'_, Self>, tag: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let tags = inner.tags().await.infer_error()?;
let res = tags.get_version(tag.as_str()).await.infer_error()?;
Ok(res)
})
}
pub fn create(self_: PyRef<Self>, tag: String, version: u64) -> PyResult<Bound<PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let mut tags = inner.tags().await.infer_error()?;
tags.create(tag.as_str(), version).await.infer_error()?;
Ok(())
})
}
pub fn delete(self_: PyRef<Self>, tag: String) -> PyResult<Bound<PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let mut tags = inner.tags().await.infer_error()?;
tags.delete(tag.as_str()).await.infer_error()?;
Ok(())
})
}
pub fn update(self_: PyRef<Self>, tag: String, version: u64) -> PyResult<Bound<PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let mut tags = inner.tags().await.infer_error()?;
tags.update(tag.as_str(), version).await.infer_error()?;
Ok(())
})
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.19.0-beta.7"
version = "0.19.1-beta.1"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.19.0-beta.7"
version = "0.19.1-beta.1"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -81,7 +81,7 @@ impl ListingCatalogOptionsBuilder {
/// [`crate::database::listing::ListingDatabase`]
#[derive(Debug)]
pub struct ListingCatalog {
object_store: ObjectStore,
object_store: Arc<ObjectStore>,
uri: String,
@@ -105,7 +105,7 @@ impl ListingCatalog {
}
async fn open_path(path: &str) -> Result<Self> {
let (object_store, base_path) = ObjectStore::from_path(path).unwrap();
let (object_store, base_path) = ObjectStore::from_uri(path).await.unwrap();
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
}

View File

@@ -201,7 +201,7 @@ impl ListingDatabaseOptionsBuilder {
/// We will have two tables named `table1` and `table2`.
#[derive(Debug)]
pub struct ListingDatabase {
object_store: ObjectStore,
object_store: Arc<ObjectStore>,
query_string: Option<String>,
pub(crate) uri: String,

View File

@@ -35,6 +35,8 @@ pub enum Error {
Schema { message: String },
#[snafu(display("Runtime error: {message}"))]
Runtime { message: String },
#[snafu(display("Timeout error: {message}"))]
Timeout { message: String },
// 3rd party / external errors
#[snafu(display("object_store error: {source}"))]

View File

@@ -1,11 +1,11 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::Arc;
use scalar::FtsIndexBuilder;
use serde::Deserialize;
use serde_with::skip_serializing_none;
use std::sync::Arc;
use std::time::Duration;
use vector::IvfFlatIndexBuilder;
use crate::{table::BaseTable, DistanceType, Error, Result};
@@ -17,6 +17,7 @@ use self::{
pub mod scalar;
pub mod vector;
pub mod waiter;
/// Supported index types.
#[derive(Debug, Clone)]
@@ -69,6 +70,7 @@ pub struct IndexBuilder {
pub(crate) index: Index,
pub(crate) columns: Vec<String>,
pub(crate) replace: bool,
pub(crate) wait_timeout: Option<Duration>,
}
impl IndexBuilder {
@@ -78,6 +80,7 @@ impl IndexBuilder {
index,
columns,
replace: true,
wait_timeout: None,
}
}
@@ -91,6 +94,15 @@ impl IndexBuilder {
self
}
/// Duration of time to wait for asynchronous indexing to complete. If not set,
/// `create_index()` will not wait.
///
/// This is not supported for `NativeTable` since indexing is synchronous.
pub fn wait_timeout(mut self, d: Duration) -> Self {
self.wait_timeout = Some(d);
self
}
pub async fn execute(self) -> Result<()> {
self.parent.clone().create_index(self).await
}

View File

@@ -0,0 +1,89 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::error::Result;
use crate::table::BaseTable;
use crate::Error;
use log::debug;
use std::time::{Duration, Instant};
use tokio::time::sleep;
const DEFAULT_SLEEP_MS: u64 = 1000;
const MAX_WAIT: Duration = Duration::from_secs(2 * 60 * 60);
/// Poll the table using list_indices() and index_stats() until all of the indices have 0 un-indexed rows.
/// Will return Error::Timeout if the columns are not fully indexed within the timeout.
pub async fn wait_for_index(
table: &dyn BaseTable,
index_names: &[&str],
timeout: Duration,
) -> Result<()> {
if timeout > MAX_WAIT {
return Err(Error::InvalidInput {
message: format!("timeout must be less than {:?}", MAX_WAIT),
});
}
let start = Instant::now();
let mut remaining = index_names.to_vec();
// poll via list_indices() and index_stats() until all indices are created and fully indexed
while start.elapsed() < timeout {
let mut completed = vec![];
let indices = table.list_indices().await?;
for &idx in &remaining {
if !indices.iter().any(|i| i.name == *idx) {
debug!("still waiting for new index '{}'", idx);
continue;
}
let stats = table.index_stats(idx.as_ref()).await?;
match stats {
None => {
debug!("still waiting for new index '{}'", idx);
continue;
}
Some(s) => {
if s.num_unindexed_rows == 0 {
// note: this may never stabilize under constant writes.
// we should later replace this with a status/job model
completed.push(idx);
debug!(
"fully indexed '{}'. indexed rows: {}",
idx, s.num_indexed_rows
);
} else {
debug!(
"still waiting for index '{}'. unindexed rows: {}",
idx, s.num_unindexed_rows
);
}
}
}
}
remaining.retain(|idx| !completed.contains(idx));
if remaining.is_empty() {
return Ok(());
}
sleep(Duration::from_millis(DEFAULT_SLEEP_MS)).await;
}
// debug log index diagnostics
for &r in &remaining {
let stats = table.index_stats(r.as_ref()).await?;
match stats {
Some(s) => debug!(
"index '{}' not fully indexed after {:?}. stats: {:?}",
r, timeout, s
),
None => debug!("index '{}' not found after {:?}", r, timeout),
}
}
Err(Error::Timeout {
message: format!(
"timed out waiting for indices: {:?} after {:?}",
remaining, timeout
),
})
}

View File

@@ -8,6 +8,7 @@
pub(crate) mod client;
pub(crate) mod db;
mod retry;
pub(crate) mod table;
pub(crate) mod util;

View File

@@ -1,17 +1,17 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use http::HeaderName;
use log::debug;
use reqwest::{
header::{HeaderMap, HeaderValue},
Request, RequestBuilder, Response,
Body, Request, RequestBuilder, Response,
};
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
use crate::error::{Error, Result};
use crate::remote::db::RemoteOptions;
use crate::remote::retry::{ResolvedRetryConfig, RetryCounter};
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
@@ -118,41 +118,14 @@ pub struct RetryConfig {
/// You can also set the `LANCE_CLIENT_RETRY_STATUSES` environment variable
/// to set this value. Use a comma-separated list of integer values.
///
/// The default is 429, 500, 502, 503.
/// Note that write operations will never be retried on 5xx errors as this may
/// result in duplicated writes.
///
/// The default is 409, 429, 500, 502, 503, 504.
pub statuses: Option<Vec<u16>>,
// TODO: should we allow customizing methods?
}
#[derive(Debug, Clone)]
struct ResolvedRetryConfig {
retries: u8,
connect_retries: u8,
read_retries: u8,
backoff_factor: f32,
backoff_jitter: f32,
statuses: Vec<reqwest::StatusCode>,
}
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
type Error = Error;
fn try_from(retry_config: RetryConfig) -> Result<Self> {
Ok(Self {
retries: retry_config.retries.unwrap_or(3),
connect_retries: retry_config.connect_retries.unwrap_or(3),
read_retries: retry_config.read_retries.unwrap_or(3),
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
statuses: retry_config
.statuses
.unwrap_or_else(|| vec![429, 500, 502, 503])
.into_iter()
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
.collect(),
})
}
}
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
// we can mock responses in tests. Based on the patterns from this blog post:
// https://write.as/balrogboogie/testing-reqwest-based-clients
@@ -160,8 +133,8 @@ impl TryFrom<RetryConfig> for ResolvedRetryConfig {
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
client: reqwest::Client,
host: String,
retry_config: ResolvedRetryConfig,
sender: S,
pub(crate) retry_config: ResolvedRetryConfig,
pub(crate) sender: S,
}
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
@@ -375,74 +348,69 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
self.client.post(full_uri)
}
pub async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> {
let (client, request) = req.build_split();
let mut request = request.unwrap();
let request_id = self.extract_request_id(&mut request);
self.log_request(&request, &request_id);
// Set a request id.
// TODO: allow the user to supply this, through middleware?
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
let header = HeaderValue::from_str(&request_id).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, header);
request_id
};
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
if with_retry {
self.send_with_retry_impl(client, request, request_id).await
} else {
let response = self
.sender
.send(&client, request)
.await
.err_to_http(request_id.clone())?;
debug!(
"Received response for request_id={}: {:?}",
request_id, &response
);
Ok((request_id, response))
}
let response = self
.sender
.send(&client, request)
.await
.err_to_http(request_id.clone())?;
debug!(
"Received response for request_id={}: {:?}",
request_id, &response
);
Ok((request_id, response))
}
async fn send_with_retry_impl(
/// Send the request using retries configured in the RetryConfig.
/// If retry_5xx is false, 5xx requests will not be retried regardless of the statuses configured
/// in the RetryConfig.
/// Since this requires arrow serialization, this is implemented here instead of in RestfulLanceDbClient
pub async fn send_with_retry(
&self,
client: reqwest::Client,
req: Request,
request_id: String,
req_builder: RequestBuilder,
mut make_body: Option<Box<dyn FnMut() -> Result<Body> + Send + 'static>>,
retry_5xx: bool,
) -> Result<(String, Response)> {
let mut retry_counter = RetryCounter::new(&self.retry_config, request_id);
let retry_config = &self.retry_config;
let non_5xx_statuses = retry_config
.statuses
.iter()
.filter(|s| !s.is_server_error())
.cloned()
.collect::<Vec<_>>();
// clone and build the request to extract the request id
let tmp_req = req_builder.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
let (_, r) = tmp_req.build_split();
let mut r = r.unwrap();
let request_id = self.extract_request_id(&mut r);
let mut retry_counter = RetryCounter::new(retry_config, request_id.clone());
loop {
// This only works if the request body is not a stream. If it is
// a stream, we can't use the retry path. We would need to implement
// an outer retry.
let request = req.try_clone().ok_or_else(|| Error::Runtime {
let mut req_builder = req_builder.try_clone().ok_or_else(|| Error::Runtime {
message: "Attempted to retry a request that cannot be cloned".to_string(),
})?;
let response = self
.sender
.send(&client, request)
.await
.map(|r| (r.status(), r));
// set the streaming body on the request builder after clone
if let Some(body_gen) = make_body.as_mut() {
let body = body_gen()?;
req_builder = req_builder.body(body);
}
let (c, request) = req_builder.build_split();
let mut request = request.unwrap();
self.set_request_id(&mut request, &request_id.clone());
self.log_request(&request, &request_id);
let response = self.sender.send(&c, request).await.map(|r| (r.status(), r));
match response {
Ok((status, response)) if status.is_success() => {
debug!(
@@ -451,7 +419,10 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
);
return Ok((retry_counter.request_id, response));
}
Ok((status, response)) if self.retry_config.statuses.contains(&status) => {
Ok((status, response))
if (retry_5xx && retry_config.statuses.contains(&status))
|| non_5xx_statuses.contains(&status) =>
{
let source = self
.check_response(&retry_counter.request_id, response)
.await
@@ -480,6 +451,47 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
}
fn log_request(&self, request: &Request, request_id: &String) {
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
}
/// Extract the request ID from the request headers.
/// If the request ID header is not set, this will generate a new one and set
/// it on the request headers
pub fn extract_request_id(&self, request: &mut Request) -> String {
// Set a request id.
// TODO: allow the user to supply this, through middleware?
let request_id = if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) {
request_id.to_str().unwrap().to_string()
} else {
let request_id = uuid::Uuid::new_v4().to_string();
self.set_request_id(request, &request_id);
request_id
};
request_id
}
/// Set the request ID header
pub fn set_request_id(&self, request: &mut Request, request_id: &str) {
let header = HeaderValue::from_str(request_id).unwrap();
request.headers_mut().insert(REQUEST_ID_HEADER, header);
}
pub async fn check_response(&self, request_id: &str, response: Response) -> Result<Response> {
// Try to get the response text, but if that fails, just return the status code
let status = response.status();
@@ -501,91 +513,6 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
}
}
struct RetryCounter<'a> {
request_failures: u8,
connect_failures: u8,
read_failures: u8,
config: &'a ResolvedRetryConfig,
request_id: String,
}
impl<'a> RetryCounter<'a> {
fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
Self {
request_failures: 0,
connect_failures: 0,
read_failures: 0,
config,
request_id,
}
}
fn check_out_of_retries(
&self,
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
) -> Result<()> {
if self.request_failures >= self.config.retries
|| self.connect_failures >= self.config.connect_retries
|| self.read_failures >= self.config.read_retries
{
Err(Error::Retry {
request_id: self.request_id.clone(),
request_failures: self.request_failures,
max_request_failures: self.config.retries,
connect_failures: self.connect_failures,
max_connect_failures: self.config.connect_retries,
read_failures: self.read_failures,
max_read_failures: self.config.read_retries,
source,
status_code,
})
} else {
Ok(())
}
}
fn increment_request_failures(&mut self, source: crate::Error) -> Result<()> {
self.request_failures += 1;
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_connect_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.connect_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn increment_read_failures(&mut self, source: reqwest::Error) -> Result<()> {
self.read_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
fn next_sleep_time(&self) -> Duration {
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
self.request_failures,
self.config.retries,
self.read_failures,
self.config.read_retries,
sleep_time
);
sleep_time
}
}
pub trait RequestResultExt {
type Output;
fn err_to_http(self, request_id: String) -> Result<Self::Output>;

View File

@@ -255,7 +255,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
if let Some(start_after) = request.start_after {
req = req.query(&[("page_token", start_after)]);
}
let (request_id, rsp) = self.client.send(req, true).await?;
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
let rsp = self.client.check_response(&request_id, rsp).await?;
let version = parse_server_version(&request_id, &rsp)?;
let tables = rsp
@@ -302,7 +302,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
.body(data_buffer)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
let (request_id, rsp) = self.client.send(req, false).await?;
let (request_id, rsp) = self.client.send(req).await?;
if rsp.status() == StatusCode::BAD_REQUEST {
let body = rsp.text().await.err_to_http(request_id.clone())?;
@@ -362,7 +362,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
let req = self
.client
.post(&format!("/v1/table/{}/describe/", request.name));
let (request_id, rsp) = self.client.send(req, true).await?;
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
if rsp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: request.name });
}
@@ -383,7 +383,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
.client
.post(&format!("/v1/table/{}/rename/", current_name));
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
let (request_id, resp) = self.client.send(req, false).await?;
let (request_id, resp) = self.client.send(req).await?;
self.client.check_response(&request_id, resp).await?;
let table = self.table_cache.remove(current_name).await;
if let Some(table) = table {
@@ -394,7 +394,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
async fn drop_table(&self, name: &str) -> Result<()> {
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
let (request_id, resp) = self.client.send(req, true).await?;
let (request_id, resp) = self.client.send(req).await?;
self.client.check_response(&request_id, resp).await?;
self.table_cache.remove(name).await;
Ok(())

View File

@@ -0,0 +1,122 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use crate::remote::RetryConfig;
use crate::Error;
use log::debug;
use std::time::Duration;
pub struct RetryCounter<'a> {
pub request_failures: u8,
pub connect_failures: u8,
pub read_failures: u8,
pub config: &'a ResolvedRetryConfig,
pub request_id: String,
}
impl<'a> RetryCounter<'a> {
pub(crate) fn new(config: &'a ResolvedRetryConfig, request_id: String) -> Self {
Self {
request_failures: 0,
connect_failures: 0,
read_failures: 0,
config,
request_id,
}
}
fn check_out_of_retries(
&self,
source: Box<dyn std::error::Error + Send + Sync>,
status_code: Option<reqwest::StatusCode>,
) -> crate::Result<()> {
if self.request_failures >= self.config.retries
|| self.connect_failures >= self.config.connect_retries
|| self.read_failures >= self.config.read_retries
{
Err(Error::Retry {
request_id: self.request_id.clone(),
request_failures: self.request_failures,
max_request_failures: self.config.retries,
connect_failures: self.connect_failures,
max_connect_failures: self.config.connect_retries,
read_failures: self.read_failures,
max_read_failures: self.config.read_retries,
source,
status_code,
})
} else {
Ok(())
}
}
pub fn increment_request_failures(&mut self, source: crate::Error) -> crate::Result<()> {
self.request_failures += 1;
let status_code = if let crate::Error::Http { status_code, .. } = &source {
*status_code
} else {
None
};
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn increment_connect_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
self.connect_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn increment_read_failures(&mut self, source: reqwest::Error) -> crate::Result<()> {
self.read_failures += 1;
let status_code = source.status();
self.check_out_of_retries(Box::new(source), status_code)
}
pub fn next_sleep_time(&self) -> Duration {
let backoff = self.config.backoff_factor * (2.0f32.powi(self.request_failures as i32));
let jitter = rand::random::<f32>() * self.config.backoff_jitter;
let sleep_time = Duration::from_secs_f32(backoff + jitter);
debug!(
"Retrying request {:?} ({}/{} connect, {}/{} read, {}/{} read) in {:?}",
self.request_id,
self.connect_failures,
self.config.connect_retries,
self.request_failures,
self.config.retries,
self.read_failures,
self.config.read_retries,
sleep_time
);
sleep_time
}
}
#[derive(Debug, Clone)]
pub struct ResolvedRetryConfig {
pub retries: u8,
pub connect_retries: u8,
pub read_retries: u8,
pub backoff_factor: f32,
pub backoff_jitter: f32,
pub statuses: Vec<reqwest::StatusCode>,
}
impl TryFrom<RetryConfig> for ResolvedRetryConfig {
type Error = Error;
fn try_from(retry_config: RetryConfig) -> crate::Result<Self> {
Ok(Self {
retries: retry_config.retries.unwrap_or(3),
connect_retries: retry_config.connect_retries.unwrap_or(3),
read_retries: retry_config.read_retries.unwrap_or(3),
backoff_factor: retry_config.backoff_factor.unwrap_or(0.25),
backoff_jitter: retry_config.backoff_jitter.unwrap_or(0.25),
statuses: retry_config
.statuses
.unwrap_or_else(|| vec![409, 429, 500, 502, 503, 504])
.into_iter()
.map(|status| reqwest::StatusCode::from_u16(status).unwrap())
.collect(),
})
}
}

View File

@@ -1,17 +1,14 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::io::Cursor;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use crate::index::Index;
use crate::index::IndexStatistics;
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
use crate::table::{AddDataMode, AnyQuery, Filter};
use crate::table::Tags;
use crate::table::{AddDataMode, AnyQuery, Filter, TableStatistics};
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::{DistanceType, Error, Table};
use arrow_array::RecordBatchReader;
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
@@ -22,12 +19,24 @@ use futures::TryStreamExt;
use http::header::CONTENT_TYPE;
use http::{HeaderName, StatusCode};
use lance::arrow::json::{JsonDataType, JsonSchema};
use lance::dataset::refs::TagContents;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
use lance_datafusion::exec::{execute_plan, OneShotExec};
use reqwest::{RequestBuilder, Response};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Cursor;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::RwLock;
use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE;
use crate::index::waiter::wait_for_index;
use crate::{
connection::NoData,
error::Result,
@@ -38,14 +47,141 @@ use crate::{
TableDefinition, UpdateBuilder,
},
};
use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE;
use lance::dataset::MergeStats;
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
pub struct RemoteTags<'a, S: HttpSend = Sender> {
inner: &'a RemoteTable<S>,
}
#[async_trait]
impl<S: HttpSend + 'static> Tags for RemoteTags<'_, S> {
async fn list(&self) -> Result<HashMap<String, TagContents>> {
let request = self
.inner
.client
.post(&format!("/v1/table/{}/tags/list/", self.inner.name));
let (request_id, response) = self.inner.send(request, true).await?;
let response = self
.inner
.check_table_response(&request_id, response)
.await?;
match response.text().await {
Ok(body) => {
// Explicitly tell serde_json what type we want to deserialize into
let tags_map: HashMap<String, TagContents> =
serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!("Failed to parse tags list: {}", e).into(),
request_id,
status_code: None,
})?;
Ok(tags_map)
}
Err(err) => {
let status_code = err.status();
Err(Error::Http {
source: Box::new(err),
request_id,
status_code,
})
}
}
}
async fn get_version(&self, tag: &str) -> Result<u64> {
let request = self
.inner
.client
.post(&format!("/v1/table/{}/tags/version/", self.inner.name))
.json(&serde_json::json!({ "tag": tag }));
let (request_id, response) = self.inner.send(request, true).await?;
let response = self
.inner
.check_table_response(&request_id, response)
.await?;
match response.text().await {
Ok(body) => {
let value: serde_json::Value =
serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!("Failed to parse tag version: {}", e).into(),
request_id: request_id.clone(),
status_code: None,
})?;
value
.get("version")
.and_then(|v| v.as_u64())
.ok_or_else(|| Error::Http {
source: format!("Invalid tag version response: {}", body).into(),
request_id,
status_code: None,
})
}
Err(err) => {
let status_code = err.status();
Err(Error::Http {
source: Box::new(err),
request_id,
status_code,
})
}
}
}
async fn create(&mut self, tag: &str, version: u64) -> Result<()> {
let request = self
.inner
.client
.post(&format!("/v1/table/{}/tags/create/", self.inner.name))
.json(&serde_json::json!({
"tag": tag,
"version": version
}));
let (request_id, response) = self.inner.send(request, true).await?;
self.inner
.check_table_response(&request_id, response)
.await?;
Ok(())
}
async fn delete(&mut self, tag: &str) -> Result<()> {
let request = self
.inner
.client
.post(&format!("/v1/table/{}/tags/delete/", self.inner.name))
.json(&serde_json::json!({ "tag": tag }));
let (request_id, response) = self.inner.send(request, true).await?;
self.inner
.check_table_response(&request_id, response)
.await?;
Ok(())
}
async fn update(&mut self, tag: &str, version: u64) -> Result<()> {
let request = self
.inner
.client
.post(&format!("/v1/table/{}/tags/update/", self.inner.name))
.json(&serde_json::json!({
"tag": tag,
"version": version
}));
let (request_id, response) = self.inner.send(request, true).await?;
self.inner
.check_table_response(&request_id, response)
.await?;
Ok(())
}
}
#[derive(Debug)]
pub struct RemoteTable<S: HttpSend = Sender> {
#[allow(dead_code)]
@@ -83,7 +219,7 @@ impl<S: HttpSend> RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
@@ -127,6 +263,61 @@ impl<S: HttpSend> RemoteTable<S> {
Ok(reqwest::Body::wrap_stream(body_stream))
}
/// Buffer the reader into memory
async fn buffer_reader<R: RecordBatchReader + ?Sized>(
reader: &mut R,
) -> Result<(SchemaRef, Vec<RecordBatch>)> {
let schema = reader.schema();
let mut batches = Vec::new();
for batch in reader {
batches.push(batch?);
}
Ok((schema, batches))
}
/// Create a new RecordBatchReader from buffered data
fn make_reader(schema: SchemaRef, batches: Vec<RecordBatch>) -> impl RecordBatchReader {
let iter = batches.into_iter().map(Ok);
RecordBatchIterator::new(iter, schema)
}
async fn send(&self, req: RequestBuilder, with_retry: bool) -> Result<(String, Response)> {
let res = if with_retry {
self.client.send_with_retry(req, None, true).await?
} else {
self.client.send(req).await?
};
Ok(res)
}
/// Send the request with streaming body.
/// This will use retries if with_retry is set and the number of configured retries is > 0.
/// If retries are enabled, the stream will be buffered into memory.
async fn send_streaming(
&self,
req: RequestBuilder,
mut data: Box<dyn RecordBatchReader + Send>,
with_retry: bool,
) -> Result<(String, Response)> {
if !with_retry || self.client.retry_config.retries == 0 {
let body = Self::reader_as_body(data)?;
return self.client.send(req.body(body)).await;
}
// to support retries, buffer into memory and clone the batches on each retry
let (schema, batches) = Self::buffer_reader(&mut *data).await?;
let make_body = Box::new(move || {
let reader = Self::make_reader(schema.clone(), batches.clone());
Self::reader_as_body(Box::new(reader))
});
let res = self
.client
.send_with_retry(req, Some(make_body), false)
.await?;
Ok(res)
}
async fn check_table_response(
&self,
request_id: &str,
@@ -168,7 +359,8 @@ impl<S: HttpSend> RemoteTable<S> {
}
// Server requires k.
let limit = params.limit.unwrap_or(usize::MAX);
// use isize::MAX as usize to avoid overflow: https://github.com/lancedb/lancedb/issues/2211
let limit = params.limit.unwrap_or(isize::MAX as usize);
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
if let Some(filter) = &params.filter {
@@ -339,8 +531,6 @@ impl<S: HttpSend> RemoteTable<S> {
let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name));
if let Some(timeout) = options.timeout {
// Client side timeout
request = request.timeout(timeout);
// Also send to server, so it can abort the query if it takes too long.
// (If it doesn't fit into u64, it's not worth sending anyways.)
if let Ok(timeout_ms) = u64::try_from(timeout.as_millis()) {
@@ -355,11 +545,29 @@ impl<S: HttpSend> RemoteTable<S> {
.collect();
let futures = requests.into_iter().map(|req| async move {
let (request_id, response) = self.client.send(req, true).await?;
let (request_id, response) = self.send(req, true).await?;
self.read_arrow_stream(&request_id, response).await
});
let streams = futures::future::try_join_all(futures).await?;
Ok(streams)
let streams = futures::future::try_join_all(futures);
if let Some(timeout) = options.timeout {
let timeout_future = tokio::time::sleep(timeout);
tokio::pin!(timeout_future);
tokio::pin!(streams);
tokio::select! {
_ = &mut timeout_future => {
Err(Error::Other {
message: format!("Query timeout after {} ms", timeout.as_millis()),
source: None,
})
}
result = &mut streams => {
Ok(result?)
}
}
} else {
Ok(streams.await?)
}
}
async fn prepare_query_bodies(&self, query: &AnyQuery) -> Result<Vec<serde_json::Value>> {
@@ -455,7 +663,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
self.checkout_latest().await?;
Ok(())
@@ -465,7 +673,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let request = self
.client
.post(&format!("/v1/table/{}/version/list/", self.name));
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
#[derive(Deserialize)]
@@ -511,7 +719,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
request = request.json(&body);
}
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
@@ -529,12 +737,10 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
self.check_mutable().await?;
let body = Self::reader_as_body(data)?;
let mut request = self
.client
.post(&format!("/v1/table/{}/insert/", self.name))
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(body);
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
match add.mode {
AddDataMode::Append => {}
@@ -543,8 +749,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
}
}
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send_streaming(request, data, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
@@ -612,7 +817,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.collect::<Vec<_>>();
let futures = requests.into_iter().map(|req| async move {
let (request_id, response) = self.client.send(req, true).await?;
let (request_id, response) = self.send(req, true).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
@@ -654,7 +859,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.collect();
let futures = requests.into_iter().map(|req| async move {
let (request_id, response) = self.client.send(req, true).await?;
let (request_id, response) = self.send(req, true).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
@@ -696,7 +901,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
"predicate": update.filter,
}));
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
@@ -710,7 +915,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/delete/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -796,32 +1001,55 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let request = request.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
if let Some(wait_timeout) = index.wait_timeout {
let name = format!("{}_idx", column);
self.wait_for_index(&[&name], wait_timeout).await?;
}
Ok(())
}
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(&self, index_names: &[&str], timeout: Duration) -> Result<()> {
wait_for_index(self, index_names, timeout).await
}
async fn merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
) -> Result<MergeStats> {
self.check_mutable().await?;
let query = MergeInsertRequest::try_from(params)?;
let body = Self::reader_as_body(new_data)?;
let request = self
.client
.post(&format!("/v1/table/{}/merge_insert/", self.name))
.query(&query)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(body);
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send_streaming(request, new_data, true).await?;
// TODO: server can response with these stats in response body.
// We should test that we can handle both empty response from old server
// and response with stats from new server.
self.check_table_response(&request_id, response).await?;
Ok(MergeStats::default())
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
Ok(Box::new(RemoteTags { inner: self }))
}
async fn checkout_tag(&self, tag: &str) -> Result<()> {
let tags = self.tags().await?;
let version = tags.get_version(tag).await?;
let mut write_guard = self.version.write().await;
*write_guard = Some(version);
Ok(())
}
async fn optimize(&self, _action: OptimizeAction) -> Result<OptimizeStats> {
@@ -852,7 +1080,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/add_columns/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?; // todo:
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -891,7 +1119,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/alter_columns/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -903,7 +1131,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/drop_columns/", self.name))
.json(&body);
let (request_id, response) = self.client.send(request, false).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -917,7 +1145,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
#[derive(Deserialize)]
@@ -974,7 +1202,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
@@ -998,7 +1226,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
"/v1/table/{}/index/{}/drop/",
self.name, index_name
));
let (request_id, response) = self.client.send(request, true).await?;
let (request_id, response) = self.send(request, true).await?;
self.check_table_response(&request_id, response).await?;
Ok(())
}
@@ -1017,6 +1245,20 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
fn dataset_uri(&self) -> &str {
"NOT_SUPPORTED"
}
async fn stats(&self) -> Result<TableStatistics> {
let request = self.client.post(&format!("/v1/table/{}/stats/", self.name));
let (request_id, response) = self.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
let stats = serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!("Failed to parse table statistics: {}", e).into(),
request_id,
status_code: None,
})?;
Ok(stats)
}
}
#[derive(Serialize)]
@@ -1109,7 +1351,12 @@ mod tests {
Box::pin(table.count_rows(None).map_ok(|_| ())),
Box::pin(table.update().column("a", "a + 1").execute().map_ok(|_| ())),
Box::pin(table.add(example_data()).execute().map_ok(|_| ())),
Box::pin(table.merge_insert(&["test"]).execute(example_data())),
Box::pin(
table
.merge_insert(&["test"])
.execute(example_data())
.map_ok(|_| ()),
),
Box::pin(table.delete("false")),
Box::pin(table.add_columns(
NewColumnTransform::SqlExpressions(vec![("x".into(), "y".into())]),
@@ -1459,6 +1706,42 @@ mod tests {
assert_eq!(&body, &expected_body);
}
#[tokio::test]
async fn test_merge_insert_retries_on_409() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let data = Box::new(RecordBatchIterator::new(
[Ok(batch.clone())],
batch.schema(),
));
// Default parameters
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/");
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
assert_eq!(params["on"], "some_col");
assert_eq!(params["when_matched_update_all"], "false");
assert_eq!(params["when_not_matched_insert_all"], "false");
assert_eq!(params["when_not_matched_by_source_delete"], "false");
assert!(!params.contains_key("when_matched_update_all_filt"));
assert!(!params.contains_key("when_not_matched_by_source_delete_filt"));
http::Response::builder().status(409).body("").unwrap()
});
let e = table
.merge_insert(&["some_col"])
.execute(data)
.await
.unwrap_err();
assert!(e.to_string().contains("Hit retry limit"));
}
#[tokio::test]
async fn test_delete() {
let table = Table::new_with_handler("my_table", |request| {
@@ -1500,7 +1783,7 @@ mod tests {
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let expected_body = serde_json::json!({
"k": usize::MAX,
"k": isize::MAX as usize,
"prefilter": true,
"vector": [], // Empty vector means no vector query.
"version": null,
@@ -2416,4 +2699,88 @@ mod tests {
});
table.drop_index("my_index").await.unwrap();
}
#[tokio::test]
async fn test_wait_for_index() {
let table = _make_table_with_indices(0);
table
.wait_for_index(&["vector_idx", "my_idx"], Duration::from_secs(1))
.await
.unwrap();
}
#[tokio::test]
async fn test_wait_for_index_timeout() {
let table = _make_table_with_indices(100);
let e = table
.wait_for_index(&["vector_idx", "my_idx"], Duration::from_secs(1))
.await
.unwrap_err();
assert_eq!(
e.to_string(),
"Timeout error: timed out waiting for indices: [\"vector_idx\", \"my_idx\"] after 1s"
);
}
#[tokio::test]
async fn test_wait_for_index_timeout_never_created() {
let table = _make_table_with_indices(0);
let e = table
.wait_for_index(&["doesnt_exist_idx"], Duration::from_secs(1))
.await
.unwrap_err();
assert_eq!(
e.to_string(),
"Timeout error: timed out waiting for indices: [\"doesnt_exist_idx\"] after 1s"
);
}
fn _make_table_with_indices(unindexed_rows: usize) -> Table {
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
let response_body = match request.url().path() {
"/v1/table/my_table/index/list/" => {
serde_json::json!({
"indexes": [
{
"index_name": "vector_idx",
"index_uuid": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
"columns": ["vector"],
"index_status": "done",
},
{
"index_name": "my_idx",
"index_uuid": "34255f64-5717-4562-b3fc-2c963f66afa6",
"columns": ["my_column"],
"index_status": "done",
},
]
})
}
"/v1/table/my_table/index/vector_idx/stats/" => {
serde_json::json!({
"num_indexed_rows": 100000,
"num_unindexed_rows": unindexed_rows,
"index_type": "IVF_PQ",
"distance_type": "l2"
})
}
"/v1/table/my_table/index/my_idx/stats/" => {
serde_json::json!({
"num_indexed_rows": 100000,
"num_unindexed_rows": unindexed_rows,
"index_type": "LABEL_LIST"
})
}
_path => {
serde_json::json!(None::<String>)
}
};
let body = serde_json::to_string(&response_body).unwrap();
let status = if body == "null" { 404 } else { 200 };
http::Response::builder().status(status).body(body).unwrap()
});
table
}
}

View File

@@ -3,10 +3,6 @@
//! LanceDB Table APIs
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::{RecordBatchIterator, RecordBatchReader};
@@ -24,6 +20,7 @@ use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
use lance::dataset::scanner::Scanner;
pub use lance::dataset::ColumnAlteration;
pub use lance::dataset::MergeStats;
pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams;
pub use lance::dataset::Version;
@@ -45,6 +42,10 @@ use lance_table::format::Manifest;
use lance_table::io::commit::ManifestNamingScheme;
use log::info;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::format;
use std::path::Path;
use std::sync::Arc;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
@@ -78,10 +79,15 @@ pub mod datafusion;
pub(crate) mod dataset;
pub mod merge;
use crate::index::waiter::wait_for_index;
pub use chrono::Duration;
use futures::future::join_all;
pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::refs::{TagContents, Tags as LanceTags};
pub use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::statistics::DatasetStatisticsExt;
pub use lance_index::optimize::OptimizeOptions;
use serde_with::skip_serializing_none;
/// Defines the type of column
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -400,6 +406,24 @@ pub enum AnyQuery {
VectorQuery(VectorQueryRequest),
}
#[async_trait]
pub trait Tags: Send + Sync {
/// List the tags of the table.
async fn list(&self) -> Result<HashMap<String, TagContents>>;
/// Get the version of the table referenced by a tag.
async fn get_version(&self, tag: &str) -> Result<u64>;
/// Create a new tag for the given version of the table.
async fn create(&mut self, tag: &str, version: u64) -> Result<()>;
/// Delete a tag from the table.
async fn delete(&mut self, tag: &str) -> Result<()>;
/// Update an existing tag to point to a new version of the table.
async fn update(&mut self, tag: &str, version: u64) -> Result<()>;
}
/// A trait for anything "table-like". This is used for both native tables (which target
/// Lance datasets) and remote tables (which target LanceDB cloud)
///
@@ -464,7 +488,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()>;
) -> Result<MergeStats>;
/// Gets the table tag manager.
async fn tags(&self) -> Result<Box<dyn Tags + '_>>;
/// Optimize the dataset.
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats>;
/// Add columns to the table.
@@ -481,6 +507,9 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
async fn version(&self) -> Result<u64>;
/// Checkout a specific version of the table.
async fn checkout(&self, version: u64) -> Result<()>;
/// Checkout a table version referenced by a tag.
/// Tags provide a human-readable way to reference specific versions of the table.
async fn checkout_tag(&self, tag: &str) -> Result<()>;
/// Checkout the latest version of the table.
async fn checkout_latest(&self) -> Result<()>;
/// Restore the table to the currently checked out version.
@@ -491,6 +520,15 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
async fn table_definition(&self) -> Result<TableDefinition>;
/// Get the table URI
fn dataset_uri(&self) -> &str;
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(
&self,
index_names: &[&str],
timeout: std::time::Duration,
) -> Result<()>;
/// Get statistics on the table
async fn stats(&self) -> Result<TableStatistics>;
}
/// A Table is a collection of strong typed Rows.
@@ -769,6 +807,28 @@ impl Table {
)
}
/// See [Table::create_index]
/// For remote tables, this allows an optional wait_timeout to poll until asynchronous indexing is complete
pub fn create_index_with_timeout(
&self,
columns: &[impl AsRef<str>],
index: Index,
wait_timeout: Option<std::time::Duration>,
) -> IndexBuilder {
let mut builder = IndexBuilder::new(
self.inner.clone(),
columns
.iter()
.map(|val| val.as_ref().to_string())
.collect::<Vec<_>>(),
index,
);
if let Some(timeout) = wait_timeout {
builder = builder.wait_timeout(timeout);
}
builder
}
/// Create a builder for a merge insert operation
///
/// This operation can add rows, update rows, and remove rows all in a single
@@ -1028,6 +1088,24 @@ impl Table {
self.inner.checkout(version).await
}
/// Checks out a specific version of the Table by tag
///
/// Any read operation on the table will now access the data at the version referenced by the tag.
/// As a consequence, calling this method will disable any read consistency interval
/// that was previously set.
///
/// This is a read-only operation that turns the table into a sort of "view"
/// or "detached head". Other table instances will not be affected. To make the change
/// permanent you can use the `[Self::restore]` method.
///
/// Any operation that modifies the table will fail while the table is in a checked
/// out state.
///
/// To return the table to a normal state use `[Self::checkout_latest]`
pub async fn checkout_tag(&self, tag: &str) -> Result<()> {
self.inner.checkout_tag(tag).await
}
/// Ensures the table is pointing at the latest version
///
/// This can be used to manually update a table when the read_consistency_interval is None
@@ -1104,6 +1182,21 @@ impl Table {
self.inner.prewarm_index(name).await
}
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
pub async fn wait_for_index(
&self,
index_names: &[&str],
timeout: std::time::Duration,
) -> Result<()> {
self.inner.wait_for_index(index_names, timeout).await
}
/// Get the tags manager.
pub async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
self.inner.tags().await
}
// Take many execution plans and map them into a single plan that adds
// a query_index column and unions them.
pub(crate) fn multi_vector_plan(
@@ -1154,6 +1247,40 @@ impl Table {
.unwrap();
Ok(Arc::new(repartitioned))
}
/// Retrieve statistics on the table
pub async fn stats(&self) -> Result<TableStatistics> {
self.inner.stats().await
}
}
pub struct NativeTags {
inner: LanceTags,
}
#[async_trait]
impl Tags for NativeTags {
async fn list(&self) -> Result<HashMap<String, TagContents>> {
Ok(self.inner.list().await?)
}
async fn get_version(&self, tag: &str) -> Result<u64> {
Ok(self.inner.get_version(tag).await?)
}
async fn create(&mut self, tag: &str, version: u64) -> Result<()> {
self.inner.create(tag, version).await?;
Ok(())
}
async fn delete(&mut self, tag: &str) -> Result<()> {
self.inner.delete(tag).await?;
Ok(())
}
async fn update(&mut self, tag: &str, version: u64) -> Result<()> {
self.inner.update(tag, version).await?;
Ok(())
}
}
impl From<NativeTable> for Table {
@@ -1900,6 +2027,10 @@ impl BaseTable for NativeTable {
self.dataset.as_time_travel(version).await
}
async fn checkout_tag(&self, tag: &str) -> Result<()> {
self.dataset.as_time_travel(tag).await
}
async fn checkout_latest(&self) -> Result<()> {
self.dataset
.as_latest(self.read_consistency_interval)
@@ -2237,7 +2368,7 @@ impl BaseTable for NativeTable {
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
) -> Result<MergeStats> {
let dataset = Arc::new(self.dataset.get().await?.clone());
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
match (
@@ -2264,9 +2395,9 @@ impl BaseTable for NativeTable {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
}
let job = builder.try_build()?;
let (new_dataset, _stats) = job.execute_reader(new_data).await?;
let (new_dataset, stats) = job.execute_reader(new_data).await?;
self.dataset.set_latest(new_dataset.as_ref().clone()).await;
Ok(())
Ok(stats)
}
/// Delete rows from the table
@@ -2275,6 +2406,14 @@ impl BaseTable for NativeTable {
Ok(())
}
async fn tags(&self) -> Result<Box<dyn Tags + '_>> {
let dataset = self.dataset.get().await?;
Ok(Box::new(NativeTags {
inner: dataset.tags.clone(),
}))
}
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats> {
let mut stats = OptimizeStats {
compaction: None,
@@ -2430,6 +2569,118 @@ impl BaseTable for NativeTable {
loss,
}))
}
/// Poll until the columns are fully indexed. Will return Error::Timeout if the columns
/// are not fully indexed within the timeout.
async fn wait_for_index(
&self,
index_names: &[&str],
timeout: std::time::Duration,
) -> Result<()> {
wait_for_index(self, index_names, timeout).await
}
async fn stats(&self) -> Result<TableStatistics> {
let num_rows = self.count_rows(None).await?;
let num_indices = self.list_indices().await?.len();
let ds = self.dataset.get().await?;
let ds_clone = (*ds).clone();
let ds_stats = Arc::new(ds_clone).calculate_data_stats().await?;
let total_bytes = ds_stats.fields.iter().map(|f| f.bytes_on_disk).sum::<u64>() as usize;
let frags = ds.get_fragments();
let mut sorted_sizes = join_all(
frags
.iter()
.map(|frag| async move { frag.physical_rows().await.unwrap_or(0) }),
)
.await;
sorted_sizes.sort();
let small_frag_threshold = 100000;
let num_fragments = sorted_sizes.len();
let num_small_fragments = sorted_sizes
.iter()
.filter(|&&size| size < small_frag_threshold)
.count();
let p25 = *sorted_sizes.get(num_fragments / 4).unwrap_or(&0);
let p50 = *sorted_sizes.get(num_fragments / 2).unwrap_or(&0);
let p75 = *sorted_sizes.get(num_fragments * 3 / 4).unwrap_or(&0);
let p99 = *sorted_sizes.get(num_fragments * 99 / 100).unwrap_or(&0);
let min = sorted_sizes.first().copied().unwrap_or(0);
let max = sorted_sizes.last().copied().unwrap_or(0);
let mean = if num_fragments == 0 {
0
} else {
sorted_sizes.iter().copied().sum::<usize>() / num_fragments
};
let frag_stats = FragmentStatistics {
num_fragments,
num_small_fragments,
lengths: FragmentSummaryStats {
min,
max,
mean,
p25,
p50,
p75,
p99,
},
};
let stats = TableStatistics {
total_bytes,
num_rows,
num_indices,
fragment_stats: frag_stats,
};
Ok(stats)
}
}
#[skip_serializing_none]
#[derive(Debug, Deserialize, PartialEq)]
pub struct TableStatistics {
/// The total number of bytes in the table
pub total_bytes: usize,
/// The number of rows in the table
pub num_rows: usize,
/// The number of indices in the table
pub num_indices: usize,
/// Statistics on table fragments
pub fragment_stats: FragmentStatistics,
}
#[skip_serializing_none]
#[derive(Debug, Deserialize, PartialEq)]
pub struct FragmentStatistics {
/// The number of fragments in the table
pub num_fragments: usize,
/// The number of uncompacted fragments in the table
pub num_small_fragments: usize,
/// Statistics on the number of rows in the table fragments
pub lengths: FragmentSummaryStats,
// todo: add size statistics
// /// Statistics on the number of bytes in the table fragments
// sizes: FragmentStats,
}
#[skip_serializing_none]
#[derive(Debug, Deserialize, PartialEq)]
pub struct FragmentSummaryStats {
pub min: usize,
pub max: usize,
pub mean: usize,
pub p25: usize,
pub p50: usize,
pub p75: usize,
pub p99: usize,
}
#[cfg(test)]
@@ -3031,6 +3282,60 @@ mod tests {
)
}
#[tokio::test]
async fn test_tags() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = ConnectBuilder::new(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let table = conn
.create_table("my_table", some_sample_data())
.execute()
.await
.unwrap();
assert_eq!(table.version().await.unwrap(), 1);
table.add(some_sample_data()).execute().await.unwrap();
assert_eq!(table.version().await.unwrap(), 2);
let mut tags_manager = table.tags().await.unwrap();
let tags = tags_manager.list().await.unwrap();
assert!(tags.is_empty(), "Tags should be empty initially");
let tag1 = "tag1";
tags_manager.create(tag1, 1).await.unwrap();
assert_eq!(tags_manager.get_version(tag1).await.unwrap(), 1);
let tags = tags_manager.list().await.unwrap();
assert_eq!(tags.len(), 1);
assert!(tags.contains_key(tag1));
assert_eq!(tags.get(tag1).unwrap().version, 1);
tags_manager.create("tag2", 2).await.unwrap();
assert_eq!(tags_manager.get_version("tag2").await.unwrap(), 2);
let tags = tags_manager.list().await.unwrap();
assert_eq!(tags.len(), 2);
assert!(tags.contains_key(tag1));
assert_eq!(tags.get(tag1).unwrap().version, 1);
assert!(tags.contains_key("tag2"));
assert_eq!(tags.get("tag2").unwrap().version, 2);
// Test update and delete
table.add(some_sample_data()).execute().await.unwrap();
tags_manager.update(tag1, 3).await.unwrap();
assert_eq!(tags_manager.get_version(tag1).await.unwrap(), 3);
tags_manager.delete("tag2").await.unwrap();
let tags = tags_manager.list().await.unwrap();
assert_eq!(tags.len(), 1);
assert!(tags.contains_key(tag1));
assert_eq!(tags.get(tag1).unwrap().version, 3);
// Test checkout tag
table.add(some_sample_data()).execute().await.unwrap();
assert_eq!(table.version().await.unwrap(), 4);
table.checkout_tag(tag1).await.unwrap();
assert_eq!(table.version().await.unwrap(), 3);
table.checkout_latest().await.unwrap();
assert_eq!(table.version().await.unwrap(), 4);
}
#[tokio::test]
async fn test_create_index() {
use arrow_array::RecordBatch;
@@ -3213,7 +3518,10 @@ mod tests {
.execute()
.await
.unwrap();
table
.wait_for_index(&["embeddings_idx"], Duration::from_millis(10))
.await
.unwrap();
let index_configs = table.list_indices().await.unwrap();
assert_eq!(index_configs.len(), 1);
let index = index_configs.into_iter().next().unwrap();
@@ -3281,7 +3589,10 @@ mod tests {
.execute()
.await
.unwrap();
table
.wait_for_index(&["i_idx"], Duration::from_millis(10))
.await
.unwrap();
let index_configs = table.list_indices().await.unwrap();
assert_eq!(index_configs.len(), 1);
let index = index_configs.into_iter().next().unwrap();
@@ -3747,4 +4058,108 @@ mod tests {
Some(&"test_field_val1".to_string())
);
}
#[tokio::test]
pub async fn test_stats() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = ConnectBuilder::new(uri).execute().await.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("foo", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..100)),
Arc::new(Int32Array::from_iter_values(0..100)),
],
)
.unwrap();
let table = conn
.create_table(
"test_stats",
RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()),
)
.execute()
.await
.unwrap();
for _ in 0..10 {
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..15)),
Arc::new(Int32Array::from_iter_values(0..15)),
],
)
.unwrap();
table
.add(RecordBatchIterator::new(
vec![Ok(batch.clone())],
batch.schema(),
))
.execute()
.await
.unwrap();
}
let empty_table = conn
.create_table(
"test_stats_empty",
RecordBatchIterator::new(vec![], batch.schema()),
)
.execute()
.await
.unwrap();
let res = table.stats().await.unwrap();
println!("{:#?}", res);
assert_eq!(
res,
TableStatistics {
num_rows: 250,
num_indices: 0,
total_bytes: 2000,
fragment_stats: FragmentStatistics {
num_fragments: 11,
num_small_fragments: 11,
lengths: FragmentSummaryStats {
min: 15,
max: 100,
mean: 22,
p25: 15,
p50: 15,
p75: 15,
p99: 100,
},
},
}
);
let res = empty_table.stats().await.unwrap();
println!("{:#?}", res);
assert_eq!(
res,
TableStatistics {
num_rows: 0,
num_indices: 0,
total_bytes: 0,
fragment_stats: FragmentStatistics {
num_fragments: 0,
num_small_fragments: 0,
lengths: FragmentSummaryStats {
min: 0,
max: 0,
mean: 0,
p25: 0,
p50: 0,
p75: 0,
p99: 0,
},
},
}
)
}
}

View File

@@ -7,7 +7,7 @@ use std::{
time::{self, Duration, Instant},
};
use lance::Dataset;
use lance::{dataset::refs, Dataset};
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::error::Result;
@@ -83,19 +83,32 @@ impl DatasetRef {
}
}
async fn as_time_travel(&mut self, target_version: u64) -> Result<()> {
async fn as_time_travel(&mut self, target_version: impl Into<refs::Ref>) -> Result<()> {
let target_ref = target_version.into();
match self {
Self::Latest { dataset, .. } => {
let new_dataset = dataset.checkout_version(target_ref.clone()).await?;
let version_value = new_dataset.version().version;
*self = Self::TimeTravel {
dataset: dataset.checkout_version(target_version).await?,
version: target_version,
dataset: new_dataset,
version: version_value,
};
}
Self::TimeTravel { dataset, version } => {
if *version != target_version {
let should_checkout = match &target_ref {
refs::Ref::Version(target_ver) => version != target_ver,
refs::Ref::Tag(_) => true, // Always checkout for tags
};
if should_checkout {
let new_dataset = dataset.checkout_version(target_ref).await?;
let version_value = new_dataset.version().version;
*self = Self::TimeTravel {
dataset: dataset.checkout_version(target_version).await?,
version: target_version,
dataset: new_dataset,
version: version_value,
};
}
}
@@ -175,7 +188,7 @@ impl DatasetConsistencyWrapper {
write_guard.as_latest(read_consistency_interval).await
}
pub async fn as_time_travel(&self, target_version: u64) -> Result<()> {
pub async fn as_time_travel(&self, target_version: impl Into<refs::Ref>) -> Result<()> {
self.0.write().await.as_time_travel(target_version).await
}

View File

@@ -4,6 +4,7 @@
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use lance::dataset::MergeStats;
use crate::Result;
@@ -86,8 +87,9 @@ impl MergeInsertBuilder {
/// Executes the merge insert operation
///
/// Nothing is returned but the [`super::Table`] is updated
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
/// Returns statistics about the merge operation including the number of rows
/// inserted, updated, and deleted.
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<MergeStats> {
self.table.clone().merge_insert(self, new_data).await
}
}