Compare commits

...

76 Commits

Author SHA1 Message Date
Lance Release
c0097c5f0a Bump version: 0.21.0-beta.2 → 0.21.0 2025-03-10 23:12:56 +00:00
Lance Release
c199708e64 Bump version: 0.21.0-beta.1 → 0.21.0-beta.2 2025-03-10 23:12:56 +00:00
Weston Pace
4a47150ae7 feat: upgrade to lance 0.24.1 (#2199) 2025-03-10 15:18:37 -07:00
Wyatt Alt
f86b20a564 fix: delete tables from DDB on drop_all_tables (#2194)
Prior to this commit, issuing drop_all_tables on a listing database with
an external manifest store would delete physical tables but leave
references behind in the manifest store. The table drop would succeed,
but subsequent creation of a table with the same name would fail with a
conflict.

With this patch, the external manifest store is updated to account for
the dropped tables so that dropped table names can be reused.
2025-03-10 15:00:53 -07:00
msu-reevo
cc81f3e1a5 fix(python): typing (#2167)
@wjones127 is there a standard way you guys setup your virtualenv? I can
either relist all the dependencies in the pyright precommit section, or
specify a venv, or the user has to be in the virtual environment when
they run git commit. If the venv location was standardized or a python
manager like `uv` was used it would be easier to avoid duplicating the
pyright dependency list.

Per your suggestion, in `pyproject.toml` I added in all the passing
files to the `includes` section.

For ruff I upgraded the version and removed "TCH" which doesn't exist as
an option.

I added a `pyright_report.csv` which contains a list of all files sorted
by pyright errors ascending as a todo list to work on.

I fixed about 30 issues in `table.py` stemming from str's being passed
into methods that required a string within a set of string Literals by
extracting them into `types.py`

Can you verify in the rust bridge that the schema should be a property
and not a method here? If it's a method, then there's another place in
the code where `inner.schema` should be `inner.schema()`
``` python
class RecordBatchStream:
    @property
    def schema(self) -> pa.Schema: ...
```

Also unless the `_lancedb.pyi` file is wrong, then there is no
`__anext__` here for `__inner` when it's not an `AsyncGenerator` and
only `next` is defined:
``` python
    async def __anext__(self) -> pa.RecordBatch:
        return await self._inner.__anext__()
        if isinstance(self._inner, AsyncGenerator):
            batch = await self._inner.__anext__()
        else:
            batch = await self._inner.next()
        if batch is None:
            raise StopAsyncIteration
        return batch
```
in the else statement, `_inner` is a `RecordBatchStream`
```python
class RecordBatchStream:
    @property
    def schema(self) -> pa.Schema: ...
    async def next(self) -> Optional[pa.RecordBatch]: ...
```

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2025-03-10 09:01:23 -07:00
Weston Pace
bc49c4db82 feat: respect datafusion's batch size when running as a table provider (#2187)
Datafusion makes the batch size available as part of the `SessionState`.
We should use that to set the `max_batch_length` property in the
`QueryExecutionOptions`.
2025-03-07 05:53:36 -08:00
Weston Pace
d2eec46f17 feat: add support for streaming input to create_table (#2175)
This PR makes it possible to create a table using an asynchronous stream
of input data. Currently only a synchronous iterator is supported. There
are a number of follow-ups not yet tackled:

* Support for embedding functions (the embedding functions wrapper needs
to be re-written to be async, should be an easy lift)
* Support for async input into the remote table (the make_ipc_batch
needs to change to accept async input, leaving undone for now because I
think we want to support actual streaming uploads into the remote table
soon)
* Support for async input into the add function (pretty essential, but
it is a fairly distinct code path, so saving for a different PR)
2025-03-06 11:55:00 -08:00
Lance Release
51437bc228 Bump version: 0.21.0-beta.0 → 0.21.0-beta.1 2025-03-06 19:23:06 +00:00
Bert
fa53cfcfd2 feat: support modifying field metadata in lancedb python (#2178) 2025-03-04 16:58:46 -05:00
vinoyang
374fe0ad95 feat(rust): introduce Catalog trait and implement ListingCatalog (#2148)
Co-authored-by: Weston Pace <weston.pace@gmail.com>
2025-03-03 20:22:24 -08:00
BubbleCal
35e5b84ba9 chore: upgrade lance to 0.24.0-beta.1 (#2171)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-03-03 12:32:12 +08:00
Lei Xu
7c12d497b0 ci: bump python to 3.12 in GHA (#2169) 2025-03-01 17:24:02 -08:00
ayao227
dfe4ba8dad chore: add reo integration (#2149)
This PR adds reo integration to the lancedb documentation website.
2025-02-28 07:51:34 -08:00
Weston Pace
fa1b9ad5bd fix: don't use with_schema to remove schema metadata (#2162)
It seems that `RecordBatch::with_schema` is unable to remove schema
metadata from a batch. It fails with the error `target schema is not
superset of current schema`.

I'm not sure how the `test_metadata_erased` test is passing. Strangely,
the metadata was not present by the time the batch arrived at the
metadata eraser. I think maybe the schema metadata is only present in
the batch if there is a filter.

I've created a new unit test that makes sure the metadata is erased if
we have a filter also
2025-02-27 10:24:00 -08:00
BubbleCal
8877eb020d feat: record the server version for remote table (#2147)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-02-27 15:55:59 +08:00
Will Jones
01e4291d21 feat(python): drop hard dependency on pylance (#2156)
Closes #1793
2025-02-26 15:53:45 -08:00
Lance Release
ab3ea76ad1 Updating package-lock.json 2025-02-26 21:23:39 +00:00
Lance Release
728ef8657d Updating package-lock.json 2025-02-26 20:11:37 +00:00
Lance Release
0b13901a16 Updating package-lock.json 2025-02-26 20:11:22 +00:00
Lance Release
84b110e0ef Bump version: 0.17.0 → 0.18.0-beta.0 2025-02-26 20:11:07 +00:00
Lance Release
e1836e54e3 Bump version: 0.20.0 → 0.21.0-beta.0 2025-02-26 20:10:54 +00:00
Weston Pace
4ba5326880 feat: reapply upgrade lance to v0.23.3-beta.1 (#2157)
This reverts commit 2f0c5baea2.

---------

Co-authored-by: Lu Qiu <luqiujob@gmail.com>
2025-02-26 11:44:11 -08:00
Lance Release
b036a69300 Updating package-lock.json 2025-02-26 19:32:22 +00:00
Will Jones
5b12a47119 feat!: revert query limit to be unbounded for scans (#2151)
In earlier PRs (#1886, #1191) we made the default limit 10 regardless of
the query type. This was confusing for users and in many cases a
breaking change. Users would have queries that used to return all
results, but instead only returned the first 10, causing silent bugs.

Part of the cause was consistency: the Python sync API seems to have
always had a limit of 10, while newer APIs (Python async and Nodejs)
didn't.

This PR sets the default limit only for searches (vector search, FTS),
while letting scans (even with filters) be unbounded. It does this
consistently for all SDKs.

Fixes #1983
Fixes #1852
Fixes #2141
2025-02-26 10:32:14 -08:00
Lance Release
769d483e50 Updating package-lock.json 2025-02-26 18:16:59 +00:00
Lance Release
9ecb11fe5a Updating package-lock.json 2025-02-26 18:16:42 +00:00
Lance Release
22bd8329f3 Bump version: 0.17.0-beta.0 → 0.17.0 2025-02-26 18:16:07 +00:00
Lance Release
a736fad149 Bump version: 0.16.1-beta.3 → 0.17.0-beta.0 2025-02-26 18:16:01 +00:00
Lance Release
072adc41aa Bump version: 0.20.0-beta.0 → 0.20.0 2025-02-26 18:15:23 +00:00
Lance Release
c6f25ef1f0 Bump version: 0.19.1-beta.3 → 0.20.0-beta.0 2025-02-26 18:15:23 +00:00
Weston Pace
2f0c5baea2 Revert "chore: upgrade lance to v0.23.3-beta.1 (#2153)"
This reverts commit a63dd66d41.
2025-02-26 10:14:29 -08:00
BubbleCal
a63dd66d41 chore: upgrade lance to v0.23.3-beta.1 (#2153)
this fixes a bug in SQ, see https://github.com/lancedb/lance/pull/3476
for more details

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
Co-authored-by: Lu Qiu <luqiujob@gmail.com>
2025-02-26 09:52:28 -08:00
Weston Pace
d6b3ccb37b feat: upgrade lance to 0.23.2 (#2152)
This also changes the pylance pin from `==0.23.2` to `~=0.23.2` which
should allow the pylance dependency to float a little. The pylance
dependency is actually not used for much anymore and so it should be
tolerant of patch changes.
2025-02-26 09:02:51 -08:00
Weston Pace
c4f99e82e5 feat: push filters down into DF table provider (#2128) 2025-02-25 14:46:28 -08:00
andrew-pienso
979a2d3d9d docs: fixes is_open docstring on AsyncTable (#2150) 2025-02-25 09:11:25 -08:00
Will Jones
7ac5f74c80 feat!: add variable store to embeddings registry (#2112)
BREAKING CHANGE: embedding function implementations in Node need to now
call `resolveVariables()` in their constructors and should **not**
implement `toJSON()`.

This tries to address the handling of secrets. In Node, they are
currently lost. In Python, they are currently leaked into the table
schema metadata.

This PR introduces an in-memory variable store on the function registry.
It also allows embedding function definitions to label certain config
values as "sensitive", and the preprocessing logic will raise an error
if users try to pass in hard-coded values.

Closes #2110
Closes #521

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
2025-02-24 15:52:19 -08:00
Will Jones
ecdee4d2b1 feat(python): add search() method to async API (#2049)
Reviving #1966.

Closes #1938

The `search()` method can apply embeddings for the user. This simplifies
hybrid search, so instead of writing:

```python
vector_query = embeddings.compute_query_embeddings("flower moon")[0]
await (
    async_tbl.query()
    .nearest_to(vector_query)
    .nearest_to_text("flower moon")
    .to_pandas()
)
```

You can write:

```python
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
```

Unfortunately, we had to do a double-await here because `search()` needs
to be async. This is because it often needs to do IO to retrieve and run
an embedding function.
2025-02-24 14:19:25 -08:00
BubbleCal
f391ed828a fix: remote table doesn't apply the prefilter flag for FTS (#2145) 2025-02-24 21:37:43 +08:00
BubbleCal
a99a450f2b fix: flat FTS panic with prefilter and update lance (#2144)
this is fixed in lance so upgrade lance to 0.23.2-beta1
2025-02-24 14:34:00 +08:00
Lei Xu
6fa1f37506 docs: improve pydantic integration docs (#2136)
Address usage mistakes in
https://github.com/lancedb/lancedb/issues/2135.

* Add example of how to use `LanceModel` and `Vector` decorator
* Add test for pydantic doc
* Fix the example to directly use LanceModel instead of calling
`MyModel.to_arrow_schema()` in the example.
* Add cross-reference link to pydantic doc site
* Configure mkdocs to watch code changes in python directory.
2025-02-21 12:48:37 -08:00
BubbleCal
544382df5e fix: handle batch quires in single request (#2139) 2025-02-21 13:23:39 +08:00
BubbleCal
784f00ef6d chore: update Cargo.lock (#2137) 2025-02-21 12:27:10 +08:00
Lance Release
96d7446f70 Updating package-lock.json 2025-02-20 04:51:26 +00:00
Lance Release
99ea78fb55 Updating package-lock.json 2025-02-20 03:38:44 +00:00
Lance Release
8eef4cdc28 Updating package-lock.json 2025-02-20 03:38:27 +00:00
Lance Release
0f102f02c3 Bump version: 0.16.1-beta.2 → 0.16.1-beta.3 2025-02-20 03:38:01 +00:00
Lance Release
a33a0670f6 Bump version: 0.19.1-beta.2 → 0.19.1-beta.3 2025-02-20 03:37:27 +00:00
BubbleCal
14c9ff46d1 feat: support multivector on remote table (#2045)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-02-20 11:34:51 +08:00
Lei Xu
1865f7decf fix: support optional nested pydantic model (#2130)
Closes #2129
2025-02-17 20:43:13 -08:00
BubbleCal
a608621476 test: query with dist range and new rows (#2126)
we found a bug that flat KNN plan node's stats is not in right order as
fields in schema, it would cause an error if querying with distance
range and new unindexed rows.

we've fixed this in lance so add this test for verifying it works

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-02-17 12:57:45 +08:00
BubbleCal
00514999ff feat: upgrade lance to 0.23.1-beta.4 (#2121)
this also upgrades object_store to 0.11.0, snafu to 0.8

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-02-16 14:53:26 +08:00
Lance Release
b3b597fef6 Updating package-lock.json 2025-02-13 04:40:10 +00:00
Lance Release
bf17144591 Updating package-lock.json 2025-02-13 04:39:54 +00:00
Lance Release
09e110525f Bump version: 0.16.1-beta.1 → 0.16.1-beta.2 2025-02-13 04:39:38 +00:00
Lance Release
40f0dbb64d Bump version: 0.19.1-beta.1 → 0.19.1-beta.2 2025-02-13 04:39:19 +00:00
BubbleCal
3b19e96ae7 fix: panic when field id doesn't equal to field index (#2116)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-02-13 12:38:35 +08:00
Will Jones
78a17ad54c chore: improve dev instructions for Python (#2088)
Closes #2042
2025-02-12 14:08:52 -08:00
Lance Release
a8e6b491e2 Updating package-lock.json 2025-02-11 22:05:54 +00:00
Lance Release
cea541ca46 Updating package-lock.json 2025-02-11 20:56:22 +00:00
Lance Release
873ffc1042 Updating package-lock.json 2025-02-11 20:56:05 +00:00
Lance Release
83273ad997 Bump version: 0.16.1-beta.0 → 0.16.1-beta.1 2025-02-11 20:55:43 +00:00
Lance Release
d18d63c69d Bump version: 0.19.1-beta.0 → 0.19.1-beta.1 2025-02-11 20:55:23 +00:00
LuQQiu
c3e865e8d0 fix: fix index out of bound in load indices (#2108)
panicked at 'index out of bounds: the len is 24 but the index is
25':Lancedb/rust/lancedb/src/index/vector.rs:26\n

load_indices() on the old manifest while use the newer manifest to get
column names could result in index out of bound if some columns are
removed from the new version.
This change reduce the possibility of index out of bound operation but
does not fully remove it.
Better that lance can directly provide column name info so no need extra
calls to get column name but that require modify the public APIs
2025-02-11 12:54:11 -08:00
Weston Pace
a7755cb313 docs: standardize node example prints (#2080)
Minor cleanup to help debug future CI failures
2025-02-11 08:26:29 -08:00
BubbleCal
3490f3456f chore: upgrade lance to 0.23.1-beta.2 (#2109) 2025-02-11 23:57:56 +08:00
Lance Release
0a1d0693e1 Updating package-lock.json 2025-02-07 20:06:22 +00:00
Lance Release
fd330b4b4b Updating package-lock.json 2025-02-07 19:28:01 +00:00
Lance Release
d4e9fc08e0 Updating package-lock.json 2025-02-07 19:27:44 +00:00
Lance Release
3626f2f5e1 Bump version: 0.16.0 → 0.16.1-beta.0 2025-02-07 19:27:26 +00:00
Lance Release
e64712cfa5 Bump version: 0.19.0 → 0.19.1-beta.0 2025-02-07 19:27:07 +00:00
Wyatt Alt
3e3118f85c feat: update lance dependency to 0.23.1-beta.1 (#2102) 2025-02-07 10:56:01 -08:00
Lance Release
592598a333 Updating package-lock.json 2025-02-07 18:50:53 +00:00
Lance Release
5ad21341c9 Updating package-lock.json 2025-02-07 17:34:04 +00:00
Lance Release
6e08caa091 Updating package-lock.json 2025-02-07 17:33:48 +00:00
Lance Release
7e259d8b0f Bump version: 0.16.0-beta.0 → 0.16.0 2025-02-07 17:33:13 +00:00
Lance Release
e84f747464 Bump version: 0.15.1-beta.3 → 0.16.0-beta.0 2025-02-07 17:33:08 +00:00
100 changed files with 4024 additions and 1191 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.15.1-beta.3" current_version = "0.18.0-beta.0"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -33,13 +33,14 @@ jobs:
python-version: "3.12" python-version: "3.12"
- name: Install ruff - name: Install ruff
run: | run: |
pip install ruff==0.8.4 pip install ruff==0.9.9
- name: Format check - name: Format check
run: ruff format --check . run: ruff format --check .
- name: Lint - name: Lint
run: ruff check . run: ruff check .
doctest:
name: "Doctest" type-check:
name: "Type Check"
timeout-minutes: 30 timeout-minutes: 30
runs-on: "ubuntu-22.04" runs-on: "ubuntu-22.04"
defaults: defaults:
@@ -54,7 +55,36 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.12"
- name: Install protobuf compiler
run: |
sudo apt update
sudo apt install -y protobuf-compiler
pip install toml
- name: Install dependencies
run: |
python ../ci/parse_requirements.py pyproject.toml --extras dev,tests,embeddings > requirements.txt
pip install -r requirements.txt
- name: Run pyright
run: pyright
doctest:
name: "Doctest"
timeout-minutes: 30
runs-on: "ubuntu-24.04"
defaults:
run:
shell: bash
working-directory: python
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: "pip" cache: "pip"
- name: Install protobuf - name: Install protobuf
run: | run: |
@@ -75,8 +105,8 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
strategy: strategy:
matrix: matrix:
python-minor-version: ["9", "11"] python-minor-version: ["9", "12"]
runs-on: "ubuntu-22.04" runs-on: "ubuntu-24.04"
defaults: defaults:
run: run:
shell: bash shell: bash
@@ -127,7 +157,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.12"
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
with: with:
workspaces: python workspaces: python
@@ -157,7 +187,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.12"
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
with: with:
workspaces: python workspaces: python
@@ -168,7 +198,7 @@ jobs:
run: rm -rf target/wheels run: rm -rf target/wheels
pydantic1x: pydantic1x:
timeout-minutes: 30 timeout-minutes: 30
runs-on: "ubuntu-22.04" runs-on: "ubuntu-24.04"
defaults: defaults:
run: run:
shell: bash shell: bash

View File

@@ -61,7 +61,12 @@ jobs:
CXX: clang++ CXX: clang++
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
# Remote cargo.lock to force a fresh build # Building without a lock file often requires the latest Rust version since downstream
# dependencies may have updated their minimum Rust version.
- uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: "stable"
# Remove cargo.lock to force a fresh build
- name: Remove Cargo.lock - name: Remove Cargo.lock
run: rm -f Cargo.lock run: rm -f Cargo.lock
- uses: rui314/setup-mold@v1 - uses: rui314/setup-mold@v1
@@ -179,15 +184,17 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies (part 1)
run: | run: |
set -e set -e
apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed apk add protobuf-dev curl clang lld llvm19 grep npm bash msitools sed
- name: Install rust
curl --proto '=https' --tlsv1.3 -sSf https://raw.githubusercontent.com/rust-lang/rustup/refs/heads/master/rustup-init.sh | sh -s -- -y uses: actions-rust-lang/setup-rust-toolchain@v1
source $HOME/.cargo/env with:
rustup target add aarch64-pc-windows-msvc target: aarch64-pc-windows-msvc
- name: Install dependencies (part 2)
run: |
set -e
mkdir -p sysroot mkdir -p sysroot
cd sysroot cd sysroot
sh ../ci/sysroot-aarch64-pc-windows-msvc.sh sh ../ci/sysroot-aarch64-pc-windows-msvc.sh
@@ -259,7 +266,7 @@ jobs:
- name: Install Rust - name: Install Rust
run: | run: |
Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe Invoke-WebRequest https://win.rustup.rs/x86_64 -OutFile rustup-init.exe
.\rustup-init.exe -y --default-host aarch64-pc-windows-msvc .\rustup-init.exe -y --default-host aarch64-pc-windows-msvc --default-toolchain 1.83.0
shell: powershell shell: powershell
- name: Add Rust to PATH - name: Add Rust to PATH
run: | run: |

View File

@@ -1,21 +1,27 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0 rev: v3.2.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.8.4 rev: v0.9.9
hooks: hooks:
- id: ruff - id: ruff
- repo: local # - repo: https://github.com/RobertCraigie/pyright-python
hooks: # rev: v1.1.395
- id: local-biome-check # hooks:
name: biome check # - id: pyright
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/ # args: ["--project", "python"]
language: system # additional_dependencies: [pyarrow-stubs]
types: [text] - repo: local
files: "nodejs/.*" hooks:
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.* - id: local-biome-check
name: biome check
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
language: system
types: [text]
files: "nodejs/.*"
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.*

827
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -21,33 +21,30 @@ categories = ["database-implementations"]
rust-version = "1.78.0" rust-version = "1.78.0"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.23.0", "features" = [ lance = { "version" = "=0.24.1", "features" = ["dynamodb"] }
"dynamodb", lance-io = { version = "=0.24.1" }
]} lance-index = { version = "=0.24.1" }
lance-io = "=0.23.0" lance-linalg = { version = "=0.24.1" }
lance-index = "=0.23.0" lance-table = { version = "=0.24.1" }
lance-linalg = "=0.23.0" lance-testing = { version = "=0.24.1" }
lance-table = "=0.23.0" lance-datafusion = { version = "=0.24.1" }
lance-testing = "=0.23.0" lance-encoding = { version = "=0.24.1" }
lance-datafusion = "=0.23.0"
lance-encoding = "=0.23.0"
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "53.2", optional = false } arrow = { version = "54.1", optional = false }
arrow-array = "53.2" arrow-array = "54.1"
arrow-data = "53.2" arrow-data = "54.1"
arrow-ipc = "53.2" arrow-ipc = "54.1"
arrow-ord = "53.2" arrow-ord = "54.1"
arrow-schema = "53.2" arrow-schema = "54.1"
arrow-arith = "53.2" arrow-arith = "54.1"
arrow-cast = "53.2" arrow-cast = "54.1"
async-trait = "0" async-trait = "0"
chrono = "0.4.35" datafusion = { version = "45.0", default-features = false }
datafusion = { version = "44.0", default-features = false } datafusion-catalog = "45.0"
datafusion-catalog = "44.0" datafusion-common = { version = "45.0", default-features = false }
datafusion-common = { version = "44.0", default-features = false } datafusion-execution = "45.0"
datafusion-execution = "44.0" datafusion-expr = "45.0"
datafusion-expr = "44.0" datafusion-physical-plan = "45.0"
datafusion-physical-plan = "44.0"
env_logger = "0.11" env_logger = "0.11"
half = { "version" = "=2.4.1", default-features = false, features = [ half = { "version" = "=2.4.1", default-features = false, features = [
"num-traits", "num-traits",
@@ -55,14 +52,21 @@ half = { "version" = "=2.4.1", default-features = false, features = [
futures = "0" futures = "0"
log = "0.4" log = "0.4"
moka = { version = "0.12", features = ["future"] } moka = { version = "0.12", features = ["future"] }
object_store = "0.10.2" object_store = "0.11.0"
pin-project = "1.0.7" pin-project = "1.0.7"
snafu = "0.7.4" snafu = "0.8"
url = "2" url = "2"
num-traits = "0.2" num-traits = "0.2"
rand = "0.8" rand = "0.8"
regex = "1.10" regex = "1.10"
lazy_static = "1" lazy_static = "1"
semver = "1.0.25"
# Temporary pins to work around downstream issues
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
chrono = "=0.4.39"
# https://github.com/RustCrypto/formats/issues/1684
base64ct = "=1.6.0"
# Workaround for: https://github.com/eira-fransham/crunchy/issues/13 # Workaround for: https://github.com/eira-fransham/crunchy/issues/13
crunchy = "=0.2.2" crunchy = "=0.2.2"

41
ci/parse_requirements.py Normal file
View File

@@ -0,0 +1,41 @@
import argparse
import toml
def parse_dependencies(pyproject_path, extras=None):
with open(pyproject_path, "r") as file:
pyproject = toml.load(file)
dependencies = pyproject.get("project", {}).get("dependencies", [])
for dependency in dependencies:
print(dependency)
optional_dependencies = pyproject.get("project", {}).get(
"optional-dependencies", {}
)
if extras:
for extra in extras.split(","):
for dep in optional_dependencies.get(extra, []):
print(dep)
def main():
parser = argparse.ArgumentParser(
description="Generate requirements.txt from pyproject.toml"
)
parser.add_argument("path", type=str, help="Path to pyproject.toml")
parser.add_argument(
"--extras",
type=str,
help="Comma-separated list of extras to include",
default="",
)
args = parser.parse_args()
parse_dependencies(args.path, args.extras)
if __name__ == "__main__":
main()

View File

@@ -4,6 +4,9 @@ repo_url: https://github.com/lancedb/lancedb
edit_uri: https://github.com/lancedb/lancedb/tree/main/docs/src edit_uri: https://github.com/lancedb/lancedb/tree/main/docs/src
repo_name: lancedb/lancedb repo_name: lancedb/lancedb
docs_dir: src docs_dir: src
watch:
- src
- ../python/python
theme: theme:
name: "material" name: "material"
@@ -63,6 +66,7 @@ plugins:
- https://arrow.apache.org/docs/objects.inv - https://arrow.apache.org/docs/objects.inv
- https://pandas.pydata.org/docs/objects.inv - https://pandas.pydata.org/docs/objects.inv
- https://lancedb.github.io/lance/objects.inv - https://lancedb.github.io/lance/objects.inv
- https://docs.pydantic.dev/latest/objects.inv
- mkdocs-jupyter - mkdocs-jupyter
- render_swagger: - render_swagger:
allow_arbitrary_locations: true allow_arbitrary_locations: true
@@ -105,8 +109,8 @@ nav:
- 📚 Concepts: - 📚 Concepts:
- Vector search: concepts/vector_search.md - Vector search: concepts/vector_search.md
- Indexing: - Indexing:
- IVFPQ: concepts/index_ivfpq.md - IVFPQ: concepts/index_ivfpq.md
- HNSW: concepts/index_hnsw.md - HNSW: concepts/index_hnsw.md
- Storage: concepts/storage.md - Storage: concepts/storage.md
- Data management: concepts/data_management.md - Data management: concepts/data_management.md
- 🔨 Guides: - 🔨 Guides:
@@ -130,8 +134,8 @@ nav:
- Adaptive RAG: rag/adaptive_rag.md - Adaptive RAG: rag/adaptive_rag.md
- SFR RAG: rag/sfr_rag.md - SFR RAG: rag/sfr_rag.md
- Advanced Techniques: - Advanced Techniques:
- HyDE: rag/advanced_techniques/hyde.md - HyDE: rag/advanced_techniques/hyde.md
- FLARE: rag/advanced_techniques/flare.md - FLARE: rag/advanced_techniques/flare.md
- Reranking: - Reranking:
- Quickstart: reranking/index.md - Quickstart: reranking/index.md
- Cohere Reranker: reranking/cohere.md - Cohere Reranker: reranking/cohere.md
@@ -146,7 +150,7 @@ nav:
- Building Custom Rerankers: reranking/custom_reranker.md - Building Custom Rerankers: reranking/custom_reranker.md
- Example: notebooks/lancedb_reranking.ipynb - Example: notebooks/lancedb_reranking.ipynb
- Filtering: sql.md - Filtering: sql.md
- Versioning & Reproducibility: - Versioning & Reproducibility:
- sync API: notebooks/reproducibility.ipynb - sync API: notebooks/reproducibility.ipynb
- async API: notebooks/reproducibility_async.ipynb - async API: notebooks/reproducibility_async.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md
@@ -178,6 +182,7 @@ nav:
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md - Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md - Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
- User-defined embedding functions: embeddings/custom_embedding_function.md - User-defined embedding functions: embeddings/custom_embedding_function.md
- Variables and secrets: embeddings/variables_and_secrets.md
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb - "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb - "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
- 🔌 Integrations: - 🔌 Integrations:
@@ -240,8 +245,8 @@ nav:
- Concepts: - Concepts:
- Vector search: concepts/vector_search.md - Vector search: concepts/vector_search.md
- Indexing: - Indexing:
- IVFPQ: concepts/index_ivfpq.md - IVFPQ: concepts/index_ivfpq.md
- HNSW: concepts/index_hnsw.md - HNSW: concepts/index_hnsw.md
- Storage: concepts/storage.md - Storage: concepts/storage.md
- Data management: concepts/data_management.md - Data management: concepts/data_management.md
- Guides: - Guides:
@@ -265,8 +270,8 @@ nav:
- Adaptive RAG: rag/adaptive_rag.md - Adaptive RAG: rag/adaptive_rag.md
- SFR RAG: rag/sfr_rag.md - SFR RAG: rag/sfr_rag.md
- Advanced Techniques: - Advanced Techniques:
- HyDE: rag/advanced_techniques/hyde.md - HyDE: rag/advanced_techniques/hyde.md
- FLARE: rag/advanced_techniques/flare.md - FLARE: rag/advanced_techniques/flare.md
- Reranking: - Reranking:
- Quickstart: reranking/index.md - Quickstart: reranking/index.md
- Cohere Reranker: reranking/cohere.md - Cohere Reranker: reranking/cohere.md
@@ -280,7 +285,7 @@ nav:
- Building Custom Rerankers: reranking/custom_reranker.md - Building Custom Rerankers: reranking/custom_reranker.md
- Example: notebooks/lancedb_reranking.ipynb - Example: notebooks/lancedb_reranking.ipynb
- Filtering: sql.md - Filtering: sql.md
- Versioning & Reproducibility: - Versioning & Reproducibility:
- sync API: notebooks/reproducibility.ipynb - sync API: notebooks/reproducibility.ipynb
- async API: notebooks/reproducibility_async.ipynb - async API: notebooks/reproducibility_async.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md
@@ -311,6 +316,7 @@ nav:
- Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md - Imagebind embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/imagebind_embedding.md
- Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md - Jina Embeddings: embeddings/available_embedding_models/multimodal_embedding_functions/jina_multimodal_embedding.md
- User-defined embedding functions: embeddings/custom_embedding_function.md - User-defined embedding functions: embeddings/custom_embedding_function.md
- Variables and secrets: embeddings/variables_and_secrets.md
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb - "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb - "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
- Integrations: - Integrations:
@@ -349,8 +355,8 @@ nav:
- 🦀 Rust: - 🦀 Rust:
- Overview: examples/examples_rust.md - Overview: examples/examples_rust.md
- Studies: - Studies:
- studies/overview.md - studies/overview.md
- ↗Improve retrievers with hybrid search and reranking: https://blog.lancedb.com/hybrid-search-and-reranking-report/ - ↗Improve retrievers with hybrid search and reranking: https://blog.lancedb.com/hybrid-search-and-reranking-report/
- API reference: - API reference:
- Overview: api_reference.md - Overview: api_reference.md
- Python: python/python.md - Python: python/python.md
@@ -371,6 +377,7 @@ extra_css:
extra_javascript: extra_javascript:
- "extra_js/init_ask_ai_widget.js" - "extra_js/init_ask_ai_widget.js"
- "extra_js/reo.js"
extra: extra:
analytics: analytics:

View File

@@ -3,6 +3,7 @@ import * as vectordb from "vectordb";
// --8<-- [end:import] // --8<-- [end:import]
(async () => { (async () => {
console.log("ann_indexes.ts: start");
// --8<-- [start:ingest] // --8<-- [start:ingest]
const db = await vectordb.connect("data/sample-lancedb"); const db = await vectordb.connect("data/sample-lancedb");
@@ -49,5 +50,5 @@ import * as vectordb from "vectordb";
.execute(); .execute();
// --8<-- [end:search3] // --8<-- [end:search3]
console.log("Ann indexes: done"); console.log("ann_indexes.ts: done");
})(); })();

View File

@@ -107,7 +107,6 @@ const example = async () => {
// --8<-- [start:search] // --8<-- [start:search]
const query = await tbl.search([100, 100]).limit(2).execute(); const query = await tbl.search([100, 100]).limit(2).execute();
// --8<-- [end:search] // --8<-- [end:search]
console.log(query);
// --8<-- [start:delete] // --8<-- [start:delete]
await tbl.delete('item = "fizz"'); await tbl.delete('item = "fizz"');
@@ -119,8 +118,9 @@ const example = async () => {
}; };
async function main() { async function main() {
console.log("basic_legacy.ts: start");
await example(); await example();
console.log("Basic example: done"); console.log("basic_legacy.ts: done");
} }
main(); main();

View File

@@ -55,6 +55,14 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings. This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings.
!!! danger "Use sensitive keys to prevent leaking secrets"
To prevent leaking secrets, such as API keys, you should add any sensitive
parameters of an embedding function to the output of the
[sensitive_keys()][lancedb.embeddings.base.EmbeddingFunction.sensitive_keys] /
[getSensitiveKeys()](../../js/namespaces/embedding/classes/EmbeddingFunction/#getsensitivekeys)
method. This prevents users from accidentally instantiating the embedding
function with hard-coded secrets.
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs. Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
=== "Python" === "Python"

View File

@@ -0,0 +1,53 @@
# Variable and Secrets
Most embedding configuration options are saved in the table's metadata. However,
this isn't always appropriate. For example, API keys should never be stored in the
metadata. Additionally, other configuration options might be best set at runtime,
such as the `device` configuration that controls whether to use GPU or CPU for
inference. If you hardcoded this to GPU, you wouldn't be able to run the code on
a server without one.
To handle these cases, you can set variables on the embedding registry and
reference them in the embedding configuration. These variables will be available
during the runtime of your program, but not saved in the table's metadata. When
the table is loaded from a different process, the variables must be set again.
To set a variable, use the `set_var()` / `setVar()` method on the embedding registry.
To reference a variable, use the syntax `$env:VARIABLE_NAME`. If there is a default
value, you can use the syntax `$env:VARIABLE_NAME:DEFAULT_VALUE`.
## Using variables to set secrets
Sensitive configuration, such as API keys, must either be set as environment
variables or using variables on the embedding registry. If you pass in a hardcoded
value, LanceDB will raise an error. Instead, if you want to set an API key via
configuration, use a variable:
=== "Python"
```python
--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_secret"
```
=== "Typescript"
```typescript
--8<-- "nodejs/examples/embedding.test.ts:register_secret"
```
## Using variables to set the device parameter
Many embedding functions that run locally have a `device` parameter that controls
whether to use GPU or CPU for inference. Because not all computers have a GPU,
it's helpful to be able to set the `device` parameter at runtime, rather than
have it hard coded in the embedding configuration. To make it work even if the
variable isn't set, you could provide a default value of `cpu` in the embedding
configuration.
Some embedding libraries even have a method to detect which devices are available,
which could be used to dynamically set the device at runtime. For example, in Python
you can check if a CUDA GPU is available using `torch.cuda.is_available()`.
```python
--8<-- "python/python/tests/docs/test_embeddings_optional.py:register_device"
```

1
docs/src/extra_js/reo.js Normal file
View File

@@ -0,0 +1 @@
!function(){var e,t,n;e="9627b71b382d201",t=function(){Reo.init({clientID:"9627b71b382d201"})},(n=document.createElement("script")).src="https://static.reo.dev/"+e+"/reo.js",n.defer=!0,n.onload=t,document.head.appendChild(n)}();

View File

@@ -8,6 +8,23 @@
An embedding function that automatically creates vector representation for a given column. An embedding function that automatically creates vector representation for a given column.
It's important subclasses pass the **original** options to the super constructor
and then pass those options to `resolveVariables` to resolve any variables before
using them.
## Example
```ts
class MyEmbeddingFunction extends EmbeddingFunction {
constructor(options: {model: string, timeout: number}) {
super(optionsRaw);
const options = this.resolveVariables(optionsRaw);
this.model = options.model;
this.timeout = options.timeout;
}
}
```
## Extended by ## Extended by
- [`TextEmbeddingFunction`](TextEmbeddingFunction.md) - [`TextEmbeddingFunction`](TextEmbeddingFunction.md)
@@ -82,12 +99,33 @@ The datatype of the embeddings
*** ***
### getSensitiveKeys()
```ts
protected getSensitiveKeys(): string[]
```
Provide a list of keys in the function options that should be treated as
sensitive. If users pass raw values for these keys, they will be rejected.
#### Returns
`string`[]
***
### init()? ### init()?
```ts ```ts
optional init(): Promise<void> optional init(): Promise<void>
``` ```
Optionally load any resources needed for the embedding function.
This method is called after the embedding function has been initialized
but before any embeddings are computed. It is useful for loading local models
or other resources that are needed for the embedding function to work.
#### Returns #### Returns
`Promise`&lt;`void`&gt; `Promise`&lt;`void`&gt;
@@ -108,6 +146,24 @@ The number of dimensions of the embeddings
*** ***
### resolveVariables()
```ts
protected resolveVariables(config): Partial<M>
```
Apply variables to the config.
#### Parameters
* **config**: `Partial`&lt;`M`&gt;
#### Returns
`Partial`&lt;`M`&gt;
***
### sourceField() ### sourceField()
```ts ```ts
@@ -134,37 +190,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d
### toJSON() ### toJSON()
```ts ```ts
abstract toJSON(): Partial<M> toJSON(): Record<string, any>
``` ```
Convert the embedding function to a JSON object Get the original arguments to the constructor, to serialize them so they
It is used to serialize the embedding function to the schema can be used to recreate the embedding function later.
It's important that any object returned by this method contains all the necessary
information to recreate the embedding function
It should return the same object that was passed to the constructor
If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
#### Returns #### Returns
`Partial`&lt;`M`&gt; `Record`&lt;`string`, `any`&gt;
#### Example
```ts
class MyEmbeddingFunction extends EmbeddingFunction {
constructor(options: {model: string, timeout: number}) {
super();
this.model = options.model;
this.timeout = options.timeout;
}
toJSON() {
return {
model: this.model,
timeout: this.timeout,
};
}
```
*** ***

View File

@@ -80,6 +80,28 @@ getTableMetadata(functions): Map<string, string>
*** ***
### getVar()
```ts
getVar(name): undefined | string
```
Get a variable.
#### Parameters
* **name**: `string`
#### Returns
`undefined` \| `string`
#### See
[setVar](EmbeddingFunctionRegistry.md#setvar)
***
### length() ### length()
```ts ```ts
@@ -145,3 +167,31 @@ reset the registry to the initial state
#### Returns #### Returns
`void` `void`
***
### setVar()
```ts
setVar(name, value): void
```
Set a variable. These can be accessed in the embedding function
configuration using the syntax `$var:variable_name`. If they are not
set, an error will be thrown letting you know which key is unset. If you
want to supply a default value, you can add an additional part in the
configuration like so: `$var:variable_name:default_value`. Default values
can be used for runtime configurations that are not sensitive, such as
whether to use a GPU for inference.
The name must not contain colons. The default value can contain colons.
#### Parameters
* **name**: `string`
* **value**: `string`
#### Returns
`void`

View File

@@ -114,12 +114,37 @@ abstract generateEmbeddings(texts, ...args): Promise<number[][] | Float32Array[]
*** ***
### getSensitiveKeys()
```ts
protected getSensitiveKeys(): string[]
```
Provide a list of keys in the function options that should be treated as
sensitive. If users pass raw values for these keys, they will be rejected.
#### Returns
`string`[]
#### Inherited from
[`EmbeddingFunction`](EmbeddingFunction.md).[`getSensitiveKeys`](EmbeddingFunction.md#getsensitivekeys)
***
### init()? ### init()?
```ts ```ts
optional init(): Promise<void> optional init(): Promise<void>
``` ```
Optionally load any resources needed for the embedding function.
This method is called after the embedding function has been initialized
but before any embeddings are computed. It is useful for loading local models
or other resources that are needed for the embedding function to work.
#### Returns #### Returns
`Promise`&lt;`void`&gt; `Promise`&lt;`void`&gt;
@@ -148,6 +173,28 @@ The number of dimensions of the embeddings
*** ***
### resolveVariables()
```ts
protected resolveVariables(config): Partial<M>
```
Apply variables to the config.
#### Parameters
* **config**: `Partial`&lt;`M`&gt;
#### Returns
`Partial`&lt;`M`&gt;
#### Inherited from
[`EmbeddingFunction`](EmbeddingFunction.md).[`resolveVariables`](EmbeddingFunction.md#resolvevariables)
***
### sourceField() ### sourceField()
```ts ```ts
@@ -173,37 +220,15 @@ sourceField is used in combination with `LanceSchema` to provide a declarative d
### toJSON() ### toJSON()
```ts ```ts
abstract toJSON(): Partial<M> toJSON(): Record<string, any>
``` ```
Convert the embedding function to a JSON object Get the original arguments to the constructor, to serialize them so they
It is used to serialize the embedding function to the schema can be used to recreate the embedding function later.
It's important that any object returned by this method contains all the necessary
information to recreate the embedding function
It should return the same object that was passed to the constructor
If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
#### Returns #### Returns
`Partial`&lt;`M`&gt; `Record`&lt;`string`, `any`&gt;
#### Example
```ts
class MyEmbeddingFunction extends EmbeddingFunction {
constructor(options: {model: string, timeout: number}) {
super();
this.model = options.model;
this.timeout = options.timeout;
}
toJSON() {
return {
model: this.model,
timeout: this.timeout,
};
}
```
#### Inherited from #### Inherited from

View File

@@ -9,23 +9,50 @@ LanceDB supports [Polars](https://github.com/pola-rs/polars), a blazingly fast D
First, we connect to a LanceDB database. First, we connect to a LanceDB database.
=== "Sync API"
```py
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
```
=== "Async API"
```py
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb_async"
```
```py
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
```
We can load a Polars `DataFrame` to LanceDB directly. We can load a Polars `DataFrame` to LanceDB directly.
```py === "Sync API"
--8<-- "python/python/tests/docs/test_python.py:import-polars"
--8<-- "python/python/tests/docs/test_python.py:create_table_polars" ```py
``` --8<-- "python/python/tests/docs/test_python.py:import-polars"
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
```
=== "Async API"
```py
--8<-- "python/python/tests/docs/test_python.py:import-polars"
--8<-- "python/python/tests/docs/test_python.py:create_table_polars_async"
```
We can now perform similarity search via the LanceDB Python API. We can now perform similarity search via the LanceDB Python API.
```py === "Sync API"
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
``` ```py
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
```
=== "Async API"
```py
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars_async"
```
In addition to the selected columns, LanceDB also returns a vector In addition to the selected columns, LanceDB also returns a vector
and also the `_distance` column which is the distance between the query and also the `_distance` column which is the distance between the query
@@ -112,4 +139,3 @@ The reason it's beneficial to not convert the LanceDB Table
to a DataFrame is because the table can potentially be way larger to a DataFrame is because the table can potentially be way larger
than memory, and Polars LazyFrames allow us to work with such than memory, and Polars LazyFrames allow us to work with such
larger-than-memory datasets by not loading it into memory all at once. larger-than-memory datasets by not loading it into memory all at once.

View File

@@ -2,14 +2,19 @@
[Pydantic](https://docs.pydantic.dev/latest/) is a data validation library in Python. [Pydantic](https://docs.pydantic.dev/latest/) is a data validation library in Python.
LanceDB integrates with Pydantic for schema inference, data ingestion, and query result casting. LanceDB integrates with Pydantic for schema inference, data ingestion, and query result casting.
Using [LanceModel][lancedb.pydantic.LanceModel], users can seamlessly
integrate Pydantic with the rest of the LanceDB APIs.
## Schema ```python
LanceDB supports to create Apache Arrow Schema from a --8<-- "python/python/tests/docs/test_pydantic_integration.py:imports"
[Pydantic BaseModel](https://docs.pydantic.dev/latest/api/main/#pydantic.main.BaseModel)
via [pydantic_to_schema()](python.md#lancedb.pydantic.pydantic_to_schema) method. --8<-- "python/python/tests/docs/test_pydantic_integration.py:base_model"
--8<-- "python/python/tests/docs/test_pydantic_integration.py:set_url"
--8<-- "python/python/tests/docs/test_pydantic_integration.py:base_example"
```
::: lancedb.pydantic.pydantic_to_schema
## Vector Field ## Vector Field
@@ -34,3 +39,9 @@ Current supported type conversions:
| `list` | `pyarrow.List` | | `list` | `pyarrow.List` |
| `BaseModel` | `pyarrow.Struct` | | `BaseModel` | `pyarrow.Struct` |
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` | | `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
LanceDB supports to create Apache Arrow Schema from a
[Pydantic BaseModel][pydantic.BaseModel]
via [pydantic_to_schema()](python.md#lancedb.pydantic.pydantic_to_schema) method.
::: lancedb.pydantic.pydantic_to_schema

View File

@@ -20,6 +20,7 @@ async function setup() {
} }
async () => { async () => {
console.log("search_legacy.ts: start");
await setup(); await setup();
// --8<-- [start:search1] // --8<-- [start:search1]
@@ -37,5 +38,5 @@ async () => {
.execute(); .execute();
// --8<-- [end:search2] // --8<-- [end:search2]
console.log("search: done"); console.log("search_legacy.ts: done");
}; };

View File

@@ -1,6 +1,7 @@
import * as vectordb from "vectordb"; import * as vectordb from "vectordb";
(async () => { (async () => {
console.log("sql_legacy.ts: start");
const db = await vectordb.connect("data/sample-lancedb"); const db = await vectordb.connect("data/sample-lancedb");
let data = []; let data = [];
@@ -34,5 +35,5 @@ import * as vectordb from "vectordb";
await tbl.filter("id = 10").limit(10).execute(); await tbl.filter("id = 10").limit(10).execute();
// --8<-- [end:sql_search] // --8<-- [end:sql_search]
console.log("SQL search: done"); console.log("sql_legacy.ts: done");
})(); })();

View File

@@ -15,6 +15,7 @@ excluded_globs = [
"../src/python/duckdb.md", "../src/python/duckdb.md",
"../src/python/pandas_and_pyarrow.md", "../src/python/pandas_and_pyarrow.md",
"../src/python/polars_arrow.md", "../src/python/polars_arrow.md",
"../src/python/pydantic.md",
"../src/embeddings/*.md", "../src/embeddings/*.md",
"../src/concepts/*.md", "../src/concepts/*.md",
"../src/ann_indexes.md", "../src/ann_indexes.md",

View File

@@ -8,7 +8,7 @@
<parent> <parent>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.15.1-beta.3</version> <version>0.18.0-beta.0</version>
<relativePath>../pom.xml</relativePath> <relativePath>../pom.xml</relativePath>
</parent> </parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.15.1-beta.3</version> <version>0.18.0-beta.0</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>LanceDB Parent</name> <name>LanceDB Parent</name>

68
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -52,14 +52,14 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.15.1-beta.3", "@lancedb/vectordb-darwin-arm64": "0.18.0-beta.0",
"@lancedb/vectordb-darwin-x64": "0.15.1-beta.3", "@lancedb/vectordb-darwin-x64": "0.18.0-beta.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.15.1-beta.3", "@lancedb/vectordb-linux-arm64-gnu": "0.18.0-beta.0",
"@lancedb/vectordb-linux-arm64-musl": "0.15.1-beta.3", "@lancedb/vectordb-linux-arm64-musl": "0.18.0-beta.0",
"@lancedb/vectordb-linux-x64-gnu": "0.15.1-beta.3", "@lancedb/vectordb-linux-x64-gnu": "0.18.0-beta.0",
"@lancedb/vectordb-linux-x64-musl": "0.15.1-beta.3", "@lancedb/vectordb-linux-x64-musl": "0.18.0-beta.0",
"@lancedb/vectordb-win32-arm64-msvc": "0.15.1-beta.3", "@lancedb/vectordb-win32-arm64-msvc": "0.18.0-beta.0",
"@lancedb/vectordb-win32-x64-msvc": "0.15.1-beta.3" "@lancedb/vectordb-win32-x64-msvc": "0.18.0-beta.0"
}, },
"peerDependencies": { "peerDependencies": {
"@apache-arrow/ts": "^14.0.2", "@apache-arrow/ts": "^14.0.2",
@@ -330,9 +330,9 @@
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": { "node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.18.0-beta.0.tgz",
"integrity": "sha512-2GinbODdSsUc+zJQ4BFZPsdraPWHJpDpGf7CsZIqfokwxIRnzVzFfQy+SZhmNhKzFkmtW21yWw6wrJ4FgS7Qtw==", "integrity": "sha512-dLLgMPllYJOiRfPqkqkmoQu48RIa7K4dOF/qFP8Aex3zqeHE/0sFm3DYjtSFc6SR/6yT8u6Y9iFo2cQp5rCFJA==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -343,9 +343,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-darwin-x64": { "node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.18.0-beta.0.tgz",
"integrity": "sha512-nRp5eN6yvx5kvfDEQuh3EHCmwjVNCIm7dXoV6BasepFkOoaHHmjKSIUFW7HjtJOfdFbb+r8UjBJx4cN6Jh2iFg==", "integrity": "sha512-la0eauU0rzHO5eeVjBt8o/5UW4VzRYAuRA7nqUFLX5T6SWP5+UWjqusVVbWGz3ski+8uEX6VhlaFZP5uIJKGIg==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -356,9 +356,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-gnu": { "node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.18.0-beta.0.tgz",
"integrity": "sha512-JOyD7Nt3RSfHGWNQjHbZMHsIw1cVWPySxbtDmDqk5QH5IfgDNZLiz/sNbROuQkNvc5SsC6wUmhBUwWBETzW7/g==", "integrity": "sha512-AkXI/lB3yu1Di2G1lhilf89V6qPTppb13aAt+/6gU5/PSfA94y9VXD67D4WyvRbuQghJjDvAavMlWMrJc2NuMw==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -369,9 +369,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-musl": { "node_modules/@lancedb/vectordb-linux-arm64-musl": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.18.0-beta.0.tgz",
"integrity": "sha512-4jTHl1i/4e7wP2U7RMjHr87/gsGJ9tfRJ4ljQIfV+LkA7ROMd/TA5XSnvPesQCDjPNRI4wAyb/BmK18V96VqBg==", "integrity": "sha512-kTVcJ4LA8w/7egY4m0EXOt8c1DeFUquVtyvexO+VzIFeeHfBkkrMI0DkE0CpHmk+gctkG7EY39jzjgLnPvppnw==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -382,9 +382,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-gnu": { "node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.18.0-beta.0.tgz",
"integrity": "sha512-odrNqB/bGL+sweZi6ed9sKft/H5/bca/tDVG/Y39xCJ6swPWxXQK2Zpn7EjqbccI2p2zkrhKcOUBO/bEkOqQng==", "integrity": "sha512-KbtIy5DkaWTsKENm5Q27hjovrR7FRuoHhl0wDJtO/2CUZYlrskjEIfcfkfA2CrEQesBug4s5jgsvNM4Wcp6zoA==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -395,9 +395,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-musl": { "node_modules/@lancedb/vectordb-linux-x64-musl": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.18.0-beta.0.tgz",
"integrity": "sha512-Zml4KgQWzkkMBHZiD30Gs3N56BT5xO01efwO/Q2qB7JKw5Vy9pa6SgFf9woBvKFQRY73fiKqafy+BmGHTgozNg==", "integrity": "sha512-SF07gmoGVExcF5v+IE6kBbCbXJSDyTgC7QCt+MDS1NsgoQ9OH7IyH7r6HJu16tKflUOUKlUHnP0hQOPpv1fWpg==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -408,9 +408,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-win32-arm64-msvc": { "node_modules/@lancedb/vectordb-win32-arm64-msvc": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-arm64-msvc/-/vectordb-win32-arm64-msvc-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-arm64-msvc/-/vectordb-win32-arm64-msvc-0.18.0-beta.0.tgz",
"integrity": "sha512-3BWkK+8JP+js/KoTad7bm26NTR5pq2tvXJkrFB0eaFfsIuUXebS+LIBF22f39He2WMpq3YojT0bMnYxp8qvRkQ==", "integrity": "sha512-YYBuSBGDlxJgSI5gHjDmQo9sl05lAXfzil6QiKfgmUMsBtb2sT+GoUCgG6qzsfe99sWiTf+pMeWDsQgfrj9vNw==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -421,9 +421,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-win32-x64-msvc": { "node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.15.1-beta.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.18.0-beta.0.tgz",
"integrity": "sha512-jr8SEisYAX7pQHIbxIDJPkANmxWh5Yohm8ELbMgu76IvLI7bsS7sB9ID+kcj1SiS5m4V6OG2BO1FrEYbPLZ6Dg==", "integrity": "sha512-t9TXeUnMU7YbP+/nUJpStm75aWwUydZj2AK+G2XwDtQrQo4Xg7/NETEbBeogmIOHuidNQYia8jEeQCUon5/+Dw==",
"cpu": [ "cpu": [
"x64" "x64"
], ],

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"private": false, "private": false,
"main": "dist/index.js", "main": "dist/index.js",
@@ -92,13 +92,13 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-x64": "0.15.1-beta.3", "@lancedb/vectordb-darwin-x64": "0.18.0-beta.0",
"@lancedb/vectordb-darwin-arm64": "0.15.1-beta.3", "@lancedb/vectordb-darwin-arm64": "0.18.0-beta.0",
"@lancedb/vectordb-linux-x64-gnu": "0.15.1-beta.3", "@lancedb/vectordb-linux-x64-gnu": "0.18.0-beta.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.15.1-beta.3", "@lancedb/vectordb-linux-arm64-gnu": "0.18.0-beta.0",
"@lancedb/vectordb-linux-x64-musl": "0.15.1-beta.3", "@lancedb/vectordb-linux-x64-musl": "0.18.0-beta.0",
"@lancedb/vectordb-linux-arm64-musl": "0.15.1-beta.3", "@lancedb/vectordb-linux-arm64-musl": "0.18.0-beta.0",
"@lancedb/vectordb-win32-x64-msvc": "0.15.1-beta.3", "@lancedb/vectordb-win32-x64-msvc": "0.18.0-beta.0",
"@lancedb/vectordb-win32-arm64-msvc": "0.15.1-beta.3" "@lancedb/vectordb-win32-arm64-msvc": "0.18.0-beta.0"
} }
} }

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "lancedb-nodejs" name = "lancedb-nodejs"
edition.workspace = true edition.workspace = true
version = "0.15.1-beta.3" version = "0.18.0-beta.0"
license.workspace = true license.workspace = true
description.workspace = true description.workspace = true
repository.workspace = true repository.workspace = true

View File

@@ -17,6 +17,8 @@ import {
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry"; import { getRegistry, register } from "../lancedb/embedding/registry";
const testOpenAIInteg = process.env.OPENAI_API_KEY == null ? test.skip : test;
describe("embedding functions", () => { describe("embedding functions", () => {
let tmpDir: tmp.DirResult; let tmpDir: tmp.DirResult;
beforeEach(() => { beforeEach(() => {
@@ -29,9 +31,6 @@ describe("embedding functions", () => {
it("should be able to create a table with an embedding function", async () => { it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() { ndims() {
return 3; return 3;
} }
@@ -75,9 +74,6 @@ describe("embedding functions", () => {
it("should be able to append and upsert using embedding function", async () => { it("should be able to append and upsert using embedding function", async () => {
@register() @register()
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() { ndims() {
return 3; return 3;
} }
@@ -143,9 +139,6 @@ describe("embedding functions", () => {
it("should be able to create an empty table with an embedding function", async () => { it("should be able to create an empty table with an embedding function", async () => {
@register() @register()
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() { ndims() {
return 3; return 3;
} }
@@ -194,9 +187,6 @@ describe("embedding functions", () => {
it("should error when appending to a table with an unregistered embedding function", async () => { it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock") @register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() { ndims() {
return 3; return 3;
} }
@@ -241,13 +231,35 @@ describe("embedding functions", () => {
`Function "mock" not found in registry`, `Function "mock" not found in registry`,
); );
}); });
testOpenAIInteg("propagates variables through all methods", async () => {
delete process.env.OPENAI_API_KEY;
const registry = getRegistry();
registry.setVar("openai_api_key", "sk-...");
const func = registry.get("openai")?.create({
model: "text-embedding-ada-002",
apiKey: "$var:openai_api_key",
}) as EmbeddingFunction;
const db = await connect("memory://");
const wordsSchema = LanceSchema({
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const tbl = await db.createEmptyTable("words", wordsSchema, {
mode: "overwrite",
});
await tbl.add([{ text: "hello world" }, { text: "goodbye world" }]);
const query = "greetings";
const actual = (await tbl.search(query).limit(1).toArray())[0];
expect(actual).toHaveProperty("text");
});
test.each([new Float16(), new Float32(), new Float64()])( test.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide manual embeddings with multiple float datatype", "should be able to provide manual embeddings with multiple float datatype",
async (floatType) => { async (floatType) => {
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() { ndims() {
return 3; return 3;
} }
@@ -292,10 +304,6 @@ describe("embedding functions", () => {
async (floatType) => { async (floatType) => {
@register("test1") @register("test1")
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> { class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
embeddingDataType(): Float { embeddingDataType(): Float {
return floatType; return floatType;
} }
@@ -310,9 +318,6 @@ describe("embedding functions", () => {
} }
@register("test") @register("test")
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() { ndims() {
return 3; return 3;
} }

View File

@@ -11,7 +11,11 @@ import * as arrow18 from "apache-arrow-18";
import * as tmp from "tmp"; import * as tmp from "tmp";
import { connect } from "../lancedb"; import { connect } from "../lancedb";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; import {
EmbeddingFunction,
FunctionOptions,
LanceSchema,
} from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry"; import { getRegistry, register } from "../lancedb/embedding/registry";
describe.each([arrow15, arrow16, arrow17, arrow18])("LanceSchema", (arrow) => { describe.each([arrow15, arrow16, arrow17, arrow18])("LanceSchema", (arrow) => {
@@ -39,11 +43,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
it("should register a new item to the registry", async () => { it("should register a new item to the registry", async () => {
@register("mock-embedding") @register("mock-embedding")
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() { constructor() {
super(); super();
} }
@@ -89,11 +88,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
}); });
test("should error if registering with the same name", async () => { test("should error if registering with the same name", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() { constructor() {
super(); super();
} }
@@ -114,13 +108,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
}); });
test("schema should contain correct metadata", async () => { test("schema should contain correct metadata", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object { constructor(args: FunctionOptions = {}) {
return {
someText: "hello",
};
}
constructor() {
super(); super();
this.resolveVariables(args);
} }
ndims() { ndims() {
return 3; return 3;
@@ -132,7 +122,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
return data.map(() => [1, 2, 3]); return data.map(() => [1, 2, 3]);
} }
} }
const func = new MockEmbeddingFunction(); const func = new MockEmbeddingFunction({ someText: "hello" });
const schema = LanceSchema({ const schema = LanceSchema({
id: new arrow.Int32(), id: new arrow.Int32(),
@@ -155,3 +145,79 @@ describe.each([arrow15, arrow16, arrow17, arrow18])("Registry", (arrow) => {
expect(schema.metadata).toEqual(expectedMetadata); expect(schema.metadata).toEqual(expectedMetadata);
}); });
}); });
describe("Registry.setVar", () => {
const registry = getRegistry();
beforeEach(() => {
@register("mock-embedding")
// biome-ignore lint/correctness/noUnusedVariables :
class MockEmbeddingFunction extends EmbeddingFunction<string> {
constructor(optionsRaw: FunctionOptions = {}) {
super();
const options = this.resolveVariables(optionsRaw);
expect(optionsRaw["someKey"].startsWith("$var:someName")).toBe(true);
expect(options["someKey"]).toBe("someValue");
if (options["secretKey"]) {
expect(optionsRaw["secretKey"]).toBe("$var:secretKey");
expect(options["secretKey"]).toBe("mySecret");
}
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
embeddingDataType() {
return new arrow18.Float32() as apiArrow.Float;
}
protected getSensitiveKeys() {
return ["secretKey"];
}
}
});
afterEach(() => {
registry.reset();
});
it("Should error if the variable is not set", () => {
console.log(registry.get("mock-embedding"));
expect(() =>
registry.get("mock-embedding")!.create({ someKey: "$var:someName" }),
).toThrow('Variable "someName" not found');
});
it("should use default values if not set", () => {
registry
.get("mock-embedding")!
.create({ someKey: "$var:someName:someValue" });
});
it("should set a variable that the embedding function understand", () => {
registry.setVar("someName", "someValue");
registry.get("mock-embedding")!.create({ someKey: "$var:someName" });
});
it("should reject secrets that aren't passed as variables", () => {
registry.setVar("someName", "someValue");
expect(() =>
registry
.get("mock-embedding")!
.create({ secretKey: "someValue", someKey: "$var:someName" }),
).toThrow(
'The key "secretKey" is sensitive and cannot be set directly. Please use the $var: syntax to set it.',
);
});
it("should not serialize secrets", () => {
registry.setVar("someName", "someValue");
registry.setVar("secretKey", "mySecret");
const func = registry
.get("mock-embedding")!
.create({ secretKey: "$var:secretKey", someKey: "$var:someName" });
expect(func.toJSON()).toEqual({
secretKey: "$var:secretKey",
someKey: "$var:someName",
});
});
});

View File

@@ -175,6 +175,8 @@ maybeDescribe("storage_options", () => {
tableNames = await db.tableNames(); tableNames = await db.tableNames();
expect(tableNames).toEqual([]); expect(tableNames).toEqual([]);
await db.dropAllTables();
}); });
it("can configure encryption at connection and table level", async () => { it("can configure encryption at connection and table level", async () => {
@@ -210,6 +212,8 @@ maybeDescribe("storage_options", () => {
await table.add([{ a: 2, b: 3 }]); await table.add([{ a: 2, b: 3 }]);
await bucket.assertAllEncrypted("test/table2.lance", kmsKey.keyId); await bucket.assertAllEncrypted("test/table2.lance", kmsKey.keyId);
await db.dropAllTables();
}); });
}); });
@@ -298,5 +302,32 @@ maybeDescribe("DynamoDB Lock", () => {
const rowCount = await table.countRows(); const rowCount = await table.countRows();
expect(rowCount).toBe(6); expect(rowCount).toBe(6);
await db.dropAllTables();
});
it("clears dynamodb state after dropping all tables", async () => {
const uri = `s3+ddb://${bucket.name}/test?ddbTableName=${commitTable.name}`;
const db = await connect(uri, {
storageOptions: CONFIG,
readConsistencyInterval: 0,
});
await db.createTable("foo", [{ a: 1, b: 2 }]);
await db.createTable("bar", [{ a: 1, b: 2 }]);
let tableNames = await db.tableNames();
expect(tableNames).toEqual(["bar", "foo"]);
await db.dropAllTables();
tableNames = await db.tableNames();
expect(tableNames).toEqual([]);
// We can create a new table with the same name as the one we dropped.
await db.createTable("foo", [{ a: 1, b: 2 }]);
tableNames = await db.tableNames();
expect(tableNames).toEqual(["foo"]);
await db.dropAllTables();
}); });
}); });

View File

@@ -666,11 +666,11 @@ describe("When creating an index", () => {
expect(fs.readdirSync(indexDir)).toHaveLength(1); expect(fs.readdirSync(indexDir)).toHaveLength(1);
for await (const r of tbl.query().where("id > 1").select(["id"])) { for await (const r of tbl.query().where("id > 1").select(["id"])) {
expect(r.numRows).toBe(10); expect(r.numRows).toBe(298);
} }
// should also work with 'filter' alias // should also work with 'filter' alias
for await (const r of tbl.query().filter("id > 1").select(["id"])) { for await (const r of tbl.query().filter("id > 1").select(["id"])) {
expect(r.numRows).toBe(10); expect(r.numRows).toBe(298);
} }
}); });
@@ -1038,9 +1038,6 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
test("can search using a string", async () => { test("can search using a string", async () => {
@register() @register()
class MockEmbeddingFunction extends EmbeddingFunction<string> { class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() { ndims() {
return 1; return 1;
} }

View File

@@ -43,12 +43,17 @@ test("custom embedding function", async () => {
@register("my_embedding") @register("my_embedding")
class MyEmbeddingFunction extends EmbeddingFunction<string> { class MyEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object { constructor(optionsRaw = {}) {
return {}; super();
const options = this.resolveVariables(optionsRaw);
// Initialize using options
} }
ndims() { ndims() {
return 3; return 3;
} }
protected getSensitiveKeys(): string[] {
return [];
}
embeddingDataType(): Float { embeddingDataType(): Float {
return new Float32(); return new Float32();
} }
@@ -94,3 +99,14 @@ test("custom embedding function", async () => {
expect(await table2.countRows()).toBe(2); expect(await table2.countRows()).toBe(2);
}); });
}); });
test("embedding function api_key", async () => {
// --8<-- [start:register_secret]
const registry = getRegistry();
registry.setVar("api_key", "sk-...");
const func = registry.get("openai")!.create({
apiKey: "$var:api_key",
});
// --8<-- [end:register_secret]
});

View File

@@ -15,6 +15,7 @@ import {
newVectorType, newVectorType,
} from "../arrow"; } from "../arrow";
import { sanitizeType } from "../sanitize"; import { sanitizeType } from "../sanitize";
import { getRegistry } from "./registry";
/** /**
* Options for a given embedding function * Options for a given embedding function
@@ -32,6 +33,22 @@ export interface EmbeddingFunctionConstructor<
/** /**
* An embedding function that automatically creates vector representation for a given column. * An embedding function that automatically creates vector representation for a given column.
*
* It's important subclasses pass the **original** options to the super constructor
* and then pass those options to `resolveVariables` to resolve any variables before
* using them.
*
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* constructor(options: {model: string, timeout: number}) {
* super(optionsRaw);
* const options = this.resolveVariables(optionsRaw);
* this.model = options.model;
* this.timeout = options.timeout;
* }
* }
* ```
*/ */
export abstract class EmbeddingFunction< export abstract class EmbeddingFunction<
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do // biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
@@ -44,33 +61,74 @@ export abstract class EmbeddingFunction<
*/ */
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is // biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
readonly TOptions!: M; readonly TOptions!: M;
/**
* Convert the embedding function to a JSON object
* It is used to serialize the embedding function to the schema
* It's important that any object returned by this method contains all the necessary
* information to recreate the embedding function
*
* It should return the same object that was passed to the constructor
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
*
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* constructor(options: {model: string, timeout: number}) {
* super();
* this.model = options.model;
* this.timeout = options.timeout;
* }
* toJSON() {
* return {
* model: this.model,
* timeout: this.timeout,
* };
* }
* ```
*/
abstract toJSON(): Partial<M>;
#config: Partial<M>;
/**
* Get the original arguments to the constructor, to serialize them so they
* can be used to recreate the embedding function later.
*/
// biome-ignore lint/suspicious/noExplicitAny :
toJSON(): Record<string, any> {
return JSON.parse(JSON.stringify(this.#config));
}
constructor() {
this.#config = {};
}
/**
* Provide a list of keys in the function options that should be treated as
* sensitive. If users pass raw values for these keys, they will be rejected.
*/
protected getSensitiveKeys(): string[] {
return [];
}
/**
* Apply variables to the config.
*/
protected resolveVariables(config: Partial<M>): Partial<M> {
this.#config = config;
const registry = getRegistry();
const newConfig = { ...config };
for (const [key_, value] of Object.entries(newConfig)) {
if (
this.getSensitiveKeys().includes(key_) &&
!value.startsWith("$var:")
) {
throw new Error(
`The key "${key_}" is sensitive and cannot be set directly. Please use the $var: syntax to set it.`,
);
}
// Makes TS happy (https://stackoverflow.com/a/78391854)
const key = key_ as keyof M;
if (typeof value === "string" && value.startsWith("$var:")) {
const [name, defaultValue] = value.slice(5).split(":", 2);
const variableValue = registry.getVar(name);
if (!variableValue) {
if (defaultValue) {
// biome-ignore lint/suspicious/noExplicitAny:
newConfig[key] = defaultValue as any;
} else {
throw new Error(`Variable "${name}" not found`);
}
} else {
// biome-ignore lint/suspicious/noExplicitAny:
newConfig[key] = variableValue as any;
}
}
}
return newConfig;
}
/**
* Optionally load any resources needed for the embedding function.
*
* This method is called after the embedding function has been initialized
* but before any embeddings are computed. It is useful for loading local models
* or other resources that are needed for the embedding function to work.
*/
async init?(): Promise<void>; async init?(): Promise<void>;
/** /**

View File

@@ -21,11 +21,13 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
#modelName: OpenAIOptions["model"]; #modelName: OpenAIOptions["model"];
constructor( constructor(
options: Partial<OpenAIOptions> = { optionsRaw: Partial<OpenAIOptions> = {
model: "text-embedding-ada-002", model: "text-embedding-ada-002",
}, },
) { ) {
super(); super();
const options = this.resolveVariables(optionsRaw);
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY; const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
if (!openAIKey) { if (!openAIKey) {
throw new Error("OpenAI API key is required"); throw new Error("OpenAI API key is required");
@@ -52,10 +54,8 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
this.#modelName = modelName; this.#modelName = modelName;
} }
toJSON() { protected getSensitiveKeys(): string[] {
return { return ["apiKey"];
model: this.#modelName,
};
} }
ndims(): number { ndims(): number {

View File

@@ -23,6 +23,7 @@ export interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
*/ */
export class EmbeddingFunctionRegistry { export class EmbeddingFunctionRegistry {
#functions = new Map<string, EmbeddingFunctionConstructor>(); #functions = new Map<string, EmbeddingFunctionConstructor>();
#variables = new Map<string, string>();
/** /**
* Get the number of registered functions * Get the number of registered functions
@@ -82,10 +83,7 @@ export class EmbeddingFunctionRegistry {
}; };
} else { } else {
// biome-ignore lint/suspicious/noExplicitAny: <explanation> // biome-ignore lint/suspicious/noExplicitAny: <explanation>
create = function (options?: any) { create = (options?: any) => new factory(options);
const instance = new factory(options);
return instance;
};
} }
return { return {
@@ -164,6 +162,37 @@ export class EmbeddingFunctionRegistry {
return metadata; return metadata;
} }
/**
* Set a variable. These can be accessed in the embedding function
* configuration using the syntax `$var:variable_name`. If they are not
* set, an error will be thrown letting you know which key is unset. If you
* want to supply a default value, you can add an additional part in the
* configuration like so: `$var:variable_name:default_value`. Default values
* can be used for runtime configurations that are not sensitive, such as
* whether to use a GPU for inference.
*
* The name must not contain colons. The default value can contain colons.
*
* @param name
* @param value
*/
setVar(name: string, value: string): void {
if (name.includes(":")) {
throw new Error("Variable names cannot contain colons");
}
this.#variables.set(name, value);
}
/**
* Get a variable.
* @param name
* @returns
* @see {@link setVar}
*/
getVar(name: string): string | undefined {
return this.#variables.get(name);
}
} }
const _REGISTRY = new EmbeddingFunctionRegistry(); const _REGISTRY = new EmbeddingFunctionRegistry();

View File

@@ -44,11 +44,12 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
#ndims?: number; #ndims?: number;
constructor( constructor(
options: Partial<XenovaTransformerOptions> = { optionsRaw: Partial<XenovaTransformerOptions> = {
model: "Xenova/all-MiniLM-L6-v2", model: "Xenova/all-MiniLM-L6-v2",
}, },
) { ) {
super(); super();
const options = this.resolveVariables(optionsRaw);
const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2"; const modelName = options?.model ?? "Xenova/all-MiniLM-L6-v2";
this.#tokenizerOptions = { this.#tokenizerOptions = {
@@ -59,22 +60,6 @@ export class TransformersEmbeddingFunction extends EmbeddingFunction<
this.#ndims = options.ndims; this.#ndims = options.ndims;
this.#modelName = modelName; this.#modelName = modelName;
} }
toJSON() {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const obj: Record<string, any> = {
model: this.#modelName,
};
if (this.#ndims) {
obj["ndims"] = this.#ndims;
}
if (this.#tokenizerOptions) {
obj["tokenizerOptions"] = this.#tokenizerOptions;
}
if (this.#tokenizer) {
obj["tokenizer"] = this.#tokenizer.name;
}
return obj;
}
async init() { async init() {
let transformers; let transformers;

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-musl", "name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node", "main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-musl", "name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node", "main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-arm64-msvc", "name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": [ "os": [
"win32" "win32"
], ],

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -11,7 +11,7 @@
"ann" "ann"
], ],
"private": false, "private": false,
"version": "0.15.1-beta.3", "version": "0.18.0-beta.0",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",

56
pyright_report.csv Normal file
View File

@@ -0,0 +1,56 @@
file,errors,warnings,total_issues
python/python/lancedb/arrow.py,0,0,0
python/python/lancedb/background_loop.py,0,0,0
python/python/lancedb/embeddings/__init__.py,0,0,0
python/python/lancedb/exceptions.py,0,0,0
python/python/lancedb/index.py,0,0,0
python/python/lancedb/integrations/__init__.py,0,0,0
python/python/lancedb/remote/__init__.py,0,0,0
python/python/lancedb/remote/errors.py,0,0,0
python/python/lancedb/rerankers/__init__.py,0,0,0
python/python/lancedb/rerankers/answerdotai.py,0,0,0
python/python/lancedb/rerankers/cohere.py,0,0,0
python/python/lancedb/rerankers/colbert.py,0,0,0
python/python/lancedb/rerankers/cross_encoder.py,0,0,0
python/python/lancedb/rerankers/openai.py,0,0,0
python/python/lancedb/rerankers/util.py,0,0,0
python/python/lancedb/rerankers/voyageai.py,0,0,0
python/python/lancedb/schema.py,0,0,0
python/python/lancedb/types.py,0,0,0
python/python/lancedb/__init__.py,0,1,1
python/python/lancedb/conftest.py,1,0,1
python/python/lancedb/embeddings/bedrock.py,1,0,1
python/python/lancedb/merge.py,1,0,1
python/python/lancedb/rerankers/base.py,1,0,1
python/python/lancedb/rerankers/jinaai.py,0,1,1
python/python/lancedb/rerankers/linear_combination.py,1,0,1
python/python/lancedb/embeddings/instructor.py,2,0,2
python/python/lancedb/embeddings/openai.py,2,0,2
python/python/lancedb/embeddings/watsonx.py,2,0,2
python/python/lancedb/embeddings/registry.py,3,0,3
python/python/lancedb/embeddings/sentence_transformers.py,3,0,3
python/python/lancedb/integrations/pyarrow.py,3,0,3
python/python/lancedb/rerankers/rrf.py,3,0,3
python/python/lancedb/dependencies.py,4,0,4
python/python/lancedb/embeddings/gemini_text.py,4,0,4
python/python/lancedb/embeddings/gte.py,4,0,4
python/python/lancedb/embeddings/gte_mlx_model.py,4,0,4
python/python/lancedb/embeddings/ollama.py,4,0,4
python/python/lancedb/embeddings/transformers.py,4,0,4
python/python/lancedb/remote/db.py,5,0,5
python/python/lancedb/context.py,6,0,6
python/python/lancedb/embeddings/cohere.py,6,0,6
python/python/lancedb/fts.py,6,0,6
python/python/lancedb/db.py,9,0,9
python/python/lancedb/embeddings/utils.py,9,0,9
python/python/lancedb/common.py,11,0,11
python/python/lancedb/util.py,13,0,13
python/python/lancedb/embeddings/imagebind.py,14,0,14
python/python/lancedb/embeddings/voyageai.py,15,0,15
python/python/lancedb/embeddings/open_clip.py,16,0,16
python/python/lancedb/pydantic.py,16,0,16
python/python/lancedb/embeddings/base.py,17,0,17
python/python/lancedb/embeddings/jinaai.py,18,1,19
python/python/lancedb/remote/table.py,23,0,23
python/python/lancedb/query.py,47,1,48
python/python/lancedb/table.py,61,0,61
1 file errors warnings total_issues
2 python/python/lancedb/arrow.py 0 0 0
3 python/python/lancedb/background_loop.py 0 0 0
4 python/python/lancedb/embeddings/__init__.py 0 0 0
5 python/python/lancedb/exceptions.py 0 0 0
6 python/python/lancedb/index.py 0 0 0
7 python/python/lancedb/integrations/__init__.py 0 0 0
8 python/python/lancedb/remote/__init__.py 0 0 0
9 python/python/lancedb/remote/errors.py 0 0 0
10 python/python/lancedb/rerankers/__init__.py 0 0 0
11 python/python/lancedb/rerankers/answerdotai.py 0 0 0
12 python/python/lancedb/rerankers/cohere.py 0 0 0
13 python/python/lancedb/rerankers/colbert.py 0 0 0
14 python/python/lancedb/rerankers/cross_encoder.py 0 0 0
15 python/python/lancedb/rerankers/openai.py 0 0 0
16 python/python/lancedb/rerankers/util.py 0 0 0
17 python/python/lancedb/rerankers/voyageai.py 0 0 0
18 python/python/lancedb/schema.py 0 0 0
19 python/python/lancedb/types.py 0 0 0
20 python/python/lancedb/__init__.py 0 1 1
21 python/python/lancedb/conftest.py 1 0 1
22 python/python/lancedb/embeddings/bedrock.py 1 0 1
23 python/python/lancedb/merge.py 1 0 1
24 python/python/lancedb/rerankers/base.py 1 0 1
25 python/python/lancedb/rerankers/jinaai.py 0 1 1
26 python/python/lancedb/rerankers/linear_combination.py 1 0 1
27 python/python/lancedb/embeddings/instructor.py 2 0 2
28 python/python/lancedb/embeddings/openai.py 2 0 2
29 python/python/lancedb/embeddings/watsonx.py 2 0 2
30 python/python/lancedb/embeddings/registry.py 3 0 3
31 python/python/lancedb/embeddings/sentence_transformers.py 3 0 3
32 python/python/lancedb/integrations/pyarrow.py 3 0 3
33 python/python/lancedb/rerankers/rrf.py 3 0 3
34 python/python/lancedb/dependencies.py 4 0 4
35 python/python/lancedb/embeddings/gemini_text.py 4 0 4
36 python/python/lancedb/embeddings/gte.py 4 0 4
37 python/python/lancedb/embeddings/gte_mlx_model.py 4 0 4
38 python/python/lancedb/embeddings/ollama.py 4 0 4
39 python/python/lancedb/embeddings/transformers.py 4 0 4
40 python/python/lancedb/remote/db.py 5 0 5
41 python/python/lancedb/context.py 6 0 6
42 python/python/lancedb/embeddings/cohere.py 6 0 6
43 python/python/lancedb/fts.py 6 0 6
44 python/python/lancedb/db.py 9 0 9
45 python/python/lancedb/embeddings/utils.py 9 0 9
46 python/python/lancedb/common.py 11 0 11
47 python/python/lancedb/util.py 13 0 13
48 python/python/lancedb/embeddings/imagebind.py 14 0 14
49 python/python/lancedb/embeddings/voyageai.py 15 0 15
50 python/python/lancedb/embeddings/open_clip.py 16 0 16
51 python/python/lancedb/pydantic.py 16 0 16
52 python/python/lancedb/embeddings/base.py 17 0 17
53 python/python/lancedb/embeddings/jinaai.py 18 1 19
54 python/python/lancedb/remote/table.py 23 0 23
55 python/python/lancedb/query.py 47 1 48
56 python/python/lancedb/table.py 61 0 61

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.19.0" current_version = "0.21.0"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -8,9 +8,9 @@ For general contribution guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md).
The Python package is a wrapper around the Rust library, `lancedb`. We use The Python package is a wrapper around the Rust library, `lancedb`. We use
[pyo3](https://pyo3.rs/) to create the bindings between Rust and Python. [pyo3](https://pyo3.rs/) to create the bindings between Rust and Python.
* `src/`: Rust bindings source code - `src/`: Rust bindings source code
* `python/lancedb`: Python package source code - `python/lancedb`: Python package source code
* `python/tests`: Unit tests - `python/tests`: Unit tests
## Development environment ## Development environment
@@ -61,6 +61,12 @@ make test
make doctest make doctest
``` ```
Run type checking:
```shell
make typecheck
```
To run a single test, you can use the `pytest` command directly. Provide the path To run a single test, you can use the `pytest` command directly. Provide the path
to the test file, and optionally the test name after `::`. to the test file, and optionally the test name after `::`.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.19.0" version = "0.21.0"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true
@@ -14,21 +14,20 @@ name = "_lancedb"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
arrow = { version = "53.2", features = ["pyarrow"] } arrow = { version = "54.1", features = ["pyarrow"] }
lancedb = { path = "../rust/lancedb", default-features = false } lancedb = { path = "../rust/lancedb", default-features = false }
env_logger.workspace = true env_logger.workspace = true
pyo3 = { version = "0.22.2", features = [ pyo3 = { version = "0.23", features = ["extension-module", "abi3-py39"] }
"extension-module", pyo3-async-runtimes = { version = "0.23", features = [
"abi3-py39", "attributes",
"gil-refs" "tokio-runtime",
] } ] }
pyo3-async-runtimes = { version = "0.22", features = ["attributes", "tokio-runtime"] }
pin-project = "1.1.5" pin-project = "1.1.5"
futures.workspace = true futures.workspace = true
tokio = { version = "1.40", features = ["sync"] } tokio = { version = "1.40", features = ["sync"] }
[build-dependencies] [build-dependencies]
pyo3-build-config = { version = "0.20.3", features = [ pyo3-build-config = { version = "0.23", features = [
"extension-module", "extension-module",
"abi3-py39", "abi3-py39",
] } ] }

View File

@@ -23,10 +23,18 @@ check: ## Check formatting and lints.
fix: ## Fix python lints fix: ## Fix python lints
ruff check python --fix ruff check python --fix
.PHONY: typecheck
typecheck: ## Run type checking with pyright.
pyright
.PHONY: doctest .PHONY: doctest
doctest: ## Run documentation tests. doctest: ## Run documentation tests.
pytest --doctest-modules python/lancedb pytest --doctest-modules python/lancedb
.PHONY: test .PHONY: test
test: ## Run tests. test: ## Run tests.
pytest python/tests -vv --durations=10 -m "not slow" pytest python/tests -vv --durations=10 -m "not slow and not s3_test"
.PHONY: clean
clean:
rm -rf data

View File

@@ -4,8 +4,8 @@ name = "lancedb"
dynamic = ["version"] dynamic = ["version"]
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.23.0",
"tqdm>=4.27.0", "tqdm>=4.27.0",
"pyarrow>=14",
"pydantic>=1.10", "pydantic>=1.10",
"packaging", "packaging",
"overrides>=0.7", "overrides>=0.7",
@@ -54,8 +54,14 @@ tests = [
"polars>=0.19, <=1.3.0", "polars>=0.19, <=1.3.0",
"tantivy", "tantivy",
"pyarrow-stubs", "pyarrow-stubs",
"pylance>=0.23.2",
]
dev = [
"ruff",
"pre-commit",
"pyright",
'typing-extensions>=4.0.0; python_version < "3.11"',
] ]
dev = ["ruff", "pre-commit", "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"']
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]
embeddings = [ embeddings = [
@@ -86,7 +92,7 @@ requires = ["maturin>=1.4"]
build-backend = "maturin" build-backend = "maturin"
[tool.ruff.lint] [tool.ruff.lint]
select = ["F", "E", "W", "G", "TCH", "PERF"] select = ["F", "E", "W", "G", "PERF"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py" addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
@@ -97,5 +103,28 @@ markers = [
] ]
[tool.pyright] [tool.pyright]
include = ["python/lancedb/table.py"] include = [
"python/lancedb/index.py",
"python/lancedb/rerankers/util.py",
"python/lancedb/rerankers/__init__.py",
"python/lancedb/rerankers/voyageai.py",
"python/lancedb/rerankers/jinaai.py",
"python/lancedb/rerankers/openai.py",
"python/lancedb/rerankers/cross_encoder.py",
"python/lancedb/rerankers/colbert.py",
"python/lancedb/rerankers/answerdotai.py",
"python/lancedb/rerankers/cohere.py",
"python/lancedb/arrow.py",
"python/lancedb/__init__.py",
"python/lancedb/types.py",
"python/lancedb/integrations/__init__.py",
"python/lancedb/exceptions.py",
"python/lancedb/background_loop.py",
"python/lancedb/schema.py",
"python/lancedb/remote/__init__.py",
"python/lancedb/remote/errors.py",
"python/lancedb/embeddings/__init__.py",
"python/lancedb/_lancedb.pyi",
]
exclude = ["python/tests/"]
pythonVersion = "3.12" pythonVersion = "3.12"

View File

@@ -14,6 +14,7 @@ from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote import ClientConfig from .remote import ClientConfig
from .remote.db import RemoteDBConnection
from .schema import vector from .schema import vector
from .table import AsyncTable from .table import AsyncTable
@@ -86,8 +87,6 @@ def connect(
conn : DBConnection conn : DBConnection
A connection to a LanceDB database. A connection to a LanceDB database.
""" """
from .remote.db import RemoteDBConnection
if isinstance(uri, str) and uri.startswith("db://"): if isinstance(uri, str) and uri.startswith("db://"):
if api_key is None: if api_key is None:
api_key = os.environ.get("LANCEDB_API_KEY") api_key = os.environ.get("LANCEDB_API_KEY")

View File

@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa import pyarrow as pa
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .remote import ClientConfig
class Connection(object): class Connection(object):
uri: str uri: str
@@ -71,11 +72,15 @@ async def connect(
region: Optional[str], region: Optional[str],
host_override: Optional[str], host_override: Optional[str],
read_consistency_interval: Optional[float], read_consistency_interval: Optional[float],
client_config: Optional[Union[ClientConfig, Dict[str, Any]]],
storage_options: Optional[Dict[str, str]],
) -> Connection: ... ) -> Connection: ...
class RecordBatchStream: class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ... def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ... def __aiter__(self) -> "RecordBatchStream": ...
async def __anext__(self) -> pa.RecordBatch: ...
class Query: class Query:
def where(self, filter: str): ... def where(self, filter: str): ...
@@ -142,6 +147,10 @@ class CompactionStats:
files_removed: int files_removed: int
files_added: int files_added: int
class CleanupStats:
bytes_removed: int
old_versions: int
class RemovalStats: class RemovalStats:
bytes_removed: int bytes_removed: int
old_versions_removed: int old_versions_removed: int

View File

@@ -2,8 +2,10 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import copy
from typing import List, Union from typing import List, Union
from lancedb.util import add_note
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@@ -28,13 +30,67 @@ class EmbeddingFunction(BaseModel, ABC):
7 # Setting 0 disables retires. Maybe this should not be enabled by default, 7 # Setting 0 disables retires. Maybe this should not be enabled by default,
) )
_ndims: int = PrivateAttr() _ndims: int = PrivateAttr()
_original_args: dict = PrivateAttr()
@classmethod @classmethod
def create(cls, **kwargs): def create(cls, **kwargs):
""" """
Create an instance of the embedding function Create an instance of the embedding function
""" """
return cls(**kwargs) resolved_kwargs = cls.__resolveVariables(kwargs)
instance = cls(**resolved_kwargs)
instance._original_args = kwargs
return instance
@classmethod
def __resolveVariables(cls, args: dict) -> dict:
"""
Resolve variables in the args
"""
from .registry import EmbeddingFunctionRegistry
new_args = copy.deepcopy(args)
registry = EmbeddingFunctionRegistry.get_instance()
sensitive_keys = cls.sensitive_keys()
for k, v in new_args.items():
if isinstance(v, str) and not v.startswith("$var:") and k in sensitive_keys:
exc = ValueError(
f"Sensitive key '{k}' cannot be set to a hardcoded value"
)
add_note(exc, "Help: Use $var: to set sensitive keys to variables")
raise exc
if isinstance(v, str) and v.startswith("$var:"):
parts = v[5:].split(":", maxsplit=1)
if len(parts) == 1:
try:
new_args[k] = registry.get_var(parts[0])
except KeyError:
exc = ValueError(
"Variable '{}' not found in registry".format(parts[0])
)
add_note(
exc,
"Help: Variables are reset in new Python sessions. "
"Use `registry.set_var` to set variables.",
)
raise exc
else:
name, default = parts
try:
new_args[k] = registry.get_var(name)
except KeyError:
new_args[k] = default
return new_args
@staticmethod
def sensitive_keys() -> List[str]:
"""
Return a list of keys that are sensitive and should not be allowed
to be set to hardcoded values in the config. For example, API keys.
"""
return []
@abstractmethod @abstractmethod
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]: def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
@@ -103,17 +159,11 @@ class EmbeddingFunction(BaseModel, ABC):
return texts return texts
def safe_model_dump(self): def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION if not hasattr(self, "_original_args"):
raise ValueError(
if PYDANTIC_VERSION.major < 2: "EmbeddingFunction was not created with EmbeddingFunction.create()"
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} )
return self.model_dump( return self._original_args
exclude={
field_name
for field_name in self.model_fields
if field_name.startswith("_")
}
)
@abstractmethod @abstractmethod
def ndims(self) -> int: def ndims(self) -> int:

View File

@@ -57,6 +57,10 @@ class JinaEmbeddings(EmbeddingFunction):
# TODO: fix hardcoding # TODO: fix hardcoding
return 768 return 768
@staticmethod
def sensitive_keys() -> List[str]:
return ["api_key"]
def sanitize_input( def sanitize_input(
self, inputs: Union[TEXT, IMAGES] self, inputs: Union[TEXT, IMAGES]
) -> Union[List[Any], np.ndarray]: ) -> Union[List[Any], np.ndarray]:

View File

@@ -54,6 +54,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
def ndims(self): def ndims(self):
return self._ndims return self._ndims
@staticmethod
def sensitive_keys():
return ["api_key"]
@staticmethod @staticmethod
def model_names(): def model_names():
return [ return [

View File

@@ -41,6 +41,7 @@ class EmbeddingFunctionRegistry:
def __init__(self): def __init__(self):
self._functions = {} self._functions = {}
self._variables = {}
def register(self, alias: str = None): def register(self, alias: str = None):
""" """
@@ -156,6 +157,28 @@ class EmbeddingFunctionRegistry:
metadata = json.dumps(json_data, indent=2).encode("utf-8") metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata} return {"embedding_functions": metadata}
def set_var(self, name: str, value: str) -> None:
"""
Set a variable. These can be accessed in embedding configuration using
the syntax `$var:variable_name`. If they are not set, an error will be
thrown letting you know which variable is missing. If you want to supply
a default value, you can add an additional part in the configuration
like so: `$var:variable_name:default_value`. Default values can be
used for runtime configurations that are not sensitive, such as
whether to use a GPU for inference.
The name must not contain a colon. Default values can contain colons.
"""
if ":" in name:
raise ValueError("Variable names cannot contain colons")
self._variables[name] = value
def get_var(self, name: str) -> str:
"""
Get a variable.
"""
return self._variables[name]
# Global instance # Global instance
__REGISTRY__ = EmbeddingFunctionRegistry() __REGISTRY__ = EmbeddingFunctionRegistry()

View File

@@ -40,6 +40,10 @@ class WatsonxEmbeddings(TextEmbeddingFunction):
url: Optional[str] = None url: Optional[str] = None
params: Optional[Dict] = None params: Optional[Dict] = None
@staticmethod
def sensitive_keys():
return ["api_key"]
@staticmethod @staticmethod
def model_names(): def model_names():
return [ return [

View File

@@ -199,18 +199,29 @@ else:
] ]
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
if inspect.isclass(tp):
if issubclass(tp, pydantic.BaseModel):
# Struct
fields = _pydantic_model_to_fields(tp)
return pa.struct(fields)
if issubclass(tp, FixedSizeListMixin):
return pa.list_(tp.value_arrow_type(), tp.dim())
return _py_type_to_arrow_type(tp, field)
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType""" """Convert a Pydantic FieldInfo to Arrow DataType"""
if isinstance(field.annotation, (_GenericAlias, GenericAlias)): if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__ origin = field.annotation.__origin__
args = field.annotation.__args__ args = field.annotation.__args__
if origin is list: if origin is list:
child = args[0] child = args[0]
return pa.list_(_py_type_to_arrow_type(child, field)) return pa.list_(_py_type_to_arrow_type(child, field))
elif origin == Union: elif origin == Union:
if len(args) == 2 and args[1] is type(None): if len(args) == 2 and args[1] is type(None):
return _py_type_to_arrow_type(args[0], field) return _pydantic_type_to_arrow_type(args[0], field)
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__ args = field.annotation.__args__
if len(args) == 2: if len(args) == 2:
@@ -218,14 +229,7 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
if typ is type(None): if typ is type(None):
continue continue
return _py_type_to_arrow_type(typ, field) return _py_type_to_arrow_type(typ, field)
elif inspect.isclass(field.annotation): return _pydantic_type_to_arrow_type(field.annotation, field)
if issubclass(field.annotation, pydantic.BaseModel):
# Struct
fields = _pydantic_model_to_fields(field.annotation)
return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
return _py_type_to_arrow_type(field.annotation, field)
def is_nullable(field: FieldInfo) -> bool: def is_nullable(field: FieldInfo) -> bool:
@@ -255,7 +259,8 @@ def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema: def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
"""Convert a Pydantic model to a PyArrow Schema. """Convert a [Pydantic Model][pydantic.BaseModel] to a
[PyArrow Schema][pyarrow.Schema].
Parameters Parameters
---------- ----------
@@ -265,24 +270,25 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
Returns Returns
------- -------
pyarrow.Schema pyarrow.Schema
The Arrow Schema
Examples Examples
-------- --------
>>> from typing import List, Optional >>> from typing import List, Optional
>>> import pydantic >>> import pydantic
>>> from lancedb.pydantic import pydantic_to_schema >>> from lancedb.pydantic import pydantic_to_schema, Vector
>>> class FooModel(pydantic.BaseModel): >>> class FooModel(pydantic.BaseModel):
... id: int ... id: int
... s: str ... s: str
... vec: List[float] ... vec: Vector(1536) # fixed_size_list<item: float32>[1536]
... li: List[int] ... li: List[int]
... ...
>>> schema = pydantic_to_schema(FooModel) >>> schema = pydantic_to_schema(FooModel)
>>> assert schema == pa.schema([ >>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False), ... pa.field("id", pa.int64(), False),
... pa.field("s", pa.utf8(), False), ... pa.field("s", pa.utf8(), False),
... pa.field("vec", pa.list_(pa.float64()), False), ... pa.field("vec", pa.list_(pa.float32(), 1536)),
... pa.field("li", pa.list_(pa.int64()), False), ... pa.field("li", pa.list_(pa.int64()), False),
... ]) ... ])
""" """
@@ -304,7 +310,7 @@ class LanceModel(pydantic.BaseModel):
... vector: Vector(2) ... vector: Vector(2)
... ...
>>> db = lancedb.connect("./example") >>> db = lancedb.connect("./example")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema()) >>> table = db.create_table("test", schema=TestModel)
>>> table.add([ >>> table.add([
... TestModel(name="test", vector=[1.0, 2.0]) ... TestModel(name="test", vector=[1.0, 2.0])
... ]) ... ])

View File

@@ -110,7 +110,7 @@ class Query(pydantic.BaseModel):
full_text_query: Optional[Union[str, dict]] = None full_text_query: Optional[Union[str, dict]] = None
# top k results to return # top k results to return
k: int k: Optional[int] = None
# # metrics # # metrics
metric: str = "L2" metric: str = "L2"
@@ -257,7 +257,7 @@ class LanceQueryBuilder(ABC):
def __init__(self, table: "Table"): def __init__(self, table: "Table"):
self._table = table self._table = table
self._limit = 10 self._limit = None
self._offset = 0 self._offset = 0
self._columns = None self._columns = None
self._where = None self._where = None
@@ -370,8 +370,7 @@ class LanceQueryBuilder(ABC):
The maximum number of results to return. The maximum number of results to return.
The default query limit is 10 results. The default query limit is 10 results.
For ANN/KNN queries, you must specify a limit. For ANN/KNN queries, you must specify a limit.
Entering 0, a negative number, or None will reset For plain searches, all records are returned if limit not set.
the limit to the default value of 10.
*WARNING* if you have a large dataset, setting *WARNING* if you have a large dataset, setting
the limit to a large number, e.g. the table size, the limit to a large number, e.g. the table size,
can potentially result in reading a can potentially result in reading a
@@ -595,6 +594,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
fast_search: bool = False, fast_search: bool = False,
): ):
super().__init__(table) super().__init__(table)
if self._limit is None:
self._limit = 10
self._query = query self._query = query
self._distance_type = "L2" self._distance_type = "L2"
self._nprobes = 20 self._nprobes = 20
@@ -888,6 +889,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
fts_columns: Union[str, List[str]] = [], fts_columns: Union[str, List[str]] = [],
): ):
super().__init__(table) super().__init__(table)
if self._limit is None:
self._limit = 10
self._query = query self._query = query
self._phrase_query = False self._phrase_query = False
self.ordering_field_name = ordering_field_name self.ordering_field_name = ordering_field_name
@@ -1055,7 +1058,7 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
query = Query( query = Query(
columns=self._columns, columns=self._columns,
filter=self._where, filter=self._where,
k=self._limit or 10, k=self._limit,
with_row_id=self._with_row_id, with_row_id=self._with_row_id,
vector=[], vector=[],
# not actually respected in remote query # not actually respected in remote query

View File

@@ -9,7 +9,8 @@ from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import warnings import warnings
from lancedb import connect_async # Remove this import to fix circular dependency
# from lancedb import connect_async
from lancedb.remote import ClientConfig from lancedb.remote import ClientConfig
import pyarrow as pa import pyarrow as pa
from overrides import override from overrides import override
@@ -78,6 +79,9 @@ class RemoteDBConnection(DBConnection):
self.client_config = client_config self.client_config = client_config
# Import connect_async here to avoid circular import
from lancedb import connect_async
self._conn = LOOP.run( self._conn = LOOP.run(
connect_async( connect_async(
db_url, db_url,

View File

@@ -3,7 +3,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect import inspect
import deprecation
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
@@ -23,15 +25,15 @@ from typing import (
) )
from urllib.parse import urlparse from urllib.parse import urlparse
import lance from . import __version__
from lancedb.arrow import peek_reader from lancedb.arrow import peek_reader
from lancedb.background_loop import LOOP from lancedb.background_loop import LOOP
from .dependencies import _check_for_pandas from .dependencies import _check_for_hugging_face, _check_for_pandas
import pyarrow as pa import pyarrow as pa
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
import numpy as np
from lance import LanceDataset from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
@@ -39,6 +41,8 @@ from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .merge import LanceMergeInsertBuilder from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict from .pydantic import LanceModel, model_to_dict
from .query import ( from .query import (
AsyncFTSQuery,
AsyncHybridQuery,
AsyncQuery, AsyncQuery,
AsyncVectorQuery, AsyncVectorQuery,
LanceEmptyQueryBuilder, LanceEmptyQueryBuilder,
@@ -62,24 +66,36 @@ from .index import lang_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats from ._lancedb import (
Table as LanceDBTable,
OptimizeStats,
CleanupStats,
CompactionStats,
)
from .db import LanceDBConnection from .db import LanceDBConnection
from .index import IndexConfig from .index import IndexConfig
from lance.dataset import CleanupStats, ReaderLike
import pandas import pandas
import PIL import PIL
from .types import (
QueryType,
OnBadVectorsType,
AddMode,
CreateMode,
VectorIndexType,
ScalarIndexType,
BaseTokenizerType,
DistanceType,
)
pd = safe_import_pandas() pd = safe_import_pandas()
pl = safe_import_polars() pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _into_pyarrow_reader(data) -> pa.RecordBatchReader: def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
if _check_for_hugging_face(data): from lancedb.dependencies import datasets
# Huggingface datasets
from lance.dependencies import datasets
if _check_for_hugging_face(data):
if isinstance(data, datasets.Dataset): if isinstance(data, datasets.Dataset):
schema = data.features.arrow_schema schema = data.features.arrow_schema
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches()) return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
@@ -171,7 +187,7 @@ def _sanitize_data(
data: "DATA", data: "DATA",
target_schema: Optional[pa.Schema] = None, target_schema: Optional[pa.Schema] = None,
metadata: Optional[dict] = None, # embedding metadata metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error", on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
*, *,
allow_subschema: bool = False, allow_subschema: bool = False,
@@ -317,7 +333,7 @@ def sanitize_create_table(
data, data,
schema: Union[pa.Schema, LanceModel], schema: Union[pa.Schema, LanceModel],
metadata=None, metadata=None,
on_bad_vectors: str = "error", on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
): ):
if inspect.isclass(schema) and issubclass(schema, LanceModel): if inspect.isclass(schema) and issubclass(schema, LanceModel):
@@ -569,9 +585,7 @@ class Table(ABC):
accelerator: Optional[str] = None, accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
*, *,
index_type: Literal[ index_type: VectorIndexType = "IVF_PQ",
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
num_bits: int = 8, num_bits: int = 8,
max_iterations: int = 50, max_iterations: int = 50,
sample_rate: int = 256, sample_rate: int = 256,
@@ -636,7 +650,7 @@ class Table(ABC):
column: str, column: str,
*, *,
replace: bool = True, replace: bool = True,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE", index_type: ScalarIndexType = "BTREE",
): ):
"""Create a scalar index on a column. """Create a scalar index on a column.
@@ -701,7 +715,7 @@ class Table(ABC):
tokenizer_name: Optional[str] = None, tokenizer_name: Optional[str] = None,
with_position: bool = True, with_position: bool = True,
# tokenizer configs: # tokenizer configs:
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple", base_tokenizer: BaseTokenizerType = "simple",
language: str = "English", language: str = "English",
max_token_length: Optional[int] = 40, max_token_length: Optional[int] = 40,
lower_case: bool = True, lower_case: bool = True,
@@ -770,8 +784,8 @@ class Table(ABC):
def add( def add(
self, self,
data: DATA, data: DATA,
mode: str = "append", mode: AddMode = "append",
on_bad_vectors: str = "error", on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
): ):
"""Add more data to the [Table](Table). """Add more data to the [Table](Table).
@@ -953,7 +967,7 @@ class Table(ABC):
self, self,
merge: LanceMergeInsertBuilder, merge: LanceMergeInsertBuilder,
new_data: DATA, new_data: DATA,
on_bad_vectors: str, on_bad_vectors: OnBadVectorsType,
fill_value: float, fill_value: float,
): ... ): ...
@@ -1070,7 +1084,7 @@ class Table(ABC):
older_than: Optional[timedelta] = None, older_than: Optional[timedelta] = None,
*, *,
delete_unverified: bool = False, delete_unverified: bool = False,
) -> CleanupStats: ) -> "CleanupStats":
""" """
Clean up old versions of the table, freeing disk space. Clean up old versions of the table, freeing disk space.
@@ -1381,6 +1395,14 @@ class LanceTable(Table):
def to_lance(self, **kwargs) -> LanceDataset: def to_lance(self, **kwargs) -> LanceDataset:
"""Return the LanceDataset backing this table.""" """Return the LanceDataset backing this table."""
try:
import lance
except ImportError:
raise ImportError(
"The lance library is required to use this function. "
"Please install with `pip install pylance`."
)
return lance.dataset( return lance.dataset(
self._dataset_path, self._dataset_path,
version=self.version, version=self.version,
@@ -1557,10 +1579,10 @@ class LanceTable(Table):
def create_index( def create_index(
self, self,
metric="L2", metric: DistanceType = "l2",
num_partitions=None, num_partitions=None,
num_sub_vectors=None, num_sub_vectors=None,
vector_column_name=VECTOR_COLUMN_NAME, vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True, replace: bool = True,
accelerator: Optional[str] = None, accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
@@ -1646,7 +1668,7 @@ class LanceTable(Table):
column: str, column: str,
*, *,
replace: bool = True, replace: bool = True,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE", index_type: ScalarIndexType = "BTREE",
): ):
if index_type == "BTREE": if index_type == "BTREE":
config = BTree() config = BTree()
@@ -1671,7 +1693,7 @@ class LanceTable(Table):
tokenizer_name: Optional[str] = None, tokenizer_name: Optional[str] = None,
with_position: bool = True, with_position: bool = True,
# tokenizer configs: # tokenizer configs:
base_tokenizer: str = "simple", base_tokenizer: BaseTokenizerType = "simple",
language: str = "English", language: str = "English",
max_token_length: Optional[int] = 40, max_token_length: Optional[int] = 40,
lower_case: bool = True, lower_case: bool = True,
@@ -1805,8 +1827,8 @@ class LanceTable(Table):
def add( def add(
self, self,
data: DATA, data: DATA,
mode: str = "append", mode: AddMode = "append",
on_bad_vectors: str = "error", on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
): ):
"""Add data to the table. """Add data to the table.
@@ -1840,7 +1862,7 @@ class LanceTable(Table):
def merge( def merge(
self, self,
other_table: Union[LanceTable, ReaderLike], other_table: Union[LanceTable, DATA],
left_on: str, left_on: str,
right_on: Optional[str] = None, right_on: Optional[str] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
@@ -1890,12 +1912,13 @@ class LanceTable(Table):
1 2 b e 1 2 b e
2 3 c f 2 3 c f
""" """
if isinstance(schema, LanceModel):
schema = schema.to_arrow_schema()
if isinstance(other_table, LanceTable): if isinstance(other_table, LanceTable):
other_table = other_table.to_lance() other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset): else:
other_table = other_table.to_table() other_table = _sanitize_data(
other_table,
schema,
)
self.to_lance().merge( self.to_lance().merge(
other_table, left_on=left_on, right_on=right_on, schema=schema other_table, left_on=left_on, right_on=right_on, schema=schema
) )
@@ -2043,7 +2066,7 @@ class LanceTable(Table):
query_type, query_type,
vector_column_name=vector_column_name, vector_column_name=vector_column_name,
ordering_field_name=ordering_field_name, ordering_field_name=ordering_field_name,
fts_columns=fts_columns, fts_columns=fts_columns or [],
) )
@classmethod @classmethod
@@ -2053,13 +2076,13 @@ class LanceTable(Table):
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None, schema: Optional[pa.Schema] = None,
mode: Literal["create", "overwrite"] = "create", mode: CreateMode = "create",
exist_ok: bool = False, exist_ok: bool = False,
on_bad_vectors: str = "error", on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*, *,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str | bool]] = None,
data_storage_version: Optional[str] = None, data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None, enable_v2_manifest_paths: Optional[bool] = None,
): ):
@@ -2213,17 +2236,22 @@ class LanceTable(Table):
self, self,
merge: LanceMergeInsertBuilder, merge: LanceMergeInsertBuilder,
new_data: DATA, new_data: DATA,
on_bad_vectors: str, on_bad_vectors: OnBadVectorsType,
fill_value: float, fill_value: float,
): ):
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)) LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
@deprecation.deprecated(
deprecated_in="0.21.0",
current_version=__version__,
details="Use `Table.optimize` instead.",
)
def cleanup_old_versions( def cleanup_old_versions(
self, self,
older_than: Optional[timedelta] = None, older_than: Optional[timedelta] = None,
*, *,
delete_unverified: bool = False, delete_unverified: bool = False,
) -> CleanupStats: ) -> "CleanupStats":
""" """
Clean up old versions of the table, freeing disk space. Clean up old versions of the table, freeing disk space.
@@ -2248,6 +2276,11 @@ class LanceTable(Table):
older_than, delete_unverified=delete_unverified older_than, delete_unverified=delete_unverified
) )
@deprecation.deprecated(
deprecated_in="0.21.0",
current_version=__version__,
details="Use `Table.optimize` instead.",
)
def compact_files(self, *args, **kwargs) -> CompactionStats: def compact_files(self, *args, **kwargs) -> CompactionStats:
""" """
Run the compaction process on the table. Run the compaction process on the table.
@@ -2379,6 +2412,19 @@ class LanceTable(Table):
""" """
LOOP.run(self._table.migrate_v2_manifest_paths()) LOOP.run(self._table.migrate_v2_manifest_paths())
def replace_field_metadata(self, field_name: str, new_metadata: Dict[str, str]):
"""
Replace the metadata of a field in the schema
Parameters
----------
field_name: str
The name of the field to replace the metadata for
new_metadata: dict
The new metadata to set
"""
LOOP.run(self._table.replace_field_metadata(field_name, new_metadata))
def _handle_bad_vectors( def _handle_bad_vectors(
reader: pa.RecordBatchReader, reader: pa.RecordBatchReader,
@@ -2679,7 +2725,7 @@ class AsyncTable:
self.close() self.close()
def is_open(self) -> bool: def is_open(self) -> bool:
"""Return True if the table is closed.""" """Return True if the table is open."""
return self._inner.is_open() return self._inner.is_open()
def close(self): def close(self):
@@ -2702,6 +2748,19 @@ class AsyncTable:
""" """
return await self._inner.schema() return await self._inner.schema()
async def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
"""
Get the embedding functions for the table
Returns
-------
funcs: Dict[str, EmbeddingFunctionConfig]
A mapping of the vector column to the embedding function
or empty dict if not configured.
"""
schema = await self.schema()
return EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
async def count_rows(self, filter: Optional[str] = None) -> int: async def count_rows(self, filter: Optional[str] = None) -> int:
""" """
Count the number of rows in the table. Count the number of rows in the table.
@@ -2828,7 +2887,7 @@ class AsyncTable:
data: DATA, data: DATA,
*, *,
mode: Optional[Literal["append", "overwrite"]] = "append", mode: Optional[Literal["append", "overwrite"]] = "append",
on_bad_vectors: Optional[str] = None, on_bad_vectors: Optional[OnBadVectorsType] = None,
fill_value: Optional[float] = None, fill_value: Optional[float] = None,
): ):
"""Add more data to the [Table](Table). """Add more data to the [Table](Table).
@@ -2931,6 +2990,234 @@ class AsyncTable:
return LanceMergeInsertBuilder(self, on) return LanceMergeInsertBuilder(self, on)
@overload
async def search(
self,
query: Optional[str] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["auto"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ...
@overload
async def search(
self,
query: Optional[str] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["hybrid"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncHybridQuery: ...
@overload
async def search(
self,
query: Optional[Union[VEC, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["auto"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncVectorQuery: ...
@overload
async def search(
self,
query: Optional[str] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["fts"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncFTSQuery: ...
@overload
async def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["vector"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncVectorQuery: ...
async def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
All query options are defined in [AsyncQuery][lancedb.query.AsyncQuery].
Parameters
----------
query: list/np.ndarray/str/PIL.Image.Image, default None
The targetted vector to search for.
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, optional
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", "hybrid", or "auto"
- If "auto" then the query type is inferred from the query;
- If `query` is a list/np.ndarray then the query type is
"vector";
- If `query` is a PIL.Image.Image then either do vector search,
or raise an error if no corresponding embedding function is found.
- If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
"""
def is_embedding(query):
return isinstance(query, (list, np.ndarray, pa.Array, pa.ChunkedArray))
async def get_embedding_func(
vector_column_name: Optional[str],
query_type: QueryType,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]],
) -> Tuple[str, EmbeddingFunctionConfig]:
schema = await self.schema()
vector_column_name = infer_vector_column_name(
schema=schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(
schema.metadata
)
func = funcs.get(vector_column_name)
if func is None:
error = ValueError(
f"Column '{vector_column_name}' has no registered "
"embedding function."
)
if len(funcs) > 0:
add_note(
error,
"Embedding functions are registered for columns: "
f"{list(funcs.keys())}",
)
else:
add_note(
error, "No embedding functions are registered for any columns."
)
raise error
return vector_column_name, func
async def make_embedding(embedding, query):
if embedding is not None:
loop = asyncio.get_running_loop()
# This function is likely to block, since it either calls an expensive
# function or makes an HTTP request to an embeddings REST API.
return (
await loop.run_in_executor(
None,
embedding.function.compute_query_embeddings_with_retry,
query,
)
)[0]
else:
return None
if query_type == "auto":
# Infer the query type.
if is_embedding(query):
vector_query = query
query_type = "vector"
elif isinstance(query, str):
try:
(
indices,
(vector_column_name, embedding_conf),
) = await asyncio.gather(
self.list_indices(),
get_embedding_func(vector_column_name, "auto", query),
)
except ValueError as e:
if "Column" in str(
e
) and "has no registered embedding function" in str(e):
# If the column has no registered embedding function,
# then it's an FTS query.
query_type = "fts"
else:
raise e
else:
if embedding_conf is not None:
vector_query = await make_embedding(embedding_conf, query)
if any(
i.columns[0] == embedding_conf.source_column
and i.index_type == "FTS"
for i in indices
):
query_type = "hybrid"
else:
query_type = "vector"
else:
query_type = "fts"
else:
# it's an image or something else embeddable.
query_type = "vector"
elif query_type == "vector":
if is_embedding(query):
vector_query = query
else:
vector_column_name, embedding_conf = await get_embedding_func(
vector_column_name, query_type, query
)
vector_query = await make_embedding(embedding_conf, query)
elif query_type == "hybrid":
if is_embedding(query):
raise ValueError("Hybrid search requires a text query")
else:
vector_column_name, embedding_conf = await get_embedding_func(
vector_column_name, query_type, query
)
vector_query = await make_embedding(embedding_conf, query)
if query_type == "vector":
builder = self.query().nearest_to(vector_query)
if vector_column_name:
builder = builder.column(vector_column_name)
return builder
elif query_type == "fts":
return self.query().nearest_to_text(query, columns=fts_columns or [])
elif query_type == "hybrid":
builder = self.query().nearest_to(vector_query)
if vector_column_name:
builder = builder.column(vector_column_name)
return builder.nearest_to_text(query, columns=fts_columns or [])
else:
raise ValueError(f"Unknown query type: '{query_type}'")
def vector_search( def vector_search(
self, self,
query_vector: Union[VEC, Tuple], query_vector: Union[VEC, Tuple],
@@ -2950,7 +3237,9 @@ class AsyncTable:
# The sync remote table calls into this method, so we need to map the # The sync remote table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only # query to the async version of the query and run that here. This is only
# used for that code path right now. # used for that code path right now.
async_query = self.query().limit(query.k) async_query = self.query()
if query.k is not None:
async_query = async_query.limit(query.k)
if query.offset > 0: if query.offset > 0:
async_query = async_query.offset(query.offset) async_query = async_query.offset(query.offset)
if query.columns: if query.columns:
@@ -2997,7 +3286,7 @@ class AsyncTable:
self, self,
merge: LanceMergeInsertBuilder, merge: LanceMergeInsertBuilder,
new_data: DATA, new_data: DATA,
on_bad_vectors: str, on_bad_vectors: OnBadVectorsType,
fill_value: float, fill_value: float,
): ):
schema = await self.schema() schema = await self.schema()
@@ -3366,6 +3655,21 @@ class AsyncTable:
""" """
await self._inner.migrate_manifest_paths_v2() await self._inner.migrate_manifest_paths_v2()
async def replace_field_metadata(
self, field_name: str, new_metadata: dict[str, str]
):
"""
Replace the metadata of a field in the schema
Parameters
----------
field_name: str
The name of the field to replace the metadata for
new_metadata: dict
The new metadata to set
"""
await self._inner.replace_field_metadata(field_name, new_metadata)
@dataclass @dataclass
class IndexStatistics: class IndexStatistics:

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import Literal
# Query type literals
QueryType = Literal["vector", "fts", "hybrid", "auto"]
# Distance type literals
DistanceType = Literal["l2", "cosine", "dot"]
DistanceTypeWithHamming = Literal["l2", "cosine", "dot", "hamming"]
# Vector handling literals
OnBadVectorsType = Literal["error", "drop", "fill", "null"]
# Mode literals
AddMode = Literal["append", "overwrite"]
CreateMode = Literal["create", "overwrite"]
# Index type literals
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
IndexType = Literal[
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
]
# Tokenizer literals
BaseTokenizerType = Literal["simple", "raw", "whitespace"]

View File

@@ -75,6 +75,6 @@ async def test_binary_vector_async():
query = np.random.randint(0, 2, size=256) query = np.random.randint(0, 2, size=256)
packed_query = np.packbits(query) packed_query = np.packbits(query)
await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow() await (await tbl.search(packed_query)).distance_type("hamming").to_arrow()
# --8<-- [end:async_binary_vector] # --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors") await db.drop_table("my_binary_vectors")

View File

@@ -53,13 +53,13 @@ async def test_binary_vector_async():
query = np.random.random(256) query = np.random.random(256)
# Search for the vectors within the range of [0.1, 0.5) # Search for the vectors within the range of [0.1, 0.5)
await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow() await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow()
# Search for the vectors with the distance less than 0.5 # Search for the vectors with the distance less than 0.5
await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow() await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow()
# Search for the vectors with the distance greater or equal to 0.1 # Search for the vectors with the distance greater or equal to 0.1
await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow() await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow()
# --8<-- [end:async_distance_range] # --8<-- [end:async_distance_range]
await db.drop_table("my_table") await db.drop_table("my_table")

View File

@@ -28,3 +28,49 @@ def test_embeddings_openai():
actual = table.search(query).limit(1).to_pydantic(Words)[0] actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text) print(actual.text)
# --8<-- [end:openai_embeddings] # --8<-- [end:openai_embeddings]
@pytest.mark.slow
@pytest.mark.asyncio
async def test_embeddings_openai_async():
uri = "memory://"
# --8<-- [start:async_openai_embeddings]
db = await lancedb.connect_async(uri)
func = get_registry().get("openai").create(name="text-embedding-ada-002")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = await db.create_table("words", schema=Words, mode="overwrite")
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:async_openai_embeddings]
def test_embeddings_secret():
# --8<-- [start:register_secret]
registry = get_registry()
registry.set_var("api_key", "sk-...")
func = registry.get("openai").create(api_key="$var:api_key")
# --8<-- [end:register_secret]
try:
import torch
except ImportError:
pytest.skip("torch not installed")
# --8<-- [start:register_device]
import torch
registry = get_registry()
if torch.cuda.is_available():
registry.set_var("device", "cuda")
func = registry.get("huggingface").create(device="$var:device:cpu")
# --8<-- [end:register_device]
assert func.device == "cuda" if torch.cuda.is_available() else "cpu"

View File

@@ -72,8 +72,7 @@ async def test_ann_index_async():
# --8<-- [end:create_ann_index_async] # --8<-- [end:create_ann_index_async]
# --8<-- [start:vector_search_async] # --8<-- [start:vector_search_async]
await ( await (
async_tbl.query() (await async_tbl.search(np.random.random((32))))
.nearest_to(np.random.random((32)))
.limit(2) .limit(2)
.nprobes(20) .nprobes(20)
.refine_factor(10) .refine_factor(10)
@@ -82,18 +81,14 @@ async def test_ann_index_async():
# --8<-- [end:vector_search_async] # --8<-- [end:vector_search_async]
# --8<-- [start:vector_search_async_with_filter] # --8<-- [start:vector_search_async_with_filter]
await ( await (
async_tbl.query() (await async_tbl.search(np.random.random((32))))
.nearest_to(np.random.random((32)))
.where("item != 'item 1141'") .where("item != 'item 1141'")
.to_pandas() .to_pandas()
) )
# --8<-- [end:vector_search_async_with_filter] # --8<-- [end:vector_search_async_with_filter]
# --8<-- [start:vector_search_async_with_select] # --8<-- [start:vector_search_async_with_select]
await ( await (
async_tbl.query() (await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas()
.nearest_to(np.random.random((32)))
.select(["vector"])
.to_pandas()
) )
# --8<-- [end:vector_search_async_with_select] # --8<-- [end:vector_search_async_with_select]
@@ -164,7 +159,7 @@ async def test_scalar_index_async():
{"book_id": 3, "vector": [5.0, 6]}, {"book_id": 3, "vector": [5.0, 6]},
] ]
async_tbl = await async_db.create_table("book_with_embeddings_async", data) async_tbl = await async_db.create_table("book_with_embeddings_async", data)
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas()) (await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas())
# --8<-- [end:vector_search_with_scalar_index_async] # --8<-- [end:vector_search_with_scalar_index_async]
# --8<-- [start:update_scalar_index_async] # --8<-- [start:update_scalar_index_async]
await async_tbl.add([{"vector": [7, 8], "book_id": 4}]) await async_tbl.add([{"vector": [7, 8], "book_id": 4}])

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
# --8<-- [start:imports]
import lancedb
from lancedb.pydantic import Vector, LanceModel
# --8<-- [end:imports]
def test_pydantic_model(tmp_path):
# --8<-- [start:base_model]
class PersonModel(LanceModel):
name: str
age: int
vector: Vector(2)
# --8<-- [end:base_model]
# --8<-- [start:set_url]
url = "./example"
# --8<-- [end:set_url]
url = tmp_path
# --8<-- [start:base_example]
db = lancedb.connect(url)
table = db.create_table("person", schema=PersonModel)
table.add(
[
PersonModel(name="bob", age=1, vector=[1.0, 2.0]),
PersonModel(name="alice", age=2, vector=[3.0, 4.0]),
]
)
assert table.count_rows() == 2
person = table.search([0.0, 0.0]).limit(1).to_pydantic(PersonModel)
assert person[0].name == "bob"
# --8<-- [end:base_example]

View File

@@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async():
query_vector = [100, 100] query_vector = [100, 100]
# Pandas DataFrame # Pandas DataFrame
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas() df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
print(df) print(df)
# --8<-- [end:vector_search_async] # --8<-- [end:vector_search_async]
# --8<-- [start:vector_search_with_filter_async] # --8<-- [start:vector_search_with_filter_async]
# Apply the filter via LanceDB # Apply the filter via LanceDB
results = ( results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
)
assert len(results) == 1 assert len(results) == 1
assert results["item"].iloc[0] == "foo" assert results["item"].iloc[0] == "foo"
# Apply the filter via Pandas # Apply the filter via Pandas
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas() df = results = await (await async_tbl.search([100, 100])).to_pandas()
results = df[df.price < 15] results = df[df.price < 15]
assert len(results) == 1 assert len(results) == 1
assert results["item"].iloc[0] == "foo" assert results["item"].iloc[0] == "foo"
@@ -188,3 +186,26 @@ def test_polars():
# --8<-- [start:print_table_lazyform] # --8<-- [start:print_table_lazyform]
print(ldf.first().collect()) print(ldf.first().collect())
# --8<-- [end:print_table_lazyform] # --8<-- [end:print_table_lazyform]
@pytest.mark.asyncio
async def test_polars_async():
uri = "data/sample-lancedb"
db = await lancedb.connect_async(uri)
# --8<-- [start:create_table_polars_async]
data = pl.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
table = await db.create_table("pl_table_async", data=data)
# --8<-- [end:create_table_polars_async]
# --8<-- [start:vector_search_polars_async]
query = [3.0, 4.0]
result = await (await table.search(query)).limit(1).to_polars()
print(result)
print(type(result))
# --8<-- [end:vector_search_polars_async]

View File

@@ -117,12 +117,11 @@ async def test_vector_search_async():
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32")) for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
] ]
async_tbl = await async_db.create_table("vector_search_async", data=data) async_tbl = await async_db.create_table("vector_search_async", data=data)
(await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list()) (await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list())
# --8<-- [end:exhaustive_search_async] # --8<-- [end:exhaustive_search_async]
# --8<-- [start:exhaustive_search_async_cosine] # --8<-- [start:exhaustive_search_async_cosine]
( (
await async_tbl.query() await (await async_tbl.search(np.random.random((1536))))
.nearest_to(np.random.random((1536)))
.distance_type("cosine") .distance_type("cosine")
.limit(10) .limit(10)
.to_list() .to_list()
@@ -145,13 +144,13 @@ async def test_vector_search_async():
async_tbl = await async_db.create_table("documents_async", data=data) async_tbl = await async_db.create_table("documents_async", data=data)
# --8<-- [end:create_table_async_with_nested_schema] # --8<-- [end:create_table_async_with_nested_schema]
# --8<-- [start:search_result_async_as_pyarrow] # --8<-- [start:search_result_async_as_pyarrow]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow() await (await async_tbl.search(np.random.randn(1536))).to_arrow()
# --8<-- [end:search_result_async_as_pyarrow] # --8<-- [end:search_result_async_as_pyarrow]
# --8<-- [start:search_result_async_as_pandas] # --8<-- [start:search_result_async_as_pandas]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas() await (await async_tbl.search(np.random.randn(1536))).to_pandas()
# --8<-- [end:search_result_async_as_pandas] # --8<-- [end:search_result_async_as_pandas]
# --8<-- [start:search_result_async_as_list] # --8<-- [start:search_result_async_as_list]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_list() await (await async_tbl.search(np.random.randn(1536))).to_list()
# --8<-- [end:search_result_async_as_list] # --8<-- [end:search_result_async_as_list]
@@ -219,9 +218,7 @@ async def test_fts_native_async():
# async API uses our native FTS algorithm # async API uses our native FTS algorithm
await async_tbl.create_index("text", config=FTS()) await async_tbl.create_index("text", config=FTS())
await ( await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list()
async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list()
)
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}] # [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
# ... # ...
# --8<-- [end:basic_fts_async] # --8<-- [end:basic_fts_async]
@@ -235,18 +232,11 @@ async def test_fts_native_async():
) )
# --8<-- [end:fts_config_folding_async] # --8<-- [end:fts_config_folding_async]
# --8<-- [start:fts_prefiltering_async] # --8<-- [start:fts_prefiltering_async]
await ( await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list()
async_tbl.query()
.nearest_to_text("puppy")
.limit(10)
.where("text='foo'")
.to_list()
)
# --8<-- [end:fts_prefiltering_async] # --8<-- [end:fts_prefiltering_async]
# --8<-- [start:fts_postfiltering_async] # --8<-- [start:fts_postfiltering_async]
await ( await (
async_tbl.query() (await async_tbl.search("puppy"))
.nearest_to_text("puppy")
.limit(10) .limit(10)
.where("text='foo'") .where("text='foo'")
.postfilter() .postfilter()
@@ -347,14 +337,8 @@ async def test_hybrid_search_async():
# Create a fts index before the hybrid search # Create a fts index before the hybrid search
await async_tbl.create_index("text", config=FTS()) await async_tbl.create_index("text", config=FTS())
text_query = "flower moon" text_query = "flower moon"
vector_query = embeddings.compute_query_embeddings(text_query)[0]
# hybrid search with default re-ranker # hybrid search with default re-ranker
await ( await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
async_tbl.query()
.nearest_to(vector_query)
.nearest_to_text(text_query)
.to_pandas()
)
# --8<-- [end:basic_hybrid_search_async] # --8<-- [end:basic_hybrid_search_async]
# --8<-- [start:hybrid_search_pass_vector_text_async] # --8<-- [start:hybrid_search_pass_vector_text_async]
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5] vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Union import os
from typing import List, Optional, Union
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import lance import lance
@@ -56,7 +57,7 @@ def test_embedding_function(tmp_path):
conf = EmbeddingFunctionConfig( conf = EmbeddingFunctionConfig(
source_column="text", source_column="text",
vector_column="vector", vector_column="vector",
function=MockTextEmbeddingFunction(), function=MockTextEmbeddingFunction.create(),
) )
metadata = registry.get_table_metadata([conf]) metadata = registry.get_table_metadata([conf])
table = table.replace_schema_metadata(metadata) table = table.replace_schema_metadata(metadata)
@@ -80,6 +81,57 @@ def test_embedding_function(tmp_path):
assert np.allclose(actual, expected) assert np.allclose(actual, expected)
def test_embedding_function_variables():
@register("variable-testing")
class VariableTestingFunction(TextEmbeddingFunction):
key1: str
secret_key: Optional[str] = None
@staticmethod
def sensitive_keys():
return ["secret_key"]
def ndims():
pass
def generate_embeddings(self, _texts):
pass
registry = EmbeddingFunctionRegistry.get_instance()
# Should error if variable is not set
with pytest.raises(ValueError, match="Variable 'test' not found"):
registry.get("variable-testing").create(
key1="$var:test",
)
# Should use default values if not set
func = registry.get("variable-testing").create(key1="$var:test:some_value")
assert func.key1 == "some_value"
# Should set a variable that the embedding function understands
registry.set_var("test", "some_value")
func = registry.get("variable-testing").create(key1="$var:test")
assert func.key1 == "some_value"
# Should reject secrets that aren't passed in as variables
with pytest.raises(
ValueError,
match="Sensitive key 'secret_key' cannot be set to a hardcoded value",
):
registry.get("variable-testing").create(
key1="whatever", secret_key="some_value"
)
# Should not serialize secrets.
registry.set_var("secret", "secret_value")
func = registry.get("variable-testing").create(
key1="whatever", secret_key="$var:secret"
)
assert func.secret_key == "secret_value"
assert func.safe_model_dump()["secret_key"] == "$var:secret"
def test_embedding_with_bad_results(tmp_path): def test_embedding_with_bad_results(tmp_path):
@register("null-embedding") @register("null-embedding")
class NullEmbeddingFunction(TextEmbeddingFunction): class NullEmbeddingFunction(TextEmbeddingFunction):
@@ -91,9 +143,11 @@ def test_embedding_with_bad_results(tmp_path):
) -> list[Union[np.array, None]]: ) -> list[Union[np.array, None]]:
# Return None, which is bad if field is non-nullable # Return None, which is bad if field is non-nullable
a = [ a = [
np.full(self.ndims(), np.nan) (
if i % 2 == 0 np.full(self.ndims(), np.nan)
else np.random.randn(self.ndims()) if i % 2 == 0
else np.random.randn(self.ndims())
)
for i in range(len(texts)) for i in range(len(texts))
] ]
return a return a
@@ -341,6 +395,7 @@ def test_add_optional_vector(tmp_path):
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all() assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()
@pytest.mark.slow
@pytest.mark.parametrize( @pytest.mark.parametrize(
"embedding_type", "embedding_type",
[ [
@@ -358,23 +413,23 @@ def test_embedding_function_safe_model_dump(embedding_type):
# Note: Some embedding types might require specific parameters # Note: Some embedding types might require specific parameters
try: try:
model = registry.get(embedding_type).create() model = registry.get(embedding_type).create({"max_retries": 1})
except Exception as e: except Exception as e:
pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}") pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}")
dumped_model = model.safe_model_dump() dumped_model = model.safe_model_dump()
assert all( assert all(not k.startswith("_") for k in dumped_model.keys()), (
not k.startswith("_") for k in dumped_model.keys() f"{embedding_type}: Dumped model contains keys starting with underscore"
), f"{embedding_type}: Dumped model contains keys starting with underscore" )
assert ( assert "max_retries" in dumped_model, (
"max_retries" in dumped_model f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model" )
assert isinstance( assert isinstance(dumped_model, dict), (
dumped_model, dict f"{embedding_type}: Dumped model is not a dictionary"
), f"{embedding_type}: Dumped model is not a dictionary" )
for key in model.__dict__: for key in model.__dict__:
if key.startswith("_"): if key.startswith("_"):
@@ -391,3 +446,33 @@ def test_retry(mock_sleep):
result = test_function() result = test_function()
assert mock_sleep.call_count == 9 assert mock_sleep.call_count == 9
assert result == "result" assert result == "result"
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OpenAI API key not set"
)
def test_openai_propagates_api_key(monkeypatch):
# Make sure that if we set it as a variable, the API key is propagated
api_key = os.environ["OPENAI_API_KEY"]
monkeypatch.delenv("OPENAI_API_KEY")
uri = "memory://"
registry = get_registry()
registry.set_var("open_api_key", api_key)
func = registry.get("openai").create(
name="text-embedding-ada-002",
max_retries=0,
api_key="$var:open_api_key",
)
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
db = lancedb.connect(uri)
table = db.create_table("words", schema=Words, mode="overwrite")
table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
assert len(actual.text) > 0

View File

@@ -174,6 +174,10 @@ def test_search_fts(table, use_tantivy):
assert len(results) == 5 assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score assert len(results[0]) == 3 # id, text, _score
# Default limit of 10
results = table.search("puppy").select(["id", "text"]).to_list()
assert len(results) == 10
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fts_select_async(async_table): async def test_fts_select_async(async_table):

View File

@@ -129,6 +129,6 @@ def test_normalize_scores():
if invert: if invert:
expected = pc.subtract(1.0, expected) expected = pc.subtract(1.0, expected)
assert pc.equal( assert pc.equal(result, expected), (
result, expected f"Expected {expected} but got {result} for invert={invert}"
), f"Expected {expected} but got {result} for invert={invert}" )

View File

@@ -10,6 +10,7 @@ import pyarrow as pa
import pydantic import pydantic
import pytest import pytest
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from pydantic import BaseModel
from pydantic import Field from pydantic import Field
@@ -252,3 +253,104 @@ def test_lance_model():
t = TestModel() t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3]) assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
def test_optional_nested_model():
class WAMedia(BaseModel):
url: str
mimetype: str
filename: Optional[str]
error: Optional[str]
data: bytes
class WALocation(BaseModel):
description: Optional[str]
latitude: str
longitude: str
class ReplyToMessage(BaseModel):
id: str
participant: str
body: str
class Message(BaseModel):
id: str
timestamp: int
from_: str
fromMe: bool
to: str
body: str
hasMedia: Optional[bool]
media: WAMedia
mediaUrl: Optional[str]
ack: Optional[int]
ackName: Optional[str]
author: Optional[str]
location: Optional[WALocation]
vCards: Optional[List[str]]
replyTo: Optional[ReplyToMessage]
class AnyEvent(LanceModel):
id: str
session: str
metadata: Optional[str] = None
engine: str
event: str
class MessageEvent(AnyEvent):
payload: Message
schema = pydantic_to_schema(MessageEvent)
payload = schema.field("payload")
assert payload.type == pa.struct(
[
pa.field("id", pa.utf8(), False),
pa.field("timestamp", pa.int64(), False),
pa.field("from_", pa.utf8(), False),
pa.field("fromMe", pa.bool_(), False),
pa.field("to", pa.utf8(), False),
pa.field("body", pa.utf8(), False),
pa.field("hasMedia", pa.bool_(), True),
pa.field(
"media",
pa.struct(
[
pa.field("url", pa.utf8(), False),
pa.field("mimetype", pa.utf8(), False),
pa.field("filename", pa.utf8(), True),
pa.field("error", pa.utf8(), True),
pa.field("data", pa.binary(), False),
]
),
False,
),
pa.field("mediaUrl", pa.utf8(), True),
pa.field("ack", pa.int64(), True),
pa.field("ackName", pa.utf8(), True),
pa.field("author", pa.utf8(), True),
pa.field(
"location",
pa.struct(
[
pa.field("description", pa.utf8(), True),
pa.field("latitude", pa.utf8(), False),
pa.field("longitude", pa.utf8(), False),
]
),
True, # Optional
),
pa.field("vCards", pa.list_(pa.utf8()), True),
pa.field(
"replyTo",
pa.struct(
[
pa.field("id", pa.utf8(), False),
pa.field("participant", pa.utf8(), False),
pa.field("body", pa.utf8(), False),
]
),
True,
),
]
)

View File

@@ -1,25 +1,35 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Union
import unittest.mock as mock import unittest.mock as mock
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
import lancedb import lancedb
from lancedb.index import IvfPq, FTS from lancedb.db import AsyncConnection
from lancedb.rerankers.cross_encoder import CrossEncoderReranker from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import get_registry, register
from lancedb.index import FTS, IvfPq
import lancedb.pydantic
import numpy as np import numpy as np
import pandas.testing as tm import pandas.testing as tm
import pyarrow as pa import pyarrow as pa
import pyarrow.compute as pc
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from lancedb.query import ( from lancedb.query import (
AsyncFTSQuery,
AsyncHybridQuery,
AsyncQueryBase, AsyncQueryBase,
AsyncVectorQuery,
LanceVectorQueryBuilder, LanceVectorQueryBuilder,
Query, Query,
) )
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable from lancedb.table import AsyncTable, LanceTable
from utils import exception_output
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@@ -232,6 +242,71 @@ async def test_distance_range_async(table_async: AsyncTable):
assert res["_distance"].to_pylist() == [min_dist, max_dist] assert res["_distance"].to_pylist() == [min_dist, max_dist]
@pytest.mark.asyncio
async def test_distance_range_with_new_rows_async():
conn = await lancedb.connect_async(
"memory://", read_consistency_interval=timedelta(seconds=0)
)
data = pa.table(
{
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(
np.random.rand(256, 2)
),
}
)
table = await conn.create_table("test", data)
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
q = [0, 0]
rs = await table.query().nearest_to(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
# append more rows so that execution plan would be mixed with ANN & Flat KNN
new_data = pa.table(
{
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(4, 2)),
}
)
await table.add(new_data)
res = (
await table.query()
.nearest_to(q)
.distance_range(upper_bound=min_dist)
.to_arrow()
)
assert len(res) == 0
res = (
await table.query()
.nearest_to(q)
.distance_range(lower_bound=max_dist)
.to_arrow()
)
for dist in res["_distance"].to_pylist():
assert dist >= max_dist
res = (
await table.query()
.nearest_to(q)
.distance_range(upper_bound=max_dist)
.to_arrow()
)
for dist in res["_distance"].to_pylist():
assert dist < max_dist
res = (
await table.query()
.nearest_to(q)
.distance_range(lower_bound=min_dist)
.to_arrow()
)
for dist in res["_distance"].to_pylist():
assert dist >= min_dist
@pytest.mark.parametrize( @pytest.mark.parametrize(
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True "multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
) )
@@ -651,3 +726,100 @@ async def test_query_with_f16(tmp_path: Path):
tbl = await db.create_table("test", df) tbl = await db.create_table("test", df)
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas() results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
assert len(results) == 2 assert len(results) == 2
@pytest.mark.asyncio
async def test_query_search_auto(mem_db_async: AsyncConnection):
nrows = 1000
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
}
)
@register("test2")
class TestEmbedding(TextEmbeddingFunction):
def ndims(self):
return 4
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
embeddings = []
for text in texts:
vec = np.array([float(text) / 1000] * self.ndims())
embeddings.append(vec)
return embeddings
registry = get_registry()
func = registry.get("test2").create()
class TestModel(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
funcs = await tbl.embedding_functions()
assert len(funcs) == 1
# No FTS or vector index
# Search for vector -> vector query
q = [0.1] * 4
query = await tbl.search(q)
assert isinstance(query, AsyncVectorQuery)
# Search for string -> vector query
query = await tbl.search("0.1")
assert isinstance(query, AsyncVectorQuery)
await tbl.create_index("text", config=FTS())
query = await tbl.search("0.1")
assert isinstance(query, AsyncHybridQuery)
data_with_vecs = await tbl.to_arrow()
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
with pytest.raises(
Exception,
match=(
"Cannot perform full text search unless an INVERTED index has been created"
),
):
query = await (await tbl2.search("0.1")).to_arrow()
@pytest.mark.asyncio
async def test_query_search_specified(mem_db_async: AsyncConnection):
nrows, ndims = 1000, 16
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(nrows * ndims).cast(pa.float32()), ndims
),
}
)
table = await mem_db_async.create_table("test", data)
await table.create_index("text", config=FTS())
# Validate that specifying fts, vector or hybrid gets the right query.
q = [0.1] * ndims
query = await table.search(q, query_type="vector")
assert isinstance(query, AsyncVectorQuery)
query = await table.search("0.1", query_type="fts")
assert isinstance(query, AsyncFTSQuery)
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
await table.search("0.1", query_type="foo")
with pytest.raises(
ValueError, match="Column 'vector' has no registered embedding function"
) as e:
await table.search("0.1", query_type="vector")
assert "No embedding functions are registered for any columns" in exception_output(
e
)

View File

@@ -9,6 +9,7 @@ import json
import threading import threading
from unittest.mock import MagicMock from unittest.mock import MagicMock
import uuid import uuid
from packaging.version import Version
import lancedb import lancedb
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
@@ -32,15 +33,16 @@ def make_mock_http_handler(handler):
@contextlib.contextmanager @contextlib.contextmanager
def mock_lancedb_connection(handler): def mock_lancedb_connection(handler):
with http.server.HTTPServer( with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler) ("localhost", 0), make_mock_http_handler(handler)
) as server: ) as server:
port = server.server_address[1]
handle = threading.Thread(target=server.serve_forever) handle = threading.Thread(target=server.serve_forever)
handle.start() handle.start()
db = lancedb.connect( db = lancedb.connect(
"db://dev", "db://dev",
api_key="fake", api_key="fake",
host_override="http://localhost:8080", host_override=f"http://localhost:{port}",
client_config={ client_config={
"retry_config": {"retries": 2}, "retry_config": {"retries": 2},
"timeout_config": { "timeout_config": {
@@ -59,15 +61,16 @@ def mock_lancedb_connection(handler):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def mock_lancedb_connection_async(handler, **client_config): async def mock_lancedb_connection_async(handler, **client_config):
with http.server.HTTPServer( with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler) ("localhost", 0), make_mock_http_handler(handler)
) as server: ) as server:
port = server.server_address[1]
handle = threading.Thread(target=server.serve_forever) handle = threading.Thread(target=server.serve_forever)
handle.start() handle.start()
db = await lancedb.connect_async( db = await lancedb.connect_async(
"db://dev", "db://dev",
api_key="fake", api_key="fake",
host_override="http://localhost:8080", host_override=f"http://localhost:{port}",
client_config={ client_config={
"retry_config": {"retries": 2}, "retry_config": {"retries": 2},
"timeout_config": { "timeout_config": {
@@ -275,11 +278,12 @@ def test_table_create_indices():
@contextlib.contextmanager @contextlib.contextmanager
def query_test_table(query_handler): def query_test_table(query_handler, *, server_version=Version("0.1.0")):
def handler(request): def handler(request):
if request.path == "/v1/table/test/describe/": if request.path == "/v1/table/test/describe/":
request.send_response(200) request.send_response(200)
request.send_header("Content-Type", "application/json") request.send_header("Content-Type", "application/json")
request.send_header("phalanx-version", str(server_version))
request.end_headers() request.end_headers()
request.wfile.write(b"{}") request.wfile.write(b"{}")
elif request.path == "/v1/table/test/query/": elif request.path == "/v1/table/test/query/":
@@ -336,6 +340,7 @@ def test_query_sync_empty_query():
"filter": "true", "filter": "true",
"vector": [], "vector": [],
"columns": ["id"], "columns": ["id"],
"prefilter": False,
"version": None, "version": None,
} }
@@ -385,11 +390,25 @@ def test_query_sync_maximal():
) )
def test_query_sync_multiple_vectors(): @pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")])
def handler(_body): def test_query_sync_batch_queries(server_version):
return pa.table({"id": [1]}) def handler(body):
# TODO: we will add the ability to get the server version,
# so that we can decide how to perform batch quires.
vectors = body["vector"]
if server_version >= Version(
"0.2.0"
): # we can handle batch queries in single request since 0.2.0
assert len(vectors) == 2
res = []
for i, vector in enumerate(vectors):
res.append({"id": 1, "query_index": i})
return pa.Table.from_pylist(res)
else:
assert len(vectors) == 3 # matching dim
return pa.table({"id": [1]})
with query_test_table(handler) as table: with query_test_table(handler, server_version=server_version) as table:
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list() results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
assert len(results) == 2 assert len(results) == 2
results.sort(key=lambda x: x["query_index"]) results.sort(key=lambda x: x["query_index"])
@@ -404,6 +423,7 @@ def test_query_sync_fts():
"columns": [], "columns": [],
}, },
"k": 10, "k": 10,
"prefilter": True,
"vector": [], "vector": [],
"version": None, "version": None,
} }
@@ -421,6 +441,7 @@ def test_query_sync_fts():
}, },
"k": 42, "k": 42,
"vector": [], "vector": [],
"prefilter": True,
"with_row_id": True, "with_row_id": True,
"version": None, "version": None,
} }
@@ -447,6 +468,7 @@ def test_query_sync_hybrid():
}, },
"k": 42, "k": 42,
"vector": [], "vector": [],
"prefilter": True,
"with_row_id": True, "with_row_id": True,
"version": None, "version": None,
} }

View File

@@ -32,8 +32,8 @@ pytest.importorskip("lancedb.fts")
def get_test_table(tmp_path, use_tantivy): def get_test_table(tmp_path, use_tantivy):
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column # Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")() emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test")() meta_emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
class MyTable(LanceModel): class MyTable(LanceModel):
text: str = emb.SourceField() text: str = emb.SourceField()
@@ -131,9 +131,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
"represents the relevance of the result to the query & should " "represents the relevance of the result to the query & should "
"be descending." "be descending."
) )
assert np.all( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
np.diff(result.column("_relevance_score").to_numpy()) <= 0 ascending_relevance_err
), ascending_relevance_err )
# Vector search setting # Vector search setting
result = ( result = (
@@ -143,9 +143,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow() .to_arrow()
) )
assert len(result) == 30 assert len(result) == 30
assert np.all( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
np.diff(result.column("_relevance_score").to_numpy()) <= 0 ascending_relevance_err
), ascending_relevance_err )
result_explicit = ( result_explicit = (
table.search(query_vector, vector_column_name="vector") table.search(query_vector, vector_column_name="vector")
.rerank(reranker=reranker, query_string=query) .rerank(reranker=reranker, query_string=query)
@@ -168,9 +168,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow() .to_arrow()
) )
assert len(result) > 0 assert len(result) > 0
assert np.all( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
np.diff(result.column("_relevance_score").to_numpy()) <= 0 ascending_relevance_err
), ascending_relevance_err )
# empty FTS results # empty FTS results
query = "abcxyz" * 100 query = "abcxyz" * 100
@@ -185,9 +185,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
# should return _relevance_score column # should return _relevance_score column
assert "_relevance_score" in result.column_names assert "_relevance_score" in result.column_names
assert np.all( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
np.diff(result.column("_relevance_score").to_numpy()) <= 0 ascending_relevance_err
), ascending_relevance_err )
# Multi-vector search setting # Multi-vector search setting
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True) rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
@@ -262,9 +262,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
"represents the relevance of the result to the query & should " "represents the relevance of the result to the query & should "
"be descending." "be descending."
) )
assert np.all( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
np.diff(result.column("_relevance_score").to_numpy()) <= 0 ascending_relevance_err
), ascending_relevance_err )
# Test with empty FTS results # Test with empty FTS results
query = "abcxyz" * 100 query = "abcxyz" * 100
@@ -278,9 +278,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
) )
# should return _relevance_score column # should return _relevance_score column
assert "_relevance_score" in result.column_names assert "_relevance_score" in result.column_names
assert np.all( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
np.diff(result.column("_relevance_score").to_numpy()) <= 0 ascending_relevance_err
), ascending_relevance_err )
@pytest.mark.parametrize("use_tantivy", [True, False]) @pytest.mark.parametrize("use_tantivy", [True, False])
@@ -405,7 +405,9 @@ def test_answerdotai_reranker(tmp_path, use_tantivy):
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" os.environ.get("OPENAI_API_KEY") is None
or os.environ.get("OPENAI_BASE_URL") is not None,
reason="OPENAI_API_KEY not set",
) )
@pytest.mark.parametrize("use_tantivy", [True, False]) @pytest.mark.parametrize("use_tantivy", [True, False])
def test_openai_reranker(tmp_path, use_tantivy): def test_openai_reranker(tmp_path, use_tantivy):

View File

@@ -252,3 +252,27 @@ def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
db.drop_table("test_ddb_sync") db.drop_table("test_ddb_sync")
assert db.table_names() == [] assert db.table_names() == []
db.drop_database() db.drop_database()
@pytest.mark.s3_test
def test_s3_dynamodb_drop_all_tables(s3_bucket: str, commit_table: str, monkeypatch):
for key, value in CONFIG.items():
monkeypatch.setenv(key.upper(), value)
uri = f"s3+ddb://{s3_bucket}/test2?ddbTableName={commit_table}"
db = lancedb.connect(uri, read_consistency_interval=timedelta(0))
data = pa.table({"x": ["a", "b", "c"]})
db.create_table("foo", data)
db.create_table("bar", data)
assert db.table_names() == ["bar", "foo"]
# dropping all tables should clear multiple tables
db.drop_all_tables()
assert db.table_names() == []
# create a new table with the same name to ensure DDB is clean
db.create_table("foo", data)
assert db.table_names() == ["foo"]
db.drop_all_tables()

View File

@@ -887,7 +887,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
text: str text: str
vector: Vector(10) vector: Vector(10)
func = MockTextEmbeddingFunction() func = MockTextEmbeddingFunction.create()
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)}) df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
@@ -934,7 +934,7 @@ def test_create_f16_table(mem_db: DBConnection):
def test_add_with_embedding_function(mem_db: DBConnection): def test_add_with_embedding_function(mem_db: DBConnection):
emb = EmbeddingFunctionRegistry.get_instance().get("test")() emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
class MyTable(LanceModel): class MyTable(LanceModel):
text: str = emb.SourceField() text: str = emb.SourceField()
@@ -1025,13 +1025,13 @@ def test_empty_query(mem_db: DBConnection):
table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)]) table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)])
df = table.search().select(["id"]).to_pandas() df = table.search().select(["id"]).to_pandas()
assert len(df) == 10 assert len(df) == 100
# None is the same as default # None is the same as default
df = table.search().select(["id"]).limit(None).to_pandas() df = table.search().select(["id"]).limit(None).to_pandas()
assert len(df) == 10 assert len(df) == 100
# invalid limist is the same as None, wihch is the same as default # invalid limist is the same as None, wihch is the same as default
df = table.search().select(["id"]).limit(-1).to_pandas() df = table.search().select(["id"]).limit(-1).to_pandas()
assert len(df) == 10 assert len(df) == 100
# valid limit should work # valid limit should work
df = table.search().select(["id"]).limit(42).to_pandas() df = table.search().select(["id"]).limit(42).to_pandas()
assert len(df) == 42 assert len(df) == 42
@@ -1128,7 +1128,7 @@ def test_count_rows(mem_db: DBConnection):
def setup_hybrid_search_table(db: DBConnection, embedding_func): def setup_hybrid_search_table(db: DBConnection, embedding_func):
# Create a LanceDB table schema with a vector and a text column # Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)() emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func).create()
class MyTable(LanceModel): class MyTable(LanceModel):
text: str = emb.SourceField() text: str = emb.SourceField()
@@ -1481,3 +1481,12 @@ async def test_optimize_delete_unverified(tmp_db_async: AsyncConnection, tmp_pat
cleanup_older_than=timedelta(seconds=0), delete_unverified=True cleanup_older_than=timedelta(seconds=0), delete_unverified=True
) )
assert stats.prune.old_versions_removed == 2 assert stats.prune.old_versions_removed == 2
def test_replace_field_metadata(tmp_path):
db = lancedb.connect(tmp_path)
table = db.create_table("my_table", data=[{"x": 0}])
table.replace_field_metadata("x", {"foo": "bar"})
schema = table.schema
field = schema[0].metadata
assert field == {b"foo": b"bar"}

View File

@@ -127,7 +127,7 @@ def test_append_vector_columns():
conf = EmbeddingFunctionConfig( conf = EmbeddingFunctionConfig(
source_column="text", source_column="text",
vector_column="vector", vector_column="vector",
function=MockTextEmbeddingFunction(), function=MockTextEmbeddingFunction.create(),
) )
metadata = registry.get_table_metadata([conf]) metadata = registry.get_table_metadata([conf])
@@ -434,7 +434,7 @@ def test_sanitize_data(
conf = EmbeddingFunctionConfig( conf = EmbeddingFunctionConfig(
source_column="text", source_column="text",
vector_column="vector", vector_column="vector",
function=MockTextEmbeddingFunction(), function=MockTextEmbeddingFunction.create(),
) )
metadata = registry.get_table_metadata([conf]) metadata = registry.get_table_metadata([conf])
else: else:

View File

@@ -43,7 +43,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
} => Python::with_gil(|py| { } => Python::with_gil(|py| {
let message = err.to_string(); let message = err.to_string();
let http_err_cls = py let http_err_cls = py
.import_bound(intern!(py, "lancedb.remote.errors"))? .import(intern!(py, "lancedb.remote.errors"))?
.getattr(intern!(py, "HttpError"))?; .getattr(intern!(py, "HttpError"))?;
let err = http_err_cls.call1(( let err = http_err_cls.call1((
message, message,
@@ -63,7 +63,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
err.setattr(intern!(py, "__cause__"), cause_err)?; err.setattr(intern!(py, "__cause__"), cause_err)?;
} }
Err(PyErr::from_value_bound(err)) Err(PyErr::from_value(err))
}), }),
LanceError::Retry { LanceError::Retry {
request_id, request_id,
@@ -85,7 +85,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
let message = err.to_string(); let message = err.to_string();
let retry_error_cls = py let retry_error_cls = py
.import_bound(intern!(py, "lancedb.remote.errors"))? .import(intern!(py, "lancedb.remote.errors"))?
.getattr("RetryError")?; .getattr("RetryError")?;
let err = retry_error_cls.call1(( let err = retry_error_cls.call1((
message, message,
@@ -100,7 +100,7 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
))?; ))?;
err.setattr(intern!(py, "__cause__"), cause_err)?; err.setattr(intern!(py, "__cause__"), cause_err)?;
Err(PyErr::from_value_bound(err)) Err(PyErr::from_value(err))
}), }),
_ => self.runtime_error(), _ => self.runtime_error(),
}, },
@@ -127,18 +127,16 @@ fn http_from_rust_error(
status_code: Option<u16>, status_code: Option<u16>,
) -> PyResult<PyErr> { ) -> PyResult<PyErr> {
let message = err.to_string(); let message = err.to_string();
let http_err_cls = py let http_err_cls = py.import("lancedb.remote.errors")?.getattr("HttpError")?;
.import_bound("lancedb.remote.errors")?
.getattr("HttpError")?;
let py_err = http_err_cls.call1((message, request_id, status_code))?; let py_err = http_err_cls.call1((message, request_id, status_code))?;
// Reset the traceback since it doesn't provide additional information. // Reset the traceback since it doesn't provide additional information.
let py_err = py_err.call_method1(intern!(py, "with_traceback"), (PyNone::get_bound(py),))?; let py_err = py_err.call_method1(intern!(py, "with_traceback"), (PyNone::get(py),))?;
if let Some(cause) = err.source() { if let Some(cause) = err.source() {
let cause_err = http_from_rust_error(py, cause, request_id, status_code)?; let cause_err = http_from_rust_error(py, cause, request_id, status_code)?;
py_err.setattr(intern!(py, "__cause__"), cause_err)?; py_err.setattr(intern!(py, "__cause__"), cause_err)?;
} }
Ok(PyErr::from_value_bound(py_err)) Ok(PyErr::from_value(py_err))
} }

View File

@@ -7,29 +7,32 @@ use lancedb::index::{
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
Index as LanceDbIndex, Index as LanceDbIndex,
}; };
use pyo3::types::PyStringMethods;
use pyo3::IntoPyObject;
use pyo3::{ use pyo3::{
exceptions::{PyKeyError, PyValueError}, exceptions::{PyKeyError, PyValueError},
intern, pyclass, pymethods, intern, pyclass, pymethods,
types::PyAnyMethods, types::PyAnyMethods,
Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python, Bound, FromPyObject, PyAny, PyResult, Python,
}; };
use crate::util::parse_distance_type; use crate::util::parse_distance_type;
pub fn class_name<'a>(ob: &'a Bound<'_, PyAny>) -> PyResult<&'a str> { pub fn class_name(ob: &'_ Bound<'_, PyAny>) -> PyResult<String> {
let full_name: &str = ob let full_name = ob
.getattr(intern!(ob.py(), "__class__"))? .getattr(intern!(ob.py(), "__class__"))?
.getattr(intern!(ob.py(), "__name__"))? .getattr(intern!(ob.py(), "__name__"))?;
.extract()?; let full_name = full_name.downcast()?.to_string_lossy();
match full_name.rsplit_once('.') { match full_name.rsplit_once('.') {
Some((_, name)) => Ok(name), Some((_, name)) => Ok(name.to_string()),
None => Ok(full_name), None => Ok(full_name.to_string()),
} }
} }
pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<LanceDbIndex> { pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<LanceDbIndex> {
if let Some(source) = source { if let Some(source) = source {
match class_name(source)? { match class_name(source)?.as_str() {
"BTree" => Ok(LanceDbIndex::BTree(BTreeIndexBuilder::default())), "BTree" => Ok(LanceDbIndex::BTree(BTreeIndexBuilder::default())),
"Bitmap" => Ok(LanceDbIndex::Bitmap(Default::default())), "Bitmap" => Ok(LanceDbIndex::Bitmap(Default::default())),
"LabelList" => Ok(LanceDbIndex::LabelList(Default::default())), "LabelList" => Ok(LanceDbIndex::LabelList(Default::default())),
@@ -196,11 +199,11 @@ impl IndexConfig {
// For backwards-compatibility with the old sync SDK, we also support getting // For backwards-compatibility with the old sync SDK, we also support getting
// attributes via __getitem__. // attributes via __getitem__.
pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult<PyObject> { pub fn __getitem__<'a>(&self, key: String, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
match key.as_str() { match key.as_str() {
"index_type" => Ok(self.index_type.clone().into_py(py)), "index_type" => Ok(self.index_type.clone().into_pyobject(py)?.into_any()),
"columns" => Ok(self.columns.clone().into_py(py)), "columns" => Ok(self.columns.clone().into_pyobject(py)?.into_any()),
"name" | "index_name" => Ok(self.name.clone().into_py(py)), "name" | "index_name" => Ok(self.name.clone().into_pyobject(py)?.into_any()),
_ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))), _ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))),
} }
} }

View File

@@ -10,12 +10,13 @@ use lancedb::table::{
Table as LanceDbTable, Table as LanceDbTable,
}; };
use pyo3::{ use pyo3::{
exceptions::{PyRuntimeError, PyValueError}, exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods, pyclass, pymethods,
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods}, types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
Bound, FromPyObject, PyAny, PyRef, PyResult, Python, ToPyObject, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
}; };
use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::future_into_py;
use std::collections::HashMap;
use crate::{ use crate::{
error::PythonErrorExt, error::PythonErrorExt,
@@ -221,7 +222,7 @@ impl Table {
let stats = inner.index_stats(&index_name).await.infer_error()?; let stats = inner.index_stats(&index_name).await.infer_error()?;
if let Some(stats) = stats { if let Some(stats) = stats {
Python::with_gil(|py| { Python::with_gil(|py| {
let dict = PyDict::new_bound(py); let dict = PyDict::new(py);
dict.set_item("num_indexed_rows", stats.num_indexed_rows)?; dict.set_item("num_indexed_rows", stats.num_indexed_rows)?;
dict.set_item("num_unindexed_rows", stats.num_unindexed_rows)?; dict.set_item("num_unindexed_rows", stats.num_unindexed_rows)?;
dict.set_item("index_type", stats.index_type.to_string())?; dict.set_item("index_type", stats.index_type.to_string())?;
@@ -234,7 +235,7 @@ impl Table {
dict.set_item("num_indices", num_indices)?; dict.set_item("num_indices", num_indices)?;
} }
Ok(Some(dict.to_object(py))) Ok(Some(dict.unbind()))
}) })
} else { } else {
Ok(None) Ok(None)
@@ -265,7 +266,7 @@ impl Table {
versions versions
.iter() .iter()
.map(|v| { .map(|v| {
let dict = PyDict::new_bound(py); let dict = PyDict::new(py);
dict.set_item("version", v.version).unwrap(); dict.set_item("version", v.version).unwrap();
dict.set_item( dict.set_item(
"timestamp", "timestamp",
@@ -274,14 +275,13 @@ impl Table {
.unwrap(); .unwrap();
let tup: Vec<(&String, &String)> = v.metadata.iter().collect(); let tup: Vec<(&String, &String)> = v.metadata.iter().collect();
dict.set_item("metadata", tup.into_py_dict_bound(py)) dict.set_item("metadata", tup.into_py_dict(py)?).unwrap();
.unwrap(); Ok(dict.unbind())
dict.to_object(py)
}) })
.collect::<Vec<_>>() .collect::<PyResult<Vec<_>>>()
}); });
Ok(versions_as_dict) versions_as_dict
}) })
} }
@@ -486,6 +486,37 @@ impl Table {
Ok(()) Ok(())
}) })
} }
pub fn replace_field_metadata<'a>(
self_: PyRef<'a, Self>,
field_name: String,
metadata: &Bound<'_, PyDict>,
) -> PyResult<Bound<'a, PyAny>> {
let mut new_metadata = HashMap::<String, String>::new();
for (column_name, value) in metadata.into_iter() {
let key: String = column_name.extract()?;
let value: String = value.extract()?;
new_metadata.insert(key, value);
}
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let native_tbl = inner
.as_native()
.ok_or_else(|| PyValueError::new_err("This cannot be run on a remote table"))?;
let schema = native_tbl.manifest().await.infer_error()?.schema;
let field = schema
.field(&field_name)
.ok_or_else(|| PyKeyError::new_err(format!("Field {} not found", field_name)))?;
native_tbl
.replace_field_metadata(vec![(field.id as u32, new_metadata)])
.await
.infer_error()?;
Ok(())
})
}
} }
#[derive(FromPyObject)] #[derive(FromPyObject)]

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-node" name = "lancedb-node"
version = "0.15.1-beta.3" version = "0.18.0-beta.0"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
edition.workspace = true edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.15.1-beta.3" version = "0.18.0-beta.0"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
@@ -70,6 +70,7 @@ candle-core = { version = "0.6.0", optional = true }
candle-transformers = { version = "0.6.0", optional = true } candle-transformers = { version = "0.6.0", optional = true }
candle-nn = { version = "0.6.0", optional = true } candle-nn = { version = "0.6.0", optional = true }
tokenizers = { version = "0.19.1", optional = true } tokenizers = { version = "0.19.1", optional = true }
semver = { workspace = true }
# For a workaround, see workspace Cargo.toml # For a workaround, see workspace Cargo.toml
crunchy.workspace = true crunchy.workspace = true
@@ -87,6 +88,7 @@ aws-config = { version = "1.0" }
aws-smithy-runtime = { version = "1.3" } aws-smithy-runtime = { version = "1.3" }
datafusion.workspace = true datafusion.workspace = true
http-body = "1" # Matching reqwest http-body = "1" # Matching reqwest
rstest = "0.23.0"
[features] [features]

View File

@@ -4,12 +4,14 @@
use std::{pin::Pin, sync::Arc}; use std::{pin::Pin, sync::Arc};
pub use arrow_schema; pub use arrow_schema;
use futures::{Stream, StreamExt}; use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{Stream, StreamExt, TryStreamExt};
#[cfg(feature = "polars")] #[cfg(feature = "polars")]
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame}; use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
use crate::error::Result; use crate::{error::Result, Error};
/// An iterator of batches that also has a schema /// An iterator of batches that also has a schema
pub trait RecordBatchReader: Iterator<Item = Result<arrow_array::RecordBatch>> { pub trait RecordBatchReader: Iterator<Item = Result<arrow_array::RecordBatch>> {
@@ -65,6 +67,20 @@ impl<I: lance::io::RecordBatchStream + 'static> From<I> for SendableRecordBatchS
} }
} }
pub trait SendableRecordBatchStreamExt {
fn into_df_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream;
}
impl SendableRecordBatchStreamExt for SendableRecordBatchStream {
fn into_df_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream {
let schema = self.schema();
Box::pin(RecordBatchStreamAdapter::new(
schema,
self.map_err(|ldb_err| DataFusionError::External(ldb_err.into())),
))
}
}
/// A simple RecordBatchStream formed from the two parts (stream + schema) /// A simple RecordBatchStream formed from the two parts (stream + schema)
#[pin_project::pin_project] #[pin_project::pin_project]
pub struct SimpleRecordBatchStream<S: Stream<Item = Result<arrow_array::RecordBatch>>> { pub struct SimpleRecordBatchStream<S: Stream<Item = Result<arrow_array::RecordBatch>>> {
@@ -101,7 +117,7 @@ impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> RecordBatchStream
/// used in methods like [`crate::connection::Connection::create_table`] /// used in methods like [`crate::connection::Connection::create_table`]
/// or [`crate::table::Table::add`] /// or [`crate::table::Table::add`]
pub trait IntoArrow { pub trait IntoArrow {
/// Convert the data into an Arrow array /// Convert the data into an iterator of Arrow batches
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>; fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
} }
@@ -113,11 +129,38 @@ impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
} }
} }
/// A trait for converting incoming data to Arrow asynchronously
///
/// Serves the same purpose as [`IntoArrow`], but for asynchronous data.
///
/// Note: Arrow has no async equivalent to RecordBatchReader and so
pub trait IntoArrowStream {
/// Convert the data into a stream of Arrow batches
fn into_arrow(self) -> Result<SendableRecordBatchStream>;
}
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> { impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> {
pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self { pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self {
Self { schema, stream } Self { schema, stream }
} }
} }
impl IntoArrowStream for SendableRecordBatchStream {
fn into_arrow(self) -> Result<SendableRecordBatchStream> {
Ok(self)
}
}
impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
fn into_arrow(self) -> Result<SendableRecordBatchStream> {
let schema = self.schema();
let stream = self.map_err(|df_err| Error::Runtime {
message: df_err.to_string(),
});
Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema)))
}
}
#[cfg(feature = "polars")] #[cfg(feature = "polars")]
/// An iterator of record batches formed from a Polars DataFrame. /// An iterator of record batches formed from a Polars DataFrame.
pub struct PolarsDataFrameRecordBatchReader { pub struct PolarsDataFrameRecordBatchReader {

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Catalog implementation for managing databases
pub mod listing;
use std::collections::HashMap;
use std::sync::Arc;
use crate::database::Database;
use crate::error::Result;
use async_trait::async_trait;
/// Request parameters for listing databases
#[derive(Clone, Debug, Default)]
pub struct DatabaseNamesRequest {
/// Start listing after this name (exclusive)
pub start_after: Option<String>,
/// Maximum number of names to return
pub limit: Option<u32>,
}
/// Request to open an existing database
#[derive(Clone, Debug)]
pub struct OpenDatabaseRequest {
/// The name of the database to open
pub name: String,
/// A map of database-specific options
///
/// Consult the catalog / database implementation to determine which options are available
pub database_options: HashMap<String, String>,
}
/// Database creation mode
///
/// The default behavior is Create
pub enum CreateDatabaseMode {
/// Create new database, error if exists
Create,
/// Open existing database if present
ExistOk,
/// Overwrite existing database
Overwrite,
}
impl Default for CreateDatabaseMode {
fn default() -> Self {
Self::Create
}
}
/// Request to create a new database
pub struct CreateDatabaseRequest {
/// The name of the database to create
pub name: String,
/// The creation mode
pub mode: CreateDatabaseMode,
/// A map of catalog-specific options, consult your catalog implementation to determine what's available
pub options: HashMap<String, String>,
}
#[async_trait]
pub trait Catalog: Send + Sync + std::fmt::Debug + 'static {
/// List database names with pagination
async fn database_names(&self, request: DatabaseNamesRequest) -> Result<Vec<String>>;
/// Create a new database
async fn create_database(&self, request: CreateDatabaseRequest) -> Result<Arc<dyn Database>>;
/// Open existing database
async fn open_database(&self, request: OpenDatabaseRequest) -> Result<Arc<dyn Database>>;
/// Rename database
async fn rename_database(&self, old_name: &str, new_name: &str) -> Result<()>;
/// Delete database
async fn drop_database(&self, name: &str) -> Result<()>;
/// Delete all databases
async fn drop_all_databases(&self) -> Result<()>;
}

View File

@@ -0,0 +1,569 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Catalog implementation based on a local file system.
use std::collections::HashMap;
use std::fs::create_dir_all;
use std::path::Path;
use std::sync::Arc;
use super::{
Catalog, CreateDatabaseMode, CreateDatabaseRequest, DatabaseNamesRequest, OpenDatabaseRequest,
};
use crate::connection::ConnectRequest;
use crate::database::listing::ListingDatabase;
use crate::database::Database;
use crate::error::{CreateDirSnafu, Error, Result};
use async_trait::async_trait;
use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry};
use lance_io::local::to_local_path;
use object_store::path::Path as ObjectStorePath;
use snafu::ResultExt;
/// A catalog implementation that works by listing subfolders in a directory
///
/// The listing catalog will be created with a base folder specified by the URI. Every subfolder
/// in this base folder will be considered a database. These will be opened as a
/// [`crate::database::listing::ListingDatabase`]
#[derive(Debug)]
pub struct ListingCatalog {
object_store: ObjectStore,
uri: String,
base_path: ObjectStorePath,
storage_options: HashMap<String, String>,
}
impl ListingCatalog {
/// Try to create a local directory to store the lancedb dataset
pub fn try_create_dir(path: &str) -> core::result::Result<(), std::io::Error> {
let path = Path::new(path);
if !path.try_exists()? {
create_dir_all(path)?;
}
Ok(())
}
pub fn uri(&self) -> &str {
&self.uri
}
async fn open_path(path: &str) -> Result<Self> {
let (object_store, base_path) = ObjectStore::from_path(path).unwrap();
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
}
Ok(Self {
uri: path.to_string(),
base_path,
object_store,
storage_options: HashMap::new(),
})
}
pub async fn connect(request: &ConnectRequest) -> Result<Self> {
let uri = &request.uri;
let parse_res = url::Url::parse(uri);
match parse_res {
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await,
Ok(url) => {
let plain_uri = url.to_string();
let registry = Arc::new(ObjectStoreRegistry::default());
let storage_options = request.storage_options.clone();
let os_params = ObjectStoreParams {
storage_options: Some(storage_options.clone()),
..Default::default()
};
let (object_store, base_path) =
ObjectStore::from_uri_and_params(registry, &plain_uri, &os_params).await?;
if object_store.is_local() {
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
}
Ok(Self {
uri: String::from(url.clone()),
base_path,
object_store,
storage_options,
})
}
Err(_) => Self::open_path(uri).await,
}
}
fn database_path(&self, name: &str) -> ObjectStorePath {
self.base_path.child(name.replace('\\', "/"))
}
}
#[async_trait]
impl Catalog for ListingCatalog {
async fn database_names(&self, request: DatabaseNamesRequest) -> Result<Vec<String>> {
let mut f = self
.object_store
.read_dir(self.base_path.clone())
.await?
.iter()
.map(Path::new)
.filter_map(|p| p.file_name().and_then(|s| s.to_str().map(String::from)))
.collect::<Vec<String>>();
f.sort();
if let Some(start_after) = request.start_after {
let index = f
.iter()
.position(|name| name.as_str() > start_after.as_str())
.unwrap_or(f.len());
f.drain(0..index);
}
if let Some(limit) = request.limit {
f.truncate(limit as usize);
}
Ok(f)
}
async fn create_database(&self, request: CreateDatabaseRequest) -> Result<Arc<dyn Database>> {
let db_path = self.database_path(&request.name);
let db_path_str = to_local_path(&db_path);
let exists = Path::new(&db_path_str).exists();
match request.mode {
CreateDatabaseMode::Create if exists => {
return Err(Error::DatabaseAlreadyExists { name: request.name })
}
CreateDatabaseMode::Create => {
create_dir_all(db_path.to_string()).unwrap();
}
CreateDatabaseMode::ExistOk => {
if !exists {
create_dir_all(db_path.to_string()).unwrap();
}
}
CreateDatabaseMode::Overwrite => {
if exists {
self.drop_database(&request.name).await?;
}
create_dir_all(db_path.to_string()).unwrap();
}
}
let db_uri = format!("/{}/{}", self.base_path, request.name);
let connect_request = ConnectRequest {
uri: db_uri,
api_key: None,
region: None,
host_override: None,
#[cfg(feature = "remote")]
client_config: Default::default(),
read_consistency_interval: None,
storage_options: self.storage_options.clone(),
};
Ok(Arc::new(
ListingDatabase::connect_with_options(&connect_request).await?,
))
}
async fn open_database(&self, request: OpenDatabaseRequest) -> Result<Arc<dyn Database>> {
let db_path = self.database_path(&request.name);
let db_path_str = to_local_path(&db_path);
let exists = Path::new(&db_path_str).exists();
if !exists {
return Err(Error::DatabaseNotFound { name: request.name });
}
let connect_request = ConnectRequest {
uri: db_path.to_string(),
api_key: None,
region: None,
host_override: None,
#[cfg(feature = "remote")]
client_config: Default::default(),
read_consistency_interval: None,
storage_options: self.storage_options.clone(),
};
Ok(Arc::new(
ListingDatabase::connect_with_options(&connect_request).await?,
))
}
async fn rename_database(&self, _old_name: &str, _new_name: &str) -> Result<()> {
Err(Error::NotSupported {
message: "rename_database is not supported in LanceDB OSS yet".to_string(),
})
}
async fn drop_database(&self, name: &str) -> Result<()> {
let db_path = self.database_path(name);
self.object_store
.remove_dir_all(db_path.clone())
.await
.map_err(|err| match err {
lance::Error::NotFound { .. } => Error::DatabaseNotFound {
name: name.to_owned(),
},
_ => Error::from(err),
})?;
Ok(())
}
async fn drop_all_databases(&self) -> Result<()> {
self.object_store
.remove_dir_all(self.base_path.clone())
.await?;
Ok(())
}
}
#[cfg(all(test, not(windows)))]
mod tests {
use super::*;
/// file:/// URIs with drive letters do not work correctly on Windows
#[cfg(windows)]
fn path_to_uri(path: PathBuf) -> String {
path.to_str().unwrap().to_string()
}
#[cfg(not(windows))]
fn path_to_uri(path: PathBuf) -> String {
Url::from_file_path(path).unwrap().to_string()
}
async fn setup_catalog() -> (TempDir, ListingCatalog) {
let tempdir = tempfile::tempdir().unwrap();
let catalog_path = tempdir.path().join("catalog");
std::fs::create_dir_all(&catalog_path).unwrap();
let uri = path_to_uri(catalog_path);
let request = ConnectRequest {
uri: uri.clone(),
api_key: None,
region: None,
host_override: None,
#[cfg(feature = "remote")]
client_config: Default::default(),
storage_options: HashMap::new(),
read_consistency_interval: None,
};
let catalog = ListingCatalog::connect(&request).await.unwrap();
(tempdir, catalog)
}
use crate::database::{CreateTableData, CreateTableRequest, TableNamesRequest};
use crate::table::TableDefinition;
use arrow_schema::Field;
use std::path::PathBuf;
use std::sync::Arc;
use tempfile::{tempdir, TempDir};
use url::Url;
#[tokio::test]
async fn test_database_names() {
let (_tempdir, catalog) = setup_catalog().await;
let names = catalog
.database_names(DatabaseNamesRequest::default())
.await
.unwrap();
assert!(names.is_empty());
}
#[tokio::test]
async fn test_create_database() {
let (_tempdir, catalog) = setup_catalog().await;
catalog
.create_database(CreateDatabaseRequest {
name: "db1".into(),
mode: CreateDatabaseMode::Create,
options: HashMap::new(),
})
.await
.unwrap();
let names = catalog
.database_names(DatabaseNamesRequest::default())
.await
.unwrap();
assert_eq!(names, vec!["db1"]);
}
#[tokio::test]
async fn test_create_database_exist_ok() {
let (_tempdir, catalog) = setup_catalog().await;
let db1 = catalog
.create_database(CreateDatabaseRequest {
name: "db_exist_ok".into(),
mode: CreateDatabaseMode::ExistOk,
options: HashMap::new(),
})
.await
.unwrap();
let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::<Field>::default()));
db1.create_table(CreateTableRequest {
name: "test_table".parse().unwrap(),
data: CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)),
mode: Default::default(),
write_options: Default::default(),
})
.await
.unwrap();
let db2 = catalog
.create_database(CreateDatabaseRequest {
name: "db_exist_ok".into(),
mode: CreateDatabaseMode::ExistOk,
options: HashMap::new(),
})
.await
.unwrap();
let tables = db2.table_names(TableNamesRequest::default()).await.unwrap();
assert_eq!(tables, vec!["test_table".to_string()]);
}
#[tokio::test]
async fn test_create_database_overwrite() {
let (_tempdir, catalog) = setup_catalog().await;
let db = catalog
.create_database(CreateDatabaseRequest {
name: "db_overwrite".into(),
mode: CreateDatabaseMode::Create,
options: HashMap::new(),
})
.await
.unwrap();
let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::<Field>::default()));
db.create_table(CreateTableRequest {
name: "old_table".parse().unwrap(),
data: CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)),
mode: Default::default(),
write_options: Default::default(),
})
.await
.unwrap();
let tables = db.table_names(TableNamesRequest::default()).await.unwrap();
assert!(!tables.is_empty());
let new_db = catalog
.create_database(CreateDatabaseRequest {
name: "db_overwrite".into(),
mode: CreateDatabaseMode::Overwrite,
options: HashMap::new(),
})
.await
.unwrap();
let tables = new_db
.table_names(TableNamesRequest::default())
.await
.unwrap();
assert!(tables.is_empty());
}
#[tokio::test]
async fn test_create_database_overwrite_non_existing() {
let (_tempdir, catalog) = setup_catalog().await;
catalog
.create_database(CreateDatabaseRequest {
name: "new_db".into(),
mode: CreateDatabaseMode::Overwrite,
options: HashMap::new(),
})
.await
.unwrap();
let names = catalog
.database_names(DatabaseNamesRequest::default())
.await
.unwrap();
assert!(names.contains(&"new_db".to_string()));
}
#[tokio::test]
async fn test_open_database() {
let (_tempdir, catalog) = setup_catalog().await;
// Test open non-existent
let result = catalog
.open_database(OpenDatabaseRequest {
name: "missing".into(),
database_options: HashMap::new(),
})
.await;
assert!(matches!(
result.unwrap_err(),
Error::DatabaseNotFound { name } if name == "missing"
));
// Create and open
catalog
.create_database(CreateDatabaseRequest {
name: "valid_db".into(),
mode: CreateDatabaseMode::Create,
options: HashMap::new(),
})
.await
.unwrap();
let db = catalog
.open_database(OpenDatabaseRequest {
name: "valid_db".into(),
database_options: HashMap::new(),
})
.await
.unwrap();
assert_eq!(
db.table_names(TableNamesRequest::default()).await.unwrap(),
Vec::<String>::new()
);
}
#[tokio::test]
async fn test_drop_database() {
let (_tempdir, catalog) = setup_catalog().await;
// Create test database
catalog
.create_database(CreateDatabaseRequest {
name: "to_drop".into(),
mode: CreateDatabaseMode::Create,
options: HashMap::new(),
})
.await
.unwrap();
let names = catalog
.database_names(DatabaseNamesRequest::default())
.await
.unwrap();
assert!(!names.is_empty());
// Drop database
catalog.drop_database("to_drop").await.unwrap();
let names = catalog
.database_names(DatabaseNamesRequest::default())
.await
.unwrap();
assert!(names.is_empty());
}
#[tokio::test]
async fn test_drop_all_databases() {
let (_tempdir, catalog) = setup_catalog().await;
catalog
.create_database(CreateDatabaseRequest {
name: "db1".into(),
mode: CreateDatabaseMode::Create,
options: HashMap::new(),
})
.await
.unwrap();
catalog
.create_database(CreateDatabaseRequest {
name: "db2".into(),
mode: CreateDatabaseMode::Create,
options: HashMap::new(),
})
.await
.unwrap();
catalog.drop_all_databases().await.unwrap();
let names = catalog
.database_names(DatabaseNamesRequest::default())
.await
.unwrap();
assert!(names.is_empty());
}
#[tokio::test]
async fn test_rename_database_unsupported() {
let (_tempdir, catalog) = setup_catalog().await;
let result = catalog.rename_database("old", "new").await;
assert!(matches!(
result.unwrap_err(),
Error::NotSupported { message } if message.contains("rename_database")
));
}
#[tokio::test]
async fn test_connect_local_path() {
let tmp_dir = tempdir().unwrap();
let path = tmp_dir.path().to_str().unwrap();
let request = ConnectRequest {
uri: path.to_string(),
api_key: None,
region: None,
host_override: None,
#[cfg(feature = "remote")]
client_config: Default::default(),
storage_options: HashMap::new(),
read_consistency_interval: None,
};
let catalog = ListingCatalog::connect(&request).await.unwrap();
assert!(catalog.object_store.is_local());
assert_eq!(catalog.uri, path);
}
#[tokio::test]
async fn test_connect_file_scheme() {
let tmp_dir = tempdir().unwrap();
let path = tmp_dir.path();
let uri = path_to_uri(path.to_path_buf());
let request = ConnectRequest {
uri: uri.clone(),
api_key: None,
region: None,
host_override: None,
#[cfg(feature = "remote")]
client_config: Default::default(),
storage_options: HashMap::new(),
read_consistency_interval: None,
};
let catalog = ListingCatalog::connect(&request).await.unwrap();
assert!(catalog.object_store.is_local());
assert_eq!(catalog.uri, uri);
}
#[tokio::test]
async fn test_connect_invalid_uri_fallback() {
let invalid_uri = "invalid:///path";
let request = ConnectRequest {
uri: invalid_uri.to_string(),
api_key: None,
region: None,
host_override: None,
#[cfg(feature = "remote")]
client_config: Default::default(),
storage_options: HashMap::new(),
read_consistency_interval: None,
};
let result = ListingCatalog::connect(&request).await;
assert!(result.is_err());
}
}

View File

@@ -11,7 +11,7 @@ use arrow_schema::{Field, SchemaRef};
use lance::dataset::ReadParams; use lance::dataset::ReadParams;
use object_store::aws::AwsCredential; use object_store::aws::AwsCredential;
use crate::arrow::IntoArrow; use crate::arrow::{IntoArrow, IntoArrowStream, SendableRecordBatchStream};
use crate::database::listing::{ use crate::database::listing::{
ListingDatabase, OPT_NEW_TABLE_STORAGE_VERSION, OPT_NEW_TABLE_V2_MANIFEST_PATHS, ListingDatabase, OPT_NEW_TABLE_STORAGE_VERSION, OPT_NEW_TABLE_V2_MANIFEST_PATHS,
}; };
@@ -75,6 +75,14 @@ impl IntoArrow for NoData {
} }
} }
// Stores the value given from the initial CreateTableBuilder::new call
// and defers errors until `execute` is called
enum CreateTableBuilderInitialData {
None,
Iterator(Result<Box<dyn RecordBatchReader + Send>>),
Stream(Result<SendableRecordBatchStream>),
}
/// A builder for configuring a [`Connection::create_table`] operation /// A builder for configuring a [`Connection::create_table`] operation
pub struct CreateTableBuilder<const HAS_DATA: bool> { pub struct CreateTableBuilder<const HAS_DATA: bool> {
parent: Arc<dyn Database>, parent: Arc<dyn Database>,
@@ -83,7 +91,7 @@ pub struct CreateTableBuilder<const HAS_DATA: bool> {
request: CreateTableRequest, request: CreateTableRequest,
// This is a bit clumsy but we defer errors until `execute` is called // This is a bit clumsy but we defer errors until `execute` is called
// to maintain backwards compatibility // to maintain backwards compatibility
data: Option<Result<Box<dyn RecordBatchReader + Send>>>, data: CreateTableBuilderInitialData,
} }
// Builder methods that only apply when we have initial data // Builder methods that only apply when we have initial data
@@ -103,7 +111,26 @@ impl CreateTableBuilder<true> {
), ),
embeddings: Vec::new(), embeddings: Vec::new(),
embedding_registry, embedding_registry,
data: Some(data.into_arrow()), data: CreateTableBuilderInitialData::Iterator(data.into_arrow()),
}
}
fn new_streaming<T: IntoArrowStream>(
parent: Arc<dyn Database>,
name: String,
data: T,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
let dummy_schema = Arc::new(arrow_schema::Schema::new(Vec::<Field>::default()));
Self {
parent,
request: CreateTableRequest::new(
name,
CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)),
),
embeddings: Vec::new(),
embedding_registry,
data: CreateTableBuilderInitialData::Stream(data.into_arrow()),
} }
} }
@@ -125,17 +152,37 @@ impl CreateTableBuilder<true> {
} }
fn into_request(self) -> Result<CreateTableRequest> { fn into_request(self) -> Result<CreateTableRequest> {
let data = if self.embeddings.is_empty() { if self.embeddings.is_empty() {
self.data.unwrap()? match self.data {
CreateTableBuilderInitialData::Iterator(maybe_iter) => {
let data = maybe_iter?;
Ok(CreateTableRequest {
data: CreateTableData::Data(data),
..self.request
})
}
CreateTableBuilderInitialData::None => {
unreachable!("No data provided for CreateTableBuilder<true>")
}
CreateTableBuilderInitialData::Stream(maybe_stream) => {
let data = maybe_stream?;
Ok(CreateTableRequest {
data: CreateTableData::StreamingData(data),
..self.request
})
}
}
} else { } else {
let data = self.data.unwrap()?; let CreateTableBuilderInitialData::Iterator(maybe_iter) = self.data else {
Box::new(WithEmbeddings::new(data, self.embeddings)) return Err(Error::NotSupported { message: "Creating a table with embeddings is currently not support when the input is streaming".to_string() });
}; };
let req = self.request; let data = maybe_iter?;
Ok(CreateTableRequest { let data = Box::new(WithEmbeddings::new(data, self.embeddings));
data: CreateTableData::Data(data), Ok(CreateTableRequest {
..req data: CreateTableData::Data(data),
}) ..self.request
})
}
} }
} }
@@ -151,7 +198,7 @@ impl CreateTableBuilder<false> {
Self { Self {
parent, parent,
request: CreateTableRequest::new(name, CreateTableData::Empty(table_definition)), request: CreateTableRequest::new(name, CreateTableData::Empty(table_definition)),
data: None, data: CreateTableBuilderInitialData::None,
embeddings: Vec::default(), embeddings: Vec::default(),
embedding_registry, embedding_registry,
} }
@@ -432,7 +479,7 @@ impl Connection {
TableNamesBuilder::new(self.internal.clone()) TableNamesBuilder::new(self.internal.clone())
} }
/// Create a new table from data /// Create a new table from an iterator of data
/// ///
/// # Parameters /// # Parameters
/// ///
@@ -451,6 +498,25 @@ impl Connection {
) )
} }
/// Create a new table from a stream of data
///
/// # Parameters
///
/// * `name` - The name of the table
/// * `initial_data` - The initial data to write to the table
pub fn create_table_streaming<T: IntoArrowStream>(
&self,
name: impl Into<String>,
initial_data: T,
) -> CreateTableBuilder<true> {
CreateTableBuilder::<true>::new_streaming(
self.internal.clone(),
name.into(),
initial_data,
self.embedding_registry.clone(),
)
}
/// Create an empty table with a given schema /// Create an empty table with a given schema
/// ///
/// # Parameters /// # Parameters
@@ -788,12 +854,16 @@ mod test_utils {
mod tests { mod tests {
use std::fs::create_dir_all; use std::fs::create_dir_all;
use arrow::compute::concat_batches;
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use arrow_schema::{DataType, Field, Schema}; use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt; use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{stream, TryStreamExt};
use lance::error::{ArrowResult, DataFusionResult};
use lance_testing::datagen::{BatchGenerator, IncrementingInt32}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
use tempfile::tempdir; use tempfile::tempdir;
use crate::arrow::SimpleRecordBatchStream;
use crate::database::listing::{ListingDatabaseOptions, NewTableConfig}; use crate::database::listing::{ListingDatabaseOptions, NewTableConfig};
use crate::query::QueryBase; use crate::query::QueryBase;
use crate::query::{ExecutableQuery, QueryExecutionOptions}; use crate::query::{ExecutableQuery, QueryExecutionOptions};
@@ -976,6 +1046,63 @@ mod tests {
assert_eq!(batches.len(), 1); assert_eq!(batches.len(), 1);
} }
#[tokio::test]
async fn test_create_table_streaming() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let batches = make_data().collect::<ArrowResult<Vec<_>>>().unwrap();
let schema = batches.first().unwrap().schema();
let one_batch = concat_batches(&schema, batches.iter()).unwrap();
let ldb_stream = stream::iter(batches.clone().into_iter().map(Result::Ok));
let ldb_stream: SendableRecordBatchStream =
Box::pin(SimpleRecordBatchStream::new(ldb_stream, schema.clone()));
let tbl1 = db
.create_table_streaming("one", ldb_stream)
.execute()
.await
.unwrap();
let df_stream = stream::iter(batches.into_iter().map(DataFusionResult::Ok));
let df_stream: datafusion_physical_plan::SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), df_stream));
let tbl2 = db
.create_table_streaming("two", df_stream)
.execute()
.await
.unwrap();
let tbl1_data = tbl1
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let tbl1_data = concat_batches(&schema, tbl1_data.iter()).unwrap();
assert_eq!(tbl1_data, one_batch);
let tbl2_data = tbl2
.query()
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let tbl2_data = concat_batches(&schema, tbl2_data.iter()).unwrap();
assert_eq!(tbl2_data, one_batch);
}
#[tokio::test] #[tokio::test]
async fn drop_table() { async fn drop_table() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();

View File

@@ -18,8 +18,13 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream;
use lance::dataset::ReadParams; use lance::dataset::ReadParams;
use lance_datafusion::utils::StreamingWriteSource;
use crate::arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt};
use crate::error::Result; use crate::error::Result;
use crate::table::{BaseTable, TableDefinition, WriteOptions}; use crate::table::{BaseTable, TableDefinition, WriteOptions};
@@ -81,12 +86,41 @@ impl Default for CreateTableMode {
/// The data to start a table or a schema to create an empty table /// The data to start a table or a schema to create an empty table
pub enum CreateTableData { pub enum CreateTableData {
/// Creates a table using data, no schema required as it will be obtained from the data /// Creates a table using an iterator of data, the schema will be obtained from the data
Data(Box<dyn RecordBatchReader + Send>), Data(Box<dyn RecordBatchReader + Send>),
/// Creates a table using a stream of data, the schema will be obtained from the data
StreamingData(SendableRecordBatchStream),
/// Creates an empty table, the definition / schema must be provided separately /// Creates an empty table, the definition / schema must be provided separately
Empty(TableDefinition), Empty(TableDefinition),
} }
impl CreateTableData {
pub fn schema(&self) -> Arc<arrow_schema::Schema> {
match self {
Self::Data(reader) => reader.schema(),
Self::StreamingData(stream) => stream.schema(),
Self::Empty(definition) => definition.schema.clone(),
}
}
}
#[async_trait]
impl StreamingWriteSource for CreateTableData {
fn arrow_schema(&self) -> Arc<arrow_schema::Schema> {
self.schema()
}
fn into_stream(self) -> datafusion_physical_plan::SendableRecordBatchStream {
match self {
Self::Data(reader) => reader.into_stream(),
Self::StreamingData(stream) => stream.into_df_stream(),
Self::Empty(table_definition) => {
let schema = table_definition.schema.clone();
Box::pin(RecordBatchStreamAdapter::new(schema, stream::empty()))
}
}
}
}
/// A request to create a table /// A request to create a table
pub struct CreateTableRequest { pub struct CreateTableRequest {
/// The name of the new table /// The name of the new table

View File

@@ -7,9 +7,9 @@ use std::fs::create_dir_all;
use std::path::Path; use std::path::Path;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use arrow_array::RecordBatchIterator;
use lance::dataset::{ReadParams, WriteMode}; use lance::dataset::{ReadParams, WriteMode};
use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry, WrappingObjectStore}; use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry, WrappingObjectStore};
use lance_datafusion::utils::StreamingWriteSource;
use lance_encoding::version::LanceFileVersion; use lance_encoding::version::LanceFileVersion;
use lance_table::io::commit::commit_handler_from_url; use lance_table::io::commit::commit_handler_from_url;
use object_store::local::LocalFileSystem; use object_store::local::LocalFileSystem;
@@ -22,8 +22,8 @@ use crate::table::NativeTable;
use crate::utils::validate_table_name; use crate::utils::validate_table_name;
use super::{ use super::{
BaseTable, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, BaseTable, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, OpenTableRequest,
OpenTableRequest, TableNamesRequest, TableNamesRequest,
}; };
/// File extension to indicate a lance table /// File extension to indicate a lance table
@@ -322,6 +322,37 @@ impl ListingDatabase {
Ok(uri) Ok(uri)
} }
async fn drop_tables(&self, names: Vec<String>) -> Result<()> {
let object_store_params = ObjectStoreParams {
storage_options: Some(self.storage_options.clone()),
..Default::default()
};
let mut uri = self.uri.clone();
if let Some(query_string) = &self.query_string {
uri.push_str(&format!("?{}", query_string));
}
let commit_handler = commit_handler_from_url(&uri, &Some(object_store_params)).await?;
for name in names {
let dir_name = format!("{}.{}", name, LANCE_EXTENSION);
let full_path = self.base_path.child(dir_name.clone());
commit_handler.delete(&full_path).await?;
self.object_store
.remove_dir_all(full_path.clone())
.await
.map_err(|err| match err {
// this error is not lance::Error::DatasetNotFound, as the method
// `remove_dir_all` may be used to remove something not be a dataset
lance::Error::NotFound { .. } => Error::TableNotFound {
name: name.to_owned(),
},
_ => Error::from(err),
})?;
}
Ok(())
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@@ -401,19 +432,12 @@ impl Database for ListingDatabase {
write_params.mode = WriteMode::Overwrite; write_params.mode = WriteMode::Overwrite;
} }
let data = match request.data { let data_schema = request.data.arrow_schema();
CreateTableData::Data(data) => data,
CreateTableData::Empty(table_definition) => {
let schema = table_definition.schema.clone();
Box::new(RecordBatchIterator::new(vec![], schema))
}
};
let data_schema = data.schema();
match NativeTable::create( match NativeTable::create(
&table_uri, &table_uri,
&request.name, &request.name,
data, request.data,
self.store_wrapper.clone(), self.store_wrapper.clone(),
Some(write_params), Some(write_params),
self.read_consistency_interval, self.read_consistency_interval,
@@ -500,40 +524,12 @@ impl Database for ListingDatabase {
} }
async fn drop_table(&self, name: &str) -> Result<()> { async fn drop_table(&self, name: &str) -> Result<()> {
let dir_name = format!("{}.{}", name, LANCE_EXTENSION); self.drop_tables(vec![name.to_string()]).await
let full_path = self.base_path.child(dir_name.clone());
self.object_store
.remove_dir_all(full_path.clone())
.await
.map_err(|err| match err {
// this error is not lance::Error::DatasetNotFound,
// as the method `remove_dir_all` may be used to remove something not be a dataset
lance::Error::NotFound { .. } => Error::TableNotFound {
name: name.to_owned(),
},
_ => Error::from(err),
})?;
let object_store_params = ObjectStoreParams {
storage_options: Some(self.storage_options.clone()),
..Default::default()
};
let mut uri = self.uri.clone();
if let Some(query_string) = &self.query_string {
uri.push_str(&format!("?{}", query_string));
}
let commit_handler = commit_handler_from_url(&uri, &Some(object_store_params))
.await
.unwrap();
commit_handler.delete(&full_path).await.unwrap();
Ok(())
} }
async fn drop_all_tables(&self) -> Result<()> { async fn drop_all_tables(&self) -> Result<()> {
self.object_store let tables = self.table_names(TableNamesRequest::default()).await?;
.remove_dir_all(self.base_path.clone()) self.drop_tables(tables).await
.await?;
Ok(())
} }
fn as_any(&self) -> &dyn std::any::Any { fn as_any(&self) -> &dyn std::any::Any {

View File

@@ -15,6 +15,10 @@ pub enum Error {
InvalidInput { message: String }, InvalidInput { message: String },
#[snafu(display("Table '{name}' was not found"))] #[snafu(display("Table '{name}' was not found"))]
TableNotFound { name: String }, TableNotFound { name: String },
#[snafu(display("Database '{name}' was not found"))]
DatabaseNotFound { name: String },
#[snafu(display("Database '{name}' already exists."))]
DatabaseAlreadyExists { name: String },
#[snafu(display("Index '{name}' was not found"))] #[snafu(display("Index '{name}' was not found"))]
IndexNotFound { name: String }, IndexNotFound { name: String },
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))] #[snafu(display("Embedding function '{name}' was not found. : {reason}"))]

View File

@@ -23,7 +23,19 @@ impl VectorIndex {
let fields = index let fields = index
.fields .fields
.iter() .iter()
.map(|i| manifest.schema.fields[*i as usize].name.clone()) .map(|field_id| {
manifest
.schema
.field_by_id(*field_id)
.unwrap_or_else(|| {
panic!(
"field {field_id} of index {} must exist in schema",
index.name
)
})
.name
.clone()
})
.collect(); .collect();
Self { Self {
columns: fields, columns: fields,

View File

@@ -191,6 +191,7 @@
//! ``` //! ```
pub mod arrow; pub mod arrow;
pub mod catalog;
pub mod connection; pub mod connection;
pub mod data; pub mod data;
pub mod database; pub mod database;

View File

@@ -7,6 +7,7 @@ use std::sync::Arc;
use arrow::compute::concat_batches; use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType; use arrow_schema::DataType;
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::ExecutionPlan;
use futures::{stream, try_join, FutureExt, TryStreamExt}; use futures::{stream, try_join, FutureExt, TryStreamExt};
use half::f16; use half::f16;
@@ -464,11 +465,14 @@ impl<T: HasQuery> QueryBase for T {
} }
fn only_if(mut self, filter: impl AsRef<str>) -> Self { fn only_if(mut self, filter: impl AsRef<str>) -> Self {
self.mut_query().filter = Some(filter.as_ref().to_string()); self.mut_query().filter = Some(QueryFilter::Sql(filter.as_ref().to_string()));
self self
} }
fn full_text_search(mut self, query: FullTextSearchQuery) -> Self { fn full_text_search(mut self, query: FullTextSearchQuery) -> Self {
if self.mut_query().limit.is_none() {
self.mut_query().limit = Some(DEFAULT_TOP_K);
}
self.mut_query().full_text_search = Some(query); self.mut_query().full_text_search = Some(query);
self self
} }
@@ -577,6 +581,17 @@ pub trait ExecutableQuery {
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send; fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
} }
/// A query filter that can be applied to a query
#[derive(Clone, Debug)]
pub enum QueryFilter {
/// The filter is an SQL string
Sql(String),
/// The filter is a Substrait ExtendedExpression message with a single expression
Substrait(Arc<[u8]>),
/// The filter is a Datafusion expression
Datafusion(Expr),
}
/// A basic query into a table without any kind of search /// A basic query into a table without any kind of search
/// ///
/// This will result in a (potentially filtered) scan if executed /// This will result in a (potentially filtered) scan if executed
@@ -589,7 +604,7 @@ pub struct QueryRequest {
pub offset: Option<usize>, pub offset: Option<usize>,
/// Apply filter to the returned rows. /// Apply filter to the returned rows.
pub filter: Option<String>, pub filter: Option<QueryFilter>,
/// Perform a full text search on the table. /// Perform a full text search on the table.
pub full_text_search: Option<FullTextSearchQuery>, pub full_text_search: Option<FullTextSearchQuery>,
@@ -622,7 +637,7 @@ pub struct QueryRequest {
impl Default for QueryRequest { impl Default for QueryRequest {
fn default() -> Self { fn default() -> Self {
Self { Self {
limit: Some(DEFAULT_TOP_K), limit: None,
offset: None, offset: None,
filter: None, filter: None,
full_text_search: None, full_text_search: None,
@@ -707,6 +722,11 @@ impl Query {
let mut vector_query = self.into_vector(); let mut vector_query = self.into_vector();
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
vector_query.request.query_vector.push(query_vector); vector_query.request.query_vector.push(query_vector);
if vector_query.request.base.limit.is_none() {
vector_query.request.base.limit = Some(DEFAULT_TOP_K);
}
Ok(vector_query) Ok(vector_query)
} }

View File

@@ -19,12 +19,41 @@ use crate::database::{
}; };
use crate::error::Result; use crate::error::Result;
use crate::table::BaseTable; use crate::table::BaseTable;
use crate::Error;
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender}; use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
use super::table::RemoteTable; use super::table::RemoteTable;
use super::util::batches_to_ipc_bytes; use super::util::{batches_to_ipc_bytes, parse_server_version};
use super::ARROW_STREAM_CONTENT_TYPE; use super::ARROW_STREAM_CONTENT_TYPE;
// the versions of the server that we support
// for any new feature that we need to change the SDK behavior, we should bump the server version,
// and add a feature flag as method of `ServerVersion` here.
pub const DEFAULT_SERVER_VERSION: semver::Version = semver::Version::new(0, 1, 0);
#[derive(Debug, Clone)]
pub struct ServerVersion(pub semver::Version);
impl Default for ServerVersion {
fn default() -> Self {
Self(DEFAULT_SERVER_VERSION.clone())
}
}
impl ServerVersion {
pub fn parse(version: &str) -> Result<Self> {
let version = Self(
semver::Version::parse(version).map_err(|e| Error::InvalidInput {
message: e.to_string(),
})?,
);
Ok(version)
}
pub fn support_multivector(&self) -> bool {
self.0 >= semver::Version::new(0, 2, 0)
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct ListTablesResponse { struct ListTablesResponse {
tables: Vec<String>, tables: Vec<String>,
@@ -33,7 +62,7 @@ struct ListTablesResponse {
#[derive(Debug)] #[derive(Debug)]
pub struct RemoteDatabase<S: HttpSend = Sender> { pub struct RemoteDatabase<S: HttpSend = Sender> {
client: RestfulLanceDbClient<S>, client: RestfulLanceDbClient<S>,
table_cache: Cache<String, ()>, table_cache: Cache<String, Arc<RemoteTable<S>>>,
} }
impl RemoteDatabase { impl RemoteDatabase {
@@ -115,13 +144,19 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
} }
let (request_id, rsp) = self.client.send(req, true).await?; let (request_id, rsp) = self.client.send(req, true).await?;
let rsp = self.client.check_response(&request_id, rsp).await?; let rsp = self.client.check_response(&request_id, rsp).await?;
let version = parse_server_version(&request_id, &rsp)?;
let tables = rsp let tables = rsp
.json::<ListTablesResponse>() .json::<ListTablesResponse>()
.await .await
.err_to_http(request_id)? .err_to_http(request_id)?
.tables; .tables;
for table in &tables { for table in &tables {
self.table_cache.insert(table.clone(), ()).await; let remote_table = Arc::new(RemoteTable::new(
self.client.clone(),
table.clone(),
version.clone(),
));
self.table_cache.insert(table.clone(), remote_table).await;
} }
Ok(tables) Ok(tables)
} }
@@ -129,6 +164,11 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> { async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
let data = match request.data { let data = match request.data {
CreateTableData::Data(data) => data, CreateTableData::Data(data) => data,
CreateTableData::StreamingData(_) => {
return Err(Error::NotSupported {
message: "Creating a remote table from a streaming source".to_string(),
})
}
CreateTableData::Empty(table_definition) => { CreateTableData::Empty(table_definition) => {
let schema = table_definition.schema.clone(); let schema = table_definition.schema.clone();
Box::new(RecordBatchIterator::new(vec![], schema)) Box::new(RecordBatchIterator::new(vec![], schema))
@@ -187,34 +227,42 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
return Err(crate::Error::InvalidInput { message: body }); return Err(crate::Error::InvalidInput { message: body });
} }
} }
let rsp = self.client.check_response(&request_id, rsp).await?;
self.client.check_response(&request_id, rsp).await?; let version = parse_server_version(&request_id, &rsp)?;
let table = Arc::new(RemoteTable::new(
self.table_cache.insert(request.name.clone(), ()).await;
Ok(Arc::new(RemoteTable::new(
self.client.clone(), self.client.clone(),
request.name, request.name.clone(),
))) version,
));
self.table_cache
.insert(request.name.clone(), table.clone())
.await;
Ok(table)
} }
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> { async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
// We describe the table to confirm it exists before moving on. // We describe the table to confirm it exists before moving on.
if self.table_cache.get(&request.name).await.is_none() { if let Some(table) = self.table_cache.get(&request.name).await {
Ok(table.clone())
} else {
let req = self let req = self
.client .client
.post(&format!("/v1/table/{}/describe/", request.name)); .post(&format!("/v1/table/{}/describe/", request.name));
let (request_id, resp) = self.client.send(req, true).await?; let (request_id, rsp) = self.client.send(req, true).await?;
if resp.status() == StatusCode::NOT_FOUND { if rsp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: request.name }); return Err(crate::Error::TableNotFound { name: request.name });
} }
self.client.check_response(&request_id, resp).await?; let rsp = self.client.check_response(&request_id, rsp).await?;
let version = parse_server_version(&request_id, &rsp)?;
let table = Arc::new(RemoteTable::new(
self.client.clone(),
request.name.clone(),
version,
));
self.table_cache.insert(request.name, table.clone()).await;
Ok(table)
} }
Ok(Arc::new(RemoteTable::new(
self.client.clone(),
request.name,
)))
} }
async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> { async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> {
@@ -224,8 +272,10 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
let req = req.json(&serde_json::json!({ "new_table_name": new_name })); let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
let (request_id, resp) = self.client.send(req, false).await?; let (request_id, resp) = self.client.send(req, false).await?;
self.client.check_response(&request_id, resp).await?; self.client.check_response(&request_id, resp).await?;
self.table_cache.remove(current_name).await; let table = self.table_cache.remove(current_name).await;
self.table_cache.insert(new_name.into(), ()).await; if let Some(table) = table {
self.table_cache.insert(new_name.into(), table).await;
}
Ok(()) Ok(())
} }

View File

@@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
use crate::index::Index; use crate::index::Index;
use crate::index::IndexStatistics; use crate::index::IndexStatistics;
use crate::query::{QueryRequest, Select, VectorQueryRequest}; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
use crate::table::{AddDataMode, AnyQuery, Filter}; use crate::table::{AddDataMode, AnyQuery, Filter};
use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::{DistanceType, Error, Table}; use crate::{DistanceType, Error, Table};
@@ -41,6 +41,7 @@ use crate::{
use super::client::RequestResultExt; use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::db::ServerVersion;
use super::ARROW_STREAM_CONTENT_TYPE; use super::ARROW_STREAM_CONTENT_TYPE;
#[derive(Debug)] #[derive(Debug)]
@@ -48,15 +49,21 @@ pub struct RemoteTable<S: HttpSend = Sender> {
#[allow(dead_code)] #[allow(dead_code)]
client: RestfulLanceDbClient<S>, client: RestfulLanceDbClient<S>,
name: String, name: String,
server_version: ServerVersion,
version: RwLock<Option<u64>>, version: RwLock<Option<u64>>,
} }
impl<S: HttpSend> RemoteTable<S> { impl<S: HttpSend> RemoteTable<S> {
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self { pub fn new(
client: RestfulLanceDbClient<S>,
name: String,
server_version: ServerVersion,
) -> Self {
Self { Self {
client, client,
name, name,
server_version,
version: RwLock::new(None), version: RwLock::new(None),
} }
} }
@@ -149,16 +156,23 @@ impl<S: HttpSend> RemoteTable<S> {
} }
fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> { fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> {
body["prefilter"] = params.prefilter.into();
if let Some(offset) = params.offset { if let Some(offset) = params.offset {
body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset)); body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset));
} }
if let Some(limit) = params.limit { // Server requires k.
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit)); let limit = params.limit.unwrap_or(usize::MAX);
} body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
if let Some(filter) = &params.filter { if let Some(filter) = &params.filter {
body["filter"] = serde_json::Value::String(filter.clone()); if let QueryFilter::Sql(filter) = filter {
body["filter"] = serde_json::Value::String(filter.clone());
} else {
return Err(Error::NotSupported {
message: "querying a remote table with a non-sql filter".to_string(),
});
}
} }
match &params.select { match &params.select {
@@ -205,13 +219,13 @@ impl<S: HttpSend> RemoteTable<S> {
} }
fn apply_vector_query_params( fn apply_vector_query_params(
&self,
mut body: serde_json::Value, mut body: serde_json::Value,
query: &VectorQueryRequest, query: &VectorQueryRequest,
) -> Result<Vec<serde_json::Value>> { ) -> Result<Vec<serde_json::Value>> {
Self::apply_query_params(&mut body, &query.base)?; Self::apply_query_params(&mut body, &query.base)?;
// Apply general parameters, before we dispatch based on number of query vectors. // Apply general parameters, before we dispatch based on number of query vectors.
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into(); body["nprobes"] = query.nprobes.into();
body["lower_bound"] = query.lower_bound.into(); body["lower_bound"] = query.lower_bound.into();
@@ -250,26 +264,40 @@ impl<S: HttpSend> RemoteTable<S> {
} }
} }
match query.query_vector.len() { let bodies = match query.query_vector.len() {
0 => { 0 => {
// Server takes empty vector, not null or undefined. // Server takes empty vector, not null or undefined.
body["vector"] = serde_json::Value::Array(Vec::new()); body["vector"] = serde_json::Value::Array(Vec::new());
Ok(vec![body]) vec![body]
} }
1 => { 1 => {
body["vector"] = vector_to_json(&query.query_vector[0])?; body["vector"] = vector_to_json(&query.query_vector[0])?;
Ok(vec![body]) vec![body]
} }
_ => { _ => {
let mut bodies = Vec::with_capacity(query.query_vector.len()); if self.server_version.support_multivector() {
for vector in &query.query_vector { let vectors = query
let mut body = body.clone(); .query_vector
body["vector"] = vector_to_json(vector)?; .iter()
bodies.push(body); .map(vector_to_json)
.collect::<Result<Vec<_>>>()?;
body["vector"] = serde_json::Value::Array(vectors);
vec![body]
} else {
// Server does not support multiple vectors in a single query.
// We need to send multiple requests.
let mut bodies = Vec::with_capacity(query.query_vector.len());
for vector in &query.query_vector {
let mut body = body.clone();
body["vector"] = vector_to_json(vector)?;
bodies.push(body);
}
bodies
} }
Ok(bodies)
} }
} };
Ok(bodies)
} }
async fn check_mutable(&self) -> Result<()> { async fn check_mutable(&self) -> Result<()> {
@@ -300,33 +328,28 @@ impl<S: HttpSend> RemoteTable<S> {
let version = self.current_version().await; let version = self.current_version().await;
let mut body = serde_json::json!({ "version": version }); let mut body = serde_json::json!({ "version": version });
match query { let requests = match query {
AnyQuery::Query(query) => { AnyQuery::Query(query) => {
Self::apply_query_params(&mut body, query)?; Self::apply_query_params(&mut body, query)?;
// Empty vector can be passed if no vector search is performed. // Empty vector can be passed if no vector search is performed.
body["vector"] = serde_json::Value::Array(Vec::new()); body["vector"] = serde_json::Value::Array(Vec::new());
vec![request.json(&body)]
let request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let stream = self.read_arrow_stream(&request_id, response).await?;
Ok(vec![stream])
} }
AnyQuery::VectorQuery(query) => { AnyQuery::VectorQuery(query) => {
let bodies = Self::apply_vector_query_params(body, query)?; let bodies = self.apply_vector_query_params(body, query)?;
let mut futures = Vec::with_capacity(bodies.len()); bodies
for body in bodies { .into_iter()
let request = request.try_clone().unwrap().json(&body); .map(|body| request.try_clone().unwrap().json(&body))
let future = async move { .collect()
let (request_id, response) = self.client.send(request, true).await?;
self.read_arrow_stream(&request_id, response).await
};
futures.push(future);
}
futures::future::try_join_all(futures).await
} }
} };
let futures = requests.into_iter().map(|req| async move {
let (request_id, response) = self.client.send(req, true).await?;
self.read_arrow_stream(&request_id, response).await
});
let streams = futures::future::try_join_all(futures).await?;
Ok(streams)
} }
} }
@@ -349,7 +372,7 @@ mod test_utils {
use crate::remote::client::test_utils::MockSender; use crate::remote::client::test_utils::MockSender;
impl RemoteTable<MockSender> { impl RemoteTable<MockSender> {
pub fn new_mock<F, T>(name: String, handler: F) -> Self pub fn new_mock<F, T>(name: String, handler: F, version: Option<semver::Version>) -> Self
where where
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static, F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
T: Into<reqwest::Body>, T: Into<reqwest::Body>,
@@ -358,6 +381,7 @@ mod test_utils {
Self { Self {
client, client,
name, name,
server_version: version.map(ServerVersion).unwrap_or_default(),
version: RwLock::new(None), version: RwLock::new(None),
} }
} }
@@ -499,7 +523,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
options: QueryExecutionOptions, options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> { ) -> Result<Arc<dyn ExecutionPlan>> {
let streams = self.execute_query(query, options).await?; let streams = self.execute_query(query, options).await?;
if streams.len() == 1 { if streams.len() == 1 {
let stream = streams.into_iter().next().unwrap(); let stream = streams.into_iter().next().unwrap();
Ok(Arc::new(OneShotExec::new(stream))) Ok(Arc::new(OneShotExec::new(stream)))
@@ -917,8 +940,10 @@ mod tests {
use futures::{future::BoxFuture, StreamExt, TryFutureExt}; use futures::{future::BoxFuture, StreamExt, TryFutureExt};
use lance_index::scalar::FullTextSearchQuery; use lance_index::scalar::FullTextSearchQuery;
use reqwest::Body; use reqwest::Body;
use rstest::rstest;
use crate::index::vector::IvfFlatIndexBuilder; use crate::index::vector::IvfFlatIndexBuilder;
use crate::remote::db::DEFAULT_SERVER_VERSION;
use crate::remote::JSON_CONTENT_TYPE; use crate::remote::JSON_CONTENT_TYPE;
use crate::{ use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType}, index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
@@ -1326,6 +1351,52 @@ mod tests {
table.delete("id in (1, 2, 3)").await.unwrap(); table.delete("id in (1, 2, 3)").await.unwrap();
} }
#[tokio::test]
async fn test_query_plain() {
let expected_data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let expected_data_ref = expected_data.clone();
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let expected_body = serde_json::json!({
"k": usize::MAX,
"prefilter": true,
"vector": [], // Empty vector means no vector query.
"version": null,
});
assert_eq!(body, expected_body);
let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let data = table
.query()
.execute()
.await
.unwrap()
.collect::<Vec<_>>()
.await;
assert_eq!(data.len(), 1);
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
}
#[tokio::test] #[tokio::test]
async fn test_query_vector_default_values() { async fn test_query_vector_default_values() {
let expected_data = RecordBatch::try_new( let expected_data = RecordBatch::try_new(
@@ -1379,6 +1450,55 @@ mod tests {
assert_eq!(data[0].as_ref().unwrap(), &expected_data); assert_eq!(data[0].as_ref().unwrap(), &expected_data);
} }
#[tokio::test]
async fn test_query_fts_default_values() {
let expected_data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let expected_data_ref = expected_data.clone();
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let expected_body = serde_json::json!({
"full_text_query": {
"columns": [],
"query": "test",
},
"prefilter": true,
"version": null,
"k": 10,
"vector": [],
});
assert_eq!(body, expected_body);
let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let data = table
.query()
.full_text_search(FullTextSearchQuery::new("test".to_owned()))
.execute()
.await;
let data = data.unwrap().collect::<Vec<_>>().await;
assert_eq!(data.len(), 1);
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
}
#[tokio::test] #[tokio::test]
async fn test_query_vector_all_params() { async fn test_query_vector_all_params() {
let table = Table::new_with_handler("my_table", |request| { let table = Table::new_with_handler("my_table", |request| {
@@ -1461,6 +1581,7 @@ mod tests {
"k": 10, "k": 10,
"vector": [], "vector": [],
"with_row_id": true, "with_row_id": true,
"prefilter": true,
"version": null "version": null
}); });
assert_eq!(body, expected_body); assert_eq!(body, expected_body);
@@ -1491,20 +1612,47 @@ mod tests {
.unwrap(); .unwrap();
} }
#[rstest]
#[case(DEFAULT_SERVER_VERSION.clone())]
#[case(semver::Version::new(0, 2, 0))]
#[tokio::test] #[tokio::test]
async fn test_query_multiple_vectors() { async fn test_batch_queries(#[case] version: semver::Version) {
let table = Table::new_with_handler("my_table", |request| { let table = Table::new_with_handler_version("my_table", version.clone(), move |request| {
assert_eq!(request.method(), "POST"); assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/query/"); assert_eq!(request.url().path(), "/v1/table/my_table/query/");
assert_eq!( assert_eq!(
request.headers().get("Content-Type").unwrap(), request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE JSON_CONTENT_TYPE
); );
let data = RecordBatch::try_new( let body: serde_json::Value =
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], let query_vectors = body["vector"].as_array().unwrap();
) let version = ServerVersion(version.clone());
.unwrap(); let data = if version.support_multivector() {
assert_eq!(query_vectors.len(), 2);
assert_eq!(query_vectors[0].as_array().unwrap().len(), 3);
assert_eq!(query_vectors[1].as_array().unwrap().len(), 3);
RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("query_index", DataType::Int32, false),
])),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])),
Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])),
],
)
.unwrap()
} else {
// it's single flat vector, so here the length is dim
assert_eq!(query_vectors.len(), 3);
RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap()
};
let response_body = write_ipc_file(&data); let response_body = write_ipc_file(&data);
http::Response::builder() http::Response::builder()
.status(200) .status(200)
@@ -1519,8 +1667,6 @@ mod tests {
.unwrap() .unwrap()
.add_query_vector(vec![0.4, 0.5, 0.6]) .add_query_vector(vec![0.4, 0.5, 0.6])
.unwrap(); .unwrap();
let plan = query.explain_plan(true).await.unwrap();
assert!(plan.contains("UnionExec"), "Plan: {}", plan);
let results = query let results = query
.execute() .execute()

View File

@@ -4,9 +4,12 @@
use std::io::Cursor; use std::io::Cursor;
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use reqwest::Response;
use crate::Result; use crate::Result;
use super::db::ServerVersion;
pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>> { pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>> {
const WRITE_BUF_SIZE: usize = 4096; const WRITE_BUF_SIZE: usize = 4096;
let buf = Vec::with_capacity(WRITE_BUF_SIZE); let buf = Vec::with_capacity(WRITE_BUF_SIZE);
@@ -22,3 +25,24 @@ pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>>
} }
Ok(buf.into_inner()) Ok(buf.into_inner())
} }
pub fn parse_server_version(req_id: &str, rsp: &Response) -> Result<ServerVersion> {
let version = rsp
.headers()
.get("phalanx-version")
.map(|v| {
let v = v.to_str().map_err(|e| crate::Error::Http {
source: e.into(),
request_id: req_id.to_string(),
status_code: Some(rsp.status()),
})?;
ServerVersion::parse(v).map_err(|e| crate::Error::Http {
source: e.into(),
request_id: req_id.to_string(),
status_code: Some(rsp.status()),
})
})
.transpose()?
.unwrap_or_default();
Ok(version)
}

View File

@@ -28,13 +28,13 @@ pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams; pub use lance::dataset::ReadParams;
pub use lance::dataset::Version; pub use lance::dataset::Version;
use lance::dataset::{ use lance::dataset::{
Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams,
WriteParams,
}; };
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::index::vector::utils::infer_vector_dim; use lance::index::vector::utils::infer_vector_dim;
use lance::io::WrappingObjectStore; use lance::io::WrappingObjectStore;
use lance_datafusion::exec::execute_plan; use lance_datafusion::exec::execute_plan;
use lance_datafusion::utils::StreamingWriteSource;
use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams; use lance_index::vector::pq::PQBuildParams;
@@ -62,7 +62,7 @@ use crate::index::{
}; };
use crate::index::{IndexConfig, IndexStatisticsImpl}; use crate::index::{IndexConfig, IndexStatisticsImpl};
use crate::query::{ use crate::query::{
IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery, IntoQueryVector, Query, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQuery,
VectorQueryRequest, DEFAULT_TOP_K, VectorQueryRequest, DEFAULT_TOP_K,
}; };
use crate::utils::{ use crate::utils::{
@@ -509,6 +509,27 @@ mod test_utils {
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock( let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(), name.into(),
handler, handler,
None,
));
Self {
inner,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
pub fn new_with_handler_version<T>(
name: impl Into<String>,
version: semver::Version,
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
) -> Self
where
T: Into<reqwest::Body>,
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler,
Some(version),
)); ));
Self { Self {
inner, inner,
@@ -1243,7 +1264,7 @@ impl NativeTable {
pub async fn create( pub async fn create(
uri: &str, uri: &str,
name: &str, name: &str,
batches: impl RecordBatchReader + Send + 'static, batches: impl StreamingWriteSource,
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>, write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
params: Option<WriteParams>, params: Option<WriteParams>,
read_consistency_interval: Option<std::time::Duration>, read_consistency_interval: Option<std::time::Duration>,
@@ -1258,7 +1279,9 @@ impl NativeTable {
None => params, None => params,
}; };
let dataset = Dataset::write(batches, uri, Some(params)) let insert_builder = InsertBuilder::new(uri).with_params(&params);
let dataset = insert_builder
.execute_stream(batches)
.await .await
.map_err(|e| match e { .map_err(|e| match e {
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
@@ -1266,6 +1289,7 @@ impl NativeTable {
}, },
source => Error::Lance { source }, source => Error::Lance { source },
})?; })?;
Ok(Self { Ok(Self {
name: name.to_string(), name: name.to_string(),
uri: uri.to_string(), uri: uri.to_string(),
@@ -1380,10 +1404,11 @@ impl NativeTable {
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> { pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
let dataset = self.dataset.get().await?; let dataset = self.dataset.get().await?;
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?; let mf = dataset.manifest();
let indices = dataset.load_indices().await?;
Ok(indices Ok(indices
.iter() .iter()
.map(|i| VectorIndex::new_from_format(&(mf.0), i)) .map(|i| VectorIndex::new_from_format(mf, i))
.collect()) .collect())
} }
@@ -1995,8 +2020,8 @@ impl BaseTable for NativeTable {
}; };
let ds_ref = self.dataset.get().await?; let ds_ref = self.dataset.get().await?;
let mut column = query.column.clone();
let schema = ds_ref.schema(); let schema = ds_ref.schema();
let mut column = query.column.clone();
let mut query_vector = query.query_vector.first().cloned(); let mut query_vector = query.query_vector.first().cloned();
if query.query_vector.len() > 1 { if query.query_vector.len() > 1 {
@@ -2124,7 +2149,17 @@ impl BaseTable for NativeTable {
} }
if let Some(filter) = &query.base.filter { if let Some(filter) = &query.base.filter {
scanner.filter(filter)?; match filter {
QueryFilter::Sql(sql) => {
scanner.filter(sql)?;
}
QueryFilter::Substrait(substrait) => {
scanner.filter_substrait(substrait)?;
}
QueryFilter::Datafusion(expr) => {
scanner.filter_expr(expr.clone());
}
}
} }
if let Some(fts) = &query.base.full_text_search { if let Some(fts) = &query.base.full_text_search {
@@ -2359,8 +2394,9 @@ mod tests {
use arrow_data::ArrayDataBuilder; use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, Schema, TimeUnit}; use arrow_schema::{DataType, Field, Schema, TimeUnit};
use futures::TryStreamExt; use futures::TryStreamExt;
use lance::dataset::{Dataset, WriteMode}; use lance::dataset::WriteMode;
use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance::Dataset;
use rand::Rng; use rand::Rng;
use tempfile::tempdir; use tempfile::tempdir;
@@ -2410,6 +2446,7 @@ mod tests {
let uri = tmp_dir.path().to_str().unwrap(); let uri = tmp_dir.path().to_str().unwrap();
let batches = make_test_batches(); let batches = make_test_batches();
let batches = Box::new(batches) as Box<dyn RecordBatchReader + Send>;
let table = NativeTable::create(uri, "test", batches, None, None, None) let table = NativeTable::create(uri, "test", batches, None, None, None)
.await .await
.unwrap(); .unwrap();

View File

@@ -4,6 +4,7 @@
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers. //! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use arrow_array::RecordBatch;
use arrow_schema::Schema as ArrowSchema; use arrow_schema::Schema as ArrowSchema;
use async_trait::async_trait; use async_trait::async_trait;
use datafusion_catalog::{Session, TableProvider}; use datafusion_catalog::{Session, TableProvider};
@@ -17,7 +18,7 @@ use futures::{TryFutureExt, TryStreamExt};
use super::{AnyQuery, BaseTable}; use super::{AnyQuery, BaseTable};
use crate::{ use crate::{
query::{QueryExecutionOptions, QueryRequest, Select}, query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
Result, Result,
}; };
@@ -104,7 +105,9 @@ impl ExecutionPlan for MetadataEraserExec {
) -> DataFusionResult<SendableRecordBatchStream> { ) -> DataFusionResult<SendableRecordBatchStream> {
let stream = self.input.execute(partition, context)?; let stream = self.input.execute(partition, context)?;
let schema = self.schema.clone(); let schema = self.schema.clone();
let stream = stream.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap()); let stream = stream.map_ok(move |batch| {
RecordBatch::try_new(schema.clone(), batch.columns().to_vec()).unwrap()
});
Ok( Ok(
Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream)) Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
as SendableRecordBatchStream, as SendableRecordBatchStream,
@@ -148,7 +151,7 @@ impl TableProvider for BaseTableAdapter {
async fn scan( async fn scan(
&self, &self,
_state: &dyn Session, state: &dyn Session,
projection: Option<&Vec<usize>>, projection: Option<&Vec<usize>>,
filters: &[Expr], filters: &[Expr],
limit: Option<usize>, limit: Option<usize>,
@@ -161,16 +164,28 @@ impl TableProvider for BaseTableAdapter {
.collect(); .collect();
query.select = Select::Columns(field_names); query.select = Select::Columns(field_names);
} }
assert!(filters.is_empty()); if !filters.is_empty() {
let first = filters.first().unwrap().clone();
let filter = filters[1..]
.iter()
.fold(first, |acc, expr| acc.and(expr.clone()));
query.filter = Some(QueryFilter::Datafusion(filter));
}
if let Some(limit) = limit { if let Some(limit) = limit {
query.limit = Some(limit); query.limit = Some(limit);
} else { } else {
// Need to override the default of 10 // Need to override the default of 10
query.limit = None; query.limit = None;
} }
let options = QueryExecutionOptions {
max_batch_length: state.config().batch_size() as u32,
..Default::default()
};
let plan = self let plan = self
.table .table
.create_plan(&AnyQuery::Query(query), QueryExecutionOptions::default()) .create_plan(&AnyQuery::Query(query), options)
.map_err(|err| DataFusionError::External(err.into())) .map_err(|err| DataFusionError::External(err.into()))
.await?; .await?;
Ok(Arc::new(MetadataEraserExec::new(plan))) Ok(Arc::new(MetadataEraserExec::new(plan)))
@@ -180,11 +195,7 @@ impl TableProvider for BaseTableAdapter {
&self, &self,
filters: &[&Expr], filters: &[&Expr],
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> { ) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
// TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
Ok(vec![
TableProviderFilterPushDown::Unsupported;
filters.len()
])
} }
fn statistics(&self) -> Option<Statistics> { fn statistics(&self) -> Option<Statistics> {
@@ -197,67 +208,291 @@ impl TableProvider for BaseTableAdapter {
pub mod tests { pub mod tests {
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow::array::AsArray;
use arrow_array::{
BinaryArray, Float64Array, Int32Array, Int64Array, RecordBatch, RecordBatchIterator,
RecordBatchReader, StringArray, UInt32Array,
};
use arrow_schema::{DataType, Field, Schema}; use arrow_schema::{DataType, Field, Schema};
use datafusion::{datasource::provider_as_source, prelude::SessionContext}; use datafusion::{
datasource::provider_as_source,
prelude::{SessionConfig, SessionContext},
};
use datafusion_catalog::TableProvider; use datafusion_catalog::TableProvider;
use datafusion_expr::LogicalPlanBuilder; use datafusion_execution::SendableRecordBatchStream;
use futures::TryStreamExt; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
use futures::{StreamExt, TryStreamExt};
use tempfile::tempdir; use tempfile::tempdir;
use crate::{connect, table::datafusion::BaseTableAdapter}; use crate::{
connect,
index::{scalar::BTreeIndexBuilder, Index},
table::datafusion::BaseTableAdapter,
};
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]); let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]);
let schema = Arc::new( let schema = Arc::new(
Schema::new(vec![Field::new("i", DataType::Int32, false)]).with_metadata(metadata), Schema::new(vec![
Field::new("i", DataType::Int32, false),
Field::new("indexed", DataType::UInt32, false),
])
.with_metadata(metadata),
); );
RecordBatchIterator::new( RecordBatchIterator::new(
vec![RecordBatch::try_new( vec![RecordBatch::try_new(
schema.clone(), schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))], vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
],
)], )],
schema, schema,
) )
} }
#[tokio::test] fn make_tbl_two_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
async fn test_metadata_erased() { let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]);
let tmp_dir = tempdir().unwrap(); let schema = Arc::new(
let dataset_path = tmp_dir.path().join("test.lance"); Schema::new(vec![
let uri = dataset_path.to_str().unwrap(); Field::new("ints", DataType::Int64, true),
Field::new("strings", DataType::Utf8, true),
let db = connect(uri).execute().await.unwrap(); Field::new("floats", DataType::Float64, true),
Field::new("jsons", DataType::Utf8, true),
let tbl = db Field::new("bins", DataType::Binary, true),
.create_table("foo", make_test_batches()) Field::new("nodates", DataType::Utf8, true),
.execute() ])
.await .with_metadata(metadata),
.unwrap();
let provider = Arc::new(
BaseTableAdapter::try_new(tbl.base_table().clone())
.await
.unwrap(),
); );
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from_iter_values(0..1000)),
Arc::new(StringArray::from_iter_values(
(0..1000).map(|i| i.to_string()),
)),
Arc::new(Float64Array::from_iter_values((0..1000).map(|i| i as f64))),
Arc::new(StringArray::from_iter_values(
(0..1000).map(|i| format!("{{\"i\":{}}}", i)),
)),
Arc::new(BinaryArray::from_iter_values(
(0..1000).map(|i| (i as u32).to_be_bytes().to_vec()),
)),
Arc::new(StringArray::from_iter_values(
(0..1000).map(|i| i.to_string()),
)),
],
)],
schema,
)
}
assert!(provider.schema().metadata().is_empty()); struct TestFixture {
_tmp_dir: tempfile::TempDir,
// An adapter for a table with make_test_batches batches
adapter: Arc<BaseTableAdapter>,
// an adapter for a table with make_tbl_two_test_batches batches
adapter2: Arc<BaseTableAdapter>,
}
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(provider), None) impl TestFixture {
async fn new() -> Self {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tbl = db
.create_table("foo", make_test_batches())
.execute()
.await
.unwrap();
tbl.create_index(&["indexed"], Index::BTree(BTreeIndexBuilder::default()))
.execute()
.await
.unwrap();
let tbl2 = db
.create_table("tbl2", make_tbl_two_test_batches())
.execute()
.await
.unwrap();
let adapter = Arc::new(
BaseTableAdapter::try_new(tbl.base_table().clone())
.await
.unwrap(),
);
let adapter2 = Arc::new(
BaseTableAdapter::try_new(tbl2.base_table().clone())
.await
.unwrap(),
);
Self {
_tmp_dir: tmp_dir,
adapter,
adapter2,
}
}
async fn plan_to_stream(plan: LogicalPlan) -> SendableRecordBatchStream {
Self::plan_to_stream_with_config(plan, SessionConfig::default()).await
}
async fn plan_to_stream_with_config(
plan: LogicalPlan,
config: SessionConfig,
) -> SendableRecordBatchStream {
SessionContext::new_with_config(config)
.execute_logical_plan(plan)
.await
.unwrap()
.execute_stream()
.await
.unwrap()
}
async fn plan_to_explain(plan: LogicalPlan) -> String {
let mut explain_stream = SessionContext::new()
.execute_logical_plan(plan)
.await
.unwrap()
.explain(true, false)
.unwrap()
.execute_stream()
.await
.unwrap();
let batch = explain_stream.try_next().await.unwrap().unwrap();
assert!(explain_stream.try_next().await.unwrap().is_none());
let plan_descs = batch.columns()[0].as_string::<i32>();
let plans = batch.columns()[1].as_string::<i32>();
for (desc, plan) in plan_descs.iter().zip(plans.iter()) {
if desc.unwrap() == "physical_plan" {
return plan.unwrap().to_string();
}
}
panic!("No physical plan found in explain output");
}
async fn check_plan(plan: LogicalPlan, expected: &str) {
let physical_plan = Self::plan_to_explain(plan).await;
let mut lines_checked = 0;
for (actual_line, expected_line) in physical_plan.lines().zip(expected.lines()) {
lines_checked += 1;
let actual_trimmed = actual_line.trim();
let expected_trimmed = if let Some(ellipsis_pos) = expected_line.find("...") {
expected_line[0..ellipsis_pos].trim()
} else {
expected_line.trim()
};
assert_eq!(&actual_trimmed[..expected_trimmed.len()], expected_trimmed);
}
assert_eq!(lines_checked, expected.lines().count());
}
}
#[tokio::test]
async fn test_batch_size() {
let fixture = TestFixture::new().await;
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter2), None)
.unwrap() .unwrap()
.build() .build()
.unwrap(); .unwrap();
let mut stream = SessionContext::new() let config = SessionConfig::default().with_batch_size(100);
.execute_logical_plan(plan)
.await let stream = TestFixture::plan_to_stream_with_config(plan.clone(), config).await;
let batch_count = stream.count().await;
assert_eq!(batch_count, 10);
let config = SessionConfig::default().with_batch_size(250);
let stream = TestFixture::plan_to_stream_with_config(plan, config).await;
let batch_count = stream.count().await;
assert_eq!(batch_count, 4);
}
#[tokio::test]
async fn test_metadata_erased() {
let fixture = TestFixture::new().await;
assert!(fixture.adapter.schema().metadata().is_empty());
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None)
.unwrap() .unwrap()
.execute_stream() .build()
.await
.unwrap(); .unwrap();
let mut stream = TestFixture::plan_to_stream(plan).await;
while let Some(batch) = stream.try_next().await.unwrap() { while let Some(batch) = stream.try_next().await.unwrap() {
assert!(batch.schema().metadata().is_empty()); assert!(batch.schema().metadata().is_empty());
} }
} }
#[tokio::test]
async fn test_metadata_erased_with_filter() {
// This is a regression test where the metadata eraser was not properly erasing metadata
let fixture = TestFixture::new().await;
assert!(fixture.adapter.schema().metadata().is_empty());
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter2), None)
.unwrap()
.filter(col("ints").lt(lit(10)))
.unwrap()
.build()
.unwrap();
let mut stream = TestFixture::plan_to_stream(plan).await;
while let Some(batch) = stream.try_next().await.unwrap() {
assert!(batch.schema().metadata().is_empty());
}
}
#[tokio::test]
async fn test_filter_pushdown() {
let fixture = TestFixture::new().await;
// Basic filter, not much different pushed down than run from DF
let plan =
LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter.clone()), None)
.unwrap()
.filter(col("i").gt_eq(lit(5)))
.unwrap()
.build()
.unwrap();
TestFixture::check_plan(
plan,
"MetadataEraserExec
RepartitionExec:...
CoalesceBatchesExec:...
FilterExec: i@0 >= 5
ProjectionExec:...
LanceScan:...",
)
.await;
// Filter utilizing scalar index, make sure it gets pushed down
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None)
.unwrap()
.filter(col("indexed").eq(lit(5)))
.unwrap()
.build()
.unwrap();
TestFixture::check_plan(plan, "").await;
}
} }