Compare commits

..

28 Commits

Author SHA1 Message Date
Lance Release
74004161ff [python] Bump version: 0.2.4 → 0.2.5 2023-09-19 16:43:06 +00:00
Lance Release
34ddb1de6d Updating package-lock.json 2023-09-19 13:48:20 +00:00
Lance Release
1029fc9cb0 Updating package-lock.json 2023-09-19 12:19:23 +00:00
Lance Release
31c5df6d99 Bump version: 0.2.5 → 0.2.6 2023-09-19 12:19:05 +00:00
Rob Meng
dbf37a0434 fix: upgrade lance to 0.7.5 and add tests for searching empty dataset (#505)
This PR upgrade lance to `0.7.5`, which include fixes for searching an
empty dataset.

This PR also adds two tests in node SDK to make sure searching empty
dataset do no throw

Co-authored-by: rmeng <rob@lancedb.com>
2023-09-18 22:12:11 -07:00
Chang She
f20f19b804 feat: improve pydantic 1.x compat (#503) 2023-09-18 19:01:30 -07:00
Chang She
55207ce844 feat: add lancedb.__version__ (#504) 2023-09-18 18:51:51 -07:00
Chang She
c21f9cdda0 ci: fix docs build (#496)
python/python.md contains typos in the class references

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-09-18 13:07:21 -07:00
Rob Meng
bc38abb781 refactor connection string processing (#500)
in #486 `connect` started converting path into uri. However, the PR
didn't handle relative path and appended `file://` to relative path.

This PR changes the parsing strat to be more rational. If a path is
provided instead of url, we do not try anythinng special.

engine and engine params may only be specified when a url with schema is
provided

Co-authored-by: rmeng <rob@lancedb.com>
2023-09-18 12:38:00 -07:00
Rob Meng
731f86e44c add health check to wait for all service ready before next step (#501)
aws integration tests are flaky because we didn't wait for the services
to become healthy. (we only waited for the localstack service, this PR
adds wait for sub services)
2023-09-18 15:17:45 -04:00
Chang She
31dad71c94 multi-modal embedding-function (#484) 2023-09-16 21:23:51 -04:00
Will Jones
9585f550b3 fix: increase S3 timeouts (#494)
Closes #493
2023-09-15 20:21:34 -07:00
Lance Release
8dc2315479 [python] Bump version: 0.2.3 → 0.2.4 2023-09-15 14:23:26 +00:00
Rob Meng
f6bfb5da11 chore: upgrade lance to 0.7.4 (#491) 2023-09-14 16:02:23 -04:00
Lance Release
661fcecf38 [python] Bump version: 0.2.2 → 0.2.3 2023-09-14 17:48:32 +00:00
Lance Release
07fe284810 Updating package-lock.json 2023-09-10 23:58:06 +00:00
Lance Release
800bb691c3 Updating package-lock.json 2023-09-09 19:45:58 +00:00
Lance Release
ec24e09add Bump version: 0.2.4 → 0.2.5 2023-09-09 19:45:43 +00:00
Rob Meng
0554db03b3 progagate uri query string to lance; add aws integration tests (#486)
# WARNING: specifying engine is NOT a publicly supported feature in
lancedb yet. THE API WILL CHANGE.

This PR exposes dynamodb based commit to `vectordb` and JS SDK (will do
python in another PR since it's on a different release track)

This PR also added aws integration test using `localstack`

## What?
This PR adds uri parameters to DB connection string. User may specify
`engine` in the connection string to let LanceDB know that the user
wants to use an external store when reading and writing a table. User
may also pass any parameters required by the commitStore in the
connection string, these parameters will be propagated to lance.

e.g.
```
vectordb.connect("s3://my-db-bucket?engine=ddb&ddbTableName=my-commit-table")
```
will automatically convert table path to
```
s3+ddb://my-db-bucket/my_table.lance?&ddbTableName=my-commit-table
```
2023-09-09 13:33:16 -04:00
Lei Xu
b315ea3978 [Python] Pydantic vector field with default value (#474)
Rename `lance.pydantic.vector` to `Vector` and deprecate `vector(dim)`
2023-09-08 22:35:31 -07:00
Ayush Chaurasia
aa7806cf0d [Python]Fix record_batch_generator (#483)
Should fix - https://github.com/lancedb/lancedb/issues/482
2023-09-08 21:18:50 +05:30
Lei Xu
6799613109 feat: upgrade lance to 0.7.3 (#481) 2023-09-07 17:01:45 -07:00
Lei Xu
0f26915d22 [Rust] schema coerce and vector column inference (#476)
Split the rust core from #466 for easy review and less merge conflicts.
2023-09-06 10:00:46 -07:00
Chang She
32163063dc Fix up docs (#477) 2023-09-05 22:29:50 -07:00
Chang She
9a9a73a65d [python] Use pydantic for embedding function persistence (#467)
1. Support persistent embedding function so users can just search using
query string
2. Add fixed size list conversion for multiple vector columns
3. Add support for empty query (just apply select/where/limit).
4. Refactor and simplify some of the data prep code

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
2023-09-05 21:30:45 -07:00
Ayush Chaurasia
52fa7f5577 [Docs] Small typo fixes (#460) 2023-09-02 22:17:19 +05:30
Chang She
0cba0f4f92 [python] Temporary update feature (#457)
Combine delete and append to make a temporary update feature that is
only enabled for the local python lancedb.

The reason why this is temporary is because it first has to load the
data that matches the where clause into memory, which is technical
unbounded.

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-08-30 00:25:26 -07:00
Will Jones
8391ffee84 chore: make crate more discoverable (#443)
A few small changes to make the Rust crate more discoverable.
2023-08-25 08:59:14 -07:00
53 changed files with 2499 additions and 369 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.2.4 current_version = 0.2.6
commit = True commit = True
message = Bump version: {current_version} → {new_version} message = Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -9,6 +9,7 @@ on:
- node/** - node/**
- rust/ffi/node/** - rust/ffi/node/**
- .github/workflows/node.yml - .github/workflows/node.yml
- docker-compose.yml
env: env:
# Disable full debug symbol generation to speed up CI build and keep memory down # Disable full debug symbol generation to speed up CI build and keep memory down
@@ -107,3 +108,56 @@ jobs:
- name: Test - name: Test
run: | run: |
npm run test npm run test
aws-integtest:
timeout-minutes: 45
runs-on: "ubuntu-22.04"
defaults:
run:
shell: bash
working-directory: node
env:
AWS_ACCESS_KEY_ID: ACCESSKEY
AWS_SECRET_ACCESS_KEY: SECRETKEY
AWS_DEFAULT_REGION: us-west-2
# this one is for s3
AWS_ENDPOINT: http://localhost:4566
# this one is for dynamodb
DYNAMODB_ENDPOINT: http://localhost:4566
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
with:
node-version: 18
cache: 'npm'
cache-dependency-path: node/package-lock.json
- name: start local stack
run: docker compose -f ../docker-compose.yml up -d --wait
- name: create s3
run: aws s3 mb s3://lancedb-integtest --endpoint $AWS_ENDPOINT
- name: create ddb
run: |
aws dynamodb create-table \
--table-name lancedb-integtest \
--attribute-definitions '[{"AttributeName": "base_uri", "AttributeType": "S"}, {"AttributeName": "version", "AttributeType": "N"}]' \
--key-schema '[{"AttributeName": "base_uri", "KeyType": "HASH"}, {"AttributeName": "version", "KeyType": "RANGE"}]' \
--provisioned-throughput '{"ReadCapacityUnits": 10, "WriteCapacityUnits": 10}' \
--endpoint-url $DYNAMODB_ENDPOINT
- uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: |
npm ci
npm run tsc
npm run build
npm run pack-build
npm install --no-save ./dist/lancedb-vectordb-*.tgz
# Remove index.node to test with dependency installed
rm index.node
- name: Test
run: npm run integration-test

View File

@@ -38,7 +38,7 @@ jobs:
- name: isort - name: isort
run: isort --check --diff --quiet . run: isort --check --diff --quiet .
- name: Run tests - name: Run tests
run: pytest -x -v --durations=30 tests run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest - name: doctest
run: pytest --doctest-modules lancedb run: pytest --doctest-modules lancedb
mac: mac:
@@ -65,4 +65,34 @@ jobs:
- name: Black - name: Black
run: black --check --diff --no-color --quiet . run: black --check --diff --no-color --quiet .
- name: Run tests - name: Run tests
run: pytest -x -v --durations=30 tests run: pytest -m "not slow" -x -v --durations=30 tests
pydantic1x:
timeout-minutes: 30
runs-on: "ubuntu-22.04"
defaults:
run:
shell: bash
working-directory: python
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Install lancedb
run: |
pip install "pydantic<2"
pip install -e .[tests]
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock black isort
- name: Black
run: black --check --diff --no-color --quiet .
- name: isort
run: isort --check --diff --quiet .
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest
run: pytest --doctest-modules lancedb

View File

@@ -1,16 +1,25 @@
[workspace] [workspace]
members = [ members = ["rust/ffi/node", "rust/vectordb"]
"rust/vectordb", # Python package needs to be built by maturin.
"rust/ffi/node" exclude = ["python"]
]
resolver = "2" resolver = "2"
[workspace.dependencies] [workspace.dependencies]
lance = "=0.6.5" lance = { "version" = "=0.7.5", "features" = ["dynamodb"] }
lance-linalg = { "version" = "=0.7.5" }
# Note that this one does not include pyarrow
arrow = { version = "43.0.0", optional = false }
arrow-array = "43.0" arrow-array = "43.0"
arrow-data = "43.0" arrow-data = "43.0"
arrow-schema = "43.0"
arrow-ipc = "43.0" arrow-ipc = "43.0"
half = { "version" = "=2.2.1", default-features = false } arrow-ord = "43.0"
arrow-schema = "43.0"
arrow-arith = "43.0"
arrow-cast = "43.0"
half = { "version" = "=2.2.1", default-features = false, features = [
"num-traits"
] }
log = "0.4"
object_store = "0.6.1" object_store = "0.6.1"
snafu = "0.7.4" snafu = "0.7.4"
url = "2"

18
docker-compose.yml Normal file
View File

@@ -0,0 +1,18 @@
version: "3.9"
services:
localstack:
image: localstack/localstack:0.14
ports:
- 4566:4566
environment:
- SERVICES=s3,dynamodb
- DEBUG=1
- LS_LOG=trace
- DOCKER_HOST=unix:///var/run/docker.sock
- AWS_ACCESS_KEY_ID=ACCESSKEY
- AWS_SECRET_ACCESS_KEY=SECRETKEY
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]
interval: 5s
retries: 3
start_period: 10s

View File

@@ -67,6 +67,11 @@ nav:
- Home: - Home:
- 🏢 Home: index.md - 🏢 Home: index.md
- 💡 Basics: basic.md - 💡 Basics: basic.md
- 📚 Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- 🧬 Embeddings: embedding.md - 🧬 Embeddings: embedding.md
- 🔍 Python full-text search: fts.md - 🔍 Python full-text search: fts.md
- 🔌 Integrations: - 🔌 Integrations:
@@ -91,12 +96,12 @@ nav:
- Serverless Website Chatbot: examples/serverless_website_chatbot.md - Serverless Website Chatbot: examples/serverless_website_chatbot.md
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md - YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md - TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
- 📚 Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- Basics: basic.md - Basics: basic.md
- Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- Embeddings: embedding.md - Embeddings: embedding.md
- Python full-text search: fts.md - Python full-text search: fts.md
- Integrations: - Integrations:
@@ -121,12 +126,6 @@ nav:
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md - YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md - Serverless Chatbot from any website: examples/serverless_website_chatbot.md
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md - TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
- Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- API references: - API references:
- Python API: python/python.md - Python API: python/python.md
- Javascript API: javascript/modules.md - Javascript API: javascript/modules.md

View File

@@ -49,11 +49,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
db.create_table("table2", data) db.create_table("table2", data)
db["table2"].head() db["table2"].head()
``` ```
!!! info "Note" !!! info "Note"
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly. Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
```python ```python
custom_schema = pa.schema([ custom_schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -66,7 +66,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
### From PyArrow Tables ### From PyArrow Tables
You can also create LanceDB tables directly from pyarrow tables You can also create LanceDB tables directly from pyarrow tables
```python ```python
table = pa.Table.from_arrays( table = pa.Table.from_arrays(
[ [
@@ -84,18 +84,28 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
``` ```
### From Pydantic Models ### From Pydantic Models
LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method. When you create an empty table without data, you must specify the table schema.
LanceDB supports creating tables by specifying a pyarrow schema or a specialized
pydantic model called `LanceModel`.
For example, the following Content model specifies a table with 5 columns:
movie_id, vector, genres, title, and imdb_id. When you create a table, you can
pass the class as the value of the `schema` parameter to `create_table`.
The `vector` column is a `Vector` type, which is a specialized pydantic type that
can be configured with the vector dimensions. It is also important to note that
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
(which itself derives from `pydantic.BaseModel`).
```python ```python
from lancedb.pydantic import vector, LanceModel from lancedb.pydantic import Vector, LanceModel
class Content(LanceModel): class Content(LanceModel):
movie_id: int movie_id: int
vector: vector(128) vector: Vector(128)
genres: str genres: str
title: str title: str
imdb_id: int imdb_id: int
@property @property
def imdb_url(self) -> str: def imdb_url(self) -> str:
return f"https://www.imdb.com/title/tt{self.imdb_id}" return f"https://www.imdb.com/title/tt{self.imdb_id}"
@@ -103,7 +113,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
import pyarrow as pa import pyarrow as pa
db = lancedb.connect("~/.lancedb") db = lancedb.connect("~/.lancedb")
table_name = "movielens_small" table_name = "movielens_small"
table = db.create_table(table_name, schema=Content.to_arrow_schema()) table = db.create_table(table_name, schema=Content)
``` ```
### Using Iterators / Writing Large Datasets ### Using Iterators / Writing Large Datasets
@@ -113,7 +123,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types. LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types.
Here's an example using using `RecordBatch` iterator for creating tables. Here's an example using using `RecordBatch` iterator for creating tables.
```python ```python
import pyarrow as pa import pyarrow as pa
@@ -142,11 +152,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
## Creating Empty Table ## Creating Empty Table
You can also create empty tables in python. Initialize it with schema and later ingest data into it. You can also create empty tables in python. Initialize it with schema and later ingest data into it.
```python ```python
import lancedb import lancedb
import pyarrow as pa import pyarrow as pa
schema = pa.schema( schema = pa.schema(
[ [
pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -168,8 +178,8 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
from lancedb.pydantic import LanceModel, vector from lancedb.pydantic import LanceModel, vector
class Model(LanceModel): class Model(LanceModel):
vector: vector(2) vector: Vector(2)
tbl = db.create_table("table5", schema=Model.to_arrow_schema()) tbl = db.create_table("table5", schema=Model.to_arrow_schema())
``` ```
@@ -249,7 +259,7 @@ After a table has been created, you can always add more data to it using
You can also add a large dataset batch in one go using Iterator of any supported data types. You can also add a large dataset batch in one go using Iterator of any supported data types.
### Adding to table using Iterator ### Adding to table using Iterator
```python ```python
import pandas as pd import pandas as pd
@@ -261,10 +271,10 @@ After a table has been created, you can always add more data to it using
"item": ["foo", "bar"], "item": ["foo", "bar"],
"price": [10.0, 20.0], "price": [10.0, 20.0],
}) })
tbl.add(make_batches()) tbl.add(make_batches())
``` ```
The other arguments accepted: The other arguments accepted:
| Name | Type | Description | Default | | Name | Type | Description | Default |
@@ -274,7 +284,7 @@ After a table has been created, you can always add more data to it using
| on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop | | on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop |
| fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 | | fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 |
=== "Javascript/Typescript" === "Javascript/Typescript"
```javascript ```javascript
@@ -312,7 +322,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector # x vector
# 0 1 [1.0, 2.0] # 0 1 [1.0, 2.0]
# 1 3 [5.0, 6.0] # 1 3 [5.0, 6.0]
``` ```
### Delete from a list of values ### Delete from a list of values
@@ -325,7 +335,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector # x vector
# 0 3 [5.0, 6.0] # 0 3 [5.0, 6.0]
``` ```
=== "Javascript/Typescript" === "Javascript/Typescript"
```javascript ```javascript

View File

@@ -1,6 +1,6 @@
# LanceDB # LanceDB
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings. LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrieval, filtering and management of embeddings.
![Illustration](/lancedb/assets/ecosystem-illustration.png) ![Illustration](/lancedb/assets/ecosystem-illustration.png)

View File

@@ -249,11 +249,11 @@
} }
], ],
"source": [ "source": [
"from lancedb.pydantic import vector, LanceModel\n", "from lancedb.pydantic import Vector, LanceModel\n",
"\n", "\n",
"class Content(LanceModel):\n", "class Content(LanceModel):\n",
" movie_id: int\n", " movie_id: int\n",
" vector: vector(128)\n", " vector: Vector(128)\n",
" genres: str\n", " genres: str\n",
" title: str\n", " title: str\n",
" imdb_id: int\n", " imdb_id: int\n",
@@ -359,7 +359,7 @@
"import pandas as pd\n", "import pandas as pd\n",
"\n", "\n",
"class PydanticSchema(LanceModel):\n", "class PydanticSchema(LanceModel):\n",
" vector: vector(2)\n", " vector: Vector(2)\n",
" item: str\n", " item: str\n",
" price: float\n", " price: float\n",
"\n", "\n",
@@ -394,10 +394,10 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import lancedb\n", "import lancedb\n",
"from lancedb.pydantic import LanceModel, vector\n", "from lancedb.pydantic import LanceModel, Vector\n",
"\n", "\n",
"class Model(LanceModel):\n", "class Model(LanceModel):\n",
" vector: vector(2)\n", " vector: Vector(2)\n",
"\n", "\n",
"tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())" "tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())"
] ]

View File

@@ -13,10 +13,10 @@ via [pydantic_to_schema()](python.md##lancedb.pydantic.pydantic_to_schema) metho
## Vector Field ## Vector Field
LanceDB provides a [`vector(dim)`](python.md#lancedb.pydantic.vector) method to define a LanceDB provides a [`Vector(dim)`](python.md#lancedb.pydantic.Vector) method to define a
vector Field in a Pydantic Model. vector Field in a Pydantic Model.
::: lancedb.pydantic.vector ::: lancedb.pydantic.Vector
## Type Conversion ## Type Conversion
@@ -33,4 +33,4 @@ Current supported type conversions:
| `str` | `pyarrow.utf8()` | | `str` | `pyarrow.utf8()` |
| `list` | `pyarrow.List` | | `list` | `pyarrow.List` |
| `BaseModel` | `pyarrow.Struct` | | `BaseModel` | `pyarrow.Struct` |
| `vector(n)` | `pyarrow.FixedSizeList(float32, n)` | | `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |

View File

@@ -26,9 +26,19 @@ pip install lancedb
## Embeddings ## Embeddings
::: lancedb.embeddings.with_embeddings ::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
::: lancedb.embeddings.EmbeddingFunction ::: lancedb.embeddings.functions.EmbeddingFunction
::: lancedb.embeddings.functions.TextEmbeddingFunction
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
::: lancedb.embeddings.functions.OpenAIEmbeddings
::: lancedb.embeddings.functions.OpenClipEmbeddings
::: lancedb.embeddings.with_embeddings
## Context ## Context

View File

@@ -8,7 +8,8 @@ excluded_globs = [
"../src/embedding.md", "../src/embedding.md",
"../src/examples/*.md", "../src/examples/*.md",
"../src/integrations/voxel51.md", "../src/integrations/voxel51.md",
"../src/guides/tables.md" "../src/guides/tables.md",
"../src/python/duckdb.md",
] ]
python_prefix = "py" python_prefix = "py"

105
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.2.4", "version": "0.2.6",
"lockfileVersion": 2, "lockfileVersion": 2,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.2.4", "version": "0.2.6",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -31,6 +31,7 @@
"@types/node": "^18.16.2", "@types/node": "^18.16.2",
"@types/sinon": "^10.0.15", "@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1", "@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1", "cargo-cp-artifact": "^0.1",
"chai": "^4.3.7", "chai": "^4.3.7",
@@ -48,14 +49,15 @@
"ts-node-dev": "^2.0.0", "ts-node-dev": "^2.0.0",
"typedoc": "^0.24.7", "typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3", "typedoc-plugin-markdown": "^3.15.3",
"typescript": "*" "typescript": "*",
"uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.2.4", "@lancedb/vectordb-darwin-arm64": "0.2.6",
"@lancedb/vectordb-darwin-x64": "0.2.4", "@lancedb/vectordb-darwin-x64": "0.2.6",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4", "@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
"@lancedb/vectordb-linux-x64-gnu": "0.2.4", "@lancedb/vectordb-linux-x64-gnu": "0.2.6",
"@lancedb/vectordb-win32-x64-msvc": "0.2.4" "@lancedb/vectordb-win32-x64-msvc": "0.2.6"
} }
}, },
"node_modules/@apache-arrow/ts": { "node_modules/@apache-arrow/ts": {
@@ -315,9 +317,9 @@
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": { "node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==", "integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -327,9 +329,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-darwin-x64": { "node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==", "integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -339,9 +341,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-gnu": { "node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==", "integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -351,9 +353,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-gnu": { "node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==", "integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -363,9 +365,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-win32-x64-msvc": { "node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==", "integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -596,6 +598,12 @@
"@types/node": "*" "@types/node": "*"
} }
}, },
"node_modules/@types/uuid": {
"version": "9.0.3",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
"dev": true
},
"node_modules/@typescript-eslint/eslint-plugin": { "node_modules/@typescript-eslint/eslint-plugin": {
"version": "5.59.1", "version": "5.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
@@ -4451,6 +4459,15 @@
"punycode": "^2.1.0" "punycode": "^2.1.0"
} }
}, },
"node_modules/uuid": {
"version": "9.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
"dev": true,
"bin": {
"uuid": "dist/bin/uuid"
}
},
"node_modules/v8-compile-cache-lib": { "node_modules/v8-compile-cache-lib": {
"version": "3.0.1", "version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
@@ -4852,33 +4869,33 @@
} }
}, },
"@lancedb/vectordb-darwin-arm64": { "@lancedb/vectordb-darwin-arm64": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==", "integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-darwin-x64": { "@lancedb/vectordb-darwin-x64": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==", "integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-linux-arm64-gnu": { "@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==", "integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-linux-x64-gnu": { "@lancedb/vectordb-linux-x64-gnu": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==", "integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-win32-x64-msvc": { "@lancedb/vectordb-win32-x64-msvc": {
"version": "0.2.4", "version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==", "integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
"optional": true "optional": true
}, },
"@neon-rs/cli": { "@neon-rs/cli": {
@@ -5093,6 +5110,12 @@
"@types/node": "*" "@types/node": "*"
} }
}, },
"@types/uuid": {
"version": "9.0.3",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
"dev": true
},
"@typescript-eslint/eslint-plugin": { "@typescript-eslint/eslint-plugin": {
"version": "5.59.1", "version": "5.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
@@ -7844,6 +7867,12 @@
"punycode": "^2.1.0" "punycode": "^2.1.0"
} }
}, },
"uuid": {
"version": "9.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
"dev": true
},
"v8-compile-cache-lib": { "v8-compile-cache-lib": {
"version": "3.0.1", "version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.2.4", "version": "0.2.6",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
@@ -9,6 +9,7 @@
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json", "build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
"build-release": "npm run build -- --release", "build-release": "npm run build -- --release",
"test": "npm run tsc && mocha -recursive dist/test", "test": "npm run tsc && mocha -recursive dist/test",
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
"lint": "eslint native.js src --ext .js,.ts", "lint": "eslint native.js src --ext .js,.ts",
"clean": "rm -rf node_modules *.node dist/", "clean": "rm -rf node_modules *.node dist/",
"pack-build": "neon pack-build", "pack-build": "neon pack-build",
@@ -34,6 +35,7 @@
"@types/node": "^18.16.2", "@types/node": "^18.16.2",
"@types/sinon": "^10.0.15", "@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1", "@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1", "cargo-cp-artifact": "^0.1",
"chai": "^4.3.7", "chai": "^4.3.7",
@@ -51,7 +53,8 @@
"ts-node-dev": "^2.0.0", "ts-node-dev": "^2.0.0",
"typedoc": "^0.24.7", "typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3", "typedoc-plugin-markdown": "^3.15.3",
"typescript": "*" "typescript": "*",
"uuid": "^9.0.0"
}, },
"dependencies": { "dependencies": {
"@apache-arrow/ts": "^12.0.0", "@apache-arrow/ts": "^12.0.0",
@@ -78,10 +81,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.2.4", "@lancedb/vectordb-darwin-arm64": "0.2.6",
"@lancedb/vectordb-darwin-x64": "0.2.4", "@lancedb/vectordb-darwin-x64": "0.2.6",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4", "@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
"@lancedb/vectordb-linux-x64-gnu": "0.2.4", "@lancedb/vectordb-linux-x64-gnu": "0.2.6",
"@lancedb/vectordb-win32-x64-msvc": "0.2.4" "@lancedb/vectordb-win32-x64-msvc": "0.2.6"
} }
} }

View File

@@ -0,0 +1,43 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { describe } from 'mocha'
import * as chai from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import { v4 as uuidv4 } from 'uuid'
import * as lancedb from '../index'
const assert = chai.assert
chai.use(chaiAsPromised)
describe('LanceDB AWS Integration test', function () {
it('s3+ddb schema is processed correctly', async function () {
this.timeout(15000)
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
// THE API WILL CHANGE
const conn = await lancedb.connect('s3://lancedb-integtest?engine=ddb&ddbTableName=lancedb-integtest')
const data = [{ vector: Array(128).fill(1.0) }]
const tableName = uuidv4()
let table = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
const futs = [table.add(data), table.add(data), table.add(data), table.add(data), table.add(data)]
await Promise.allSettled(futs)
table = await conn.openTable(tableName)
assert.equal(await table.countRows(), 6)
})
})

View File

@@ -19,7 +19,7 @@ import * as chaiAsPromised from 'chai-as-promised'
import * as lancedb from '../index' import * as lancedb from '../index'
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index' import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
import { Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray } from 'apache-arrow' import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
const expect = chai.expect const expect = chai.expect
const assert = chai.assert const assert = chai.assert
@@ -258,6 +258,36 @@ describe('LanceDB client', function () {
}) })
}) })
describe('when searching an empty dataset', function () {
it('should not fail', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
const result = await table.search(Array(128).fill(0.1)).execute()
assert.isEmpty(result)
})
})
describe('when searching an empty-after-delete dataset', function () {
it('should not fail', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
await table.add([{ vector: Array(128).fill(0.1) }])
await table.delete('vector IS NOT NULL')
const result = await table.search(Array(128).fill(0.1)).execute()
assert.isEmpty(result)
})
})
describe('when creating a vector index', function () { describe('when creating a vector index', function () {
it('overwrite all records in a table', async function () { it('overwrite all records in a table', async function () {
const uri = await createTestDB(32, 300) const uri = await createTestDB(32, 300)

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.2.2 current_version = 0.2.5
commit = True commit = True
message = [python] Bump version: {current_version} → {new_version} message = [python] Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -11,12 +11,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib.metadata
from typing import Optional from typing import Optional
from .db import URI, DBConnection, LanceDBConnection from .db import URI, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection from .remote.db import RemoteDBConnection
from .schema import vector from .schema import vector
__version__ = importlib.metadata.version("lancedb")
def connect( def connect(
uri: URI, uri: URI,
@@ -31,9 +34,13 @@ def connect(
---------- ----------
uri: str or Path uri: str or Path
The uri of the database. The uri of the database.
api_token: str, optional api_key: str, optional
If presented, connect to LanceDB cloud. If presented, connect to LanceDB cloud.
Otherwise, connect to a database on file system or cloud storage. Otherwise, connect to a database on file system or cloud storage.
region: str, default "us-west-2"
The region to use for LanceDB Cloud.
host_override: str, optional
The override url for LanceDB Cloud.
Examples Examples
-------- --------

View File

@@ -1,7 +1,10 @@
import os import os
import numpy as np
import pytest import pytest
from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
# import lancedb so we don't have to in every example # import lancedb so we don't have to in every example
@@ -14,3 +17,24 @@ def doctest_setup(monkeypatch, tmpdir):
monkeypatch.setitem(os.environ, "COLUMNS", "80") monkeypatch.setitem(os.environ, "COLUMNS", "80")
# Work in a temporary directory # Work in a temporary directory
monkeypatch.chdir(tmpdir) monkeypatch.chdir(tmpdir)
registry = EmbeddingFunctionRegistry.get_instance()
@registry.register("test")
class MockTextEmbeddingFunction(TextEmbeddingFunction):
"""
Return the hash of the first 10 characters
"""
def generate_embeddings(self, texts):
return [self._compute_one_embedding(row) for row in texts]
def _compute_one_embedding(self, row):
emb = np.array([float(hash(c)) for c in row[:10]])
emb /= np.linalg.norm(emb)
return emb
def ndims(self):
return 10

View File

@@ -16,12 +16,13 @@ from __future__ import annotations
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional, Union
import pyarrow as pa import pyarrow as pa
from pyarrow import fs from pyarrow import fs
from .common import DATA, URI from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel from .pydantic import LanceModel
from .table import LanceTable, Table from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme from .util import fs_from_uri, get_uri_location, get_uri_scheme
@@ -40,7 +41,7 @@ class DBConnection(ABC):
self, self,
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[pa.Schema, LanceModel] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create", mode: str = "create",
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
@@ -285,10 +286,11 @@ class LanceDBConnection(DBConnection):
self, self,
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[pa.Schema, LanceModel] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create", mode: str = "create",
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> LanceTable: ) -> LanceTable:
"""Create a table in the database. """Create a table in the database.
@@ -307,6 +309,7 @@ class LanceDBConnection(DBConnection):
mode=mode, mode=mode,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
embedding_functions=embedding_functions,
) )
return tbl return tbl

View File

@@ -0,0 +1,24 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .functions import (
EmbeddingFunction,
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
OpenAIEmbeddings,
OpenClipEmbeddings,
SentenceTransformerEmbeddings,
TextEmbeddingFunction,
)
from .utils import with_embeddings

View File

@@ -0,0 +1,577 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import importlib
import io
import json
import os
import socket
import urllib.error
import urllib.parse as urlparse
import urllib.request
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel, Field, PrivateAttr
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing.
You can implement your own embedding function by subclassing EmbeddingFunction
or TextEmbeddingFunction and registering it with the registry.
Examples
--------
>>> registry = EmbeddingFunctionRegistry.get_instance()
>>> @registry.register("my-embedding-function")
... class MyEmbeddingFunction(EmbeddingFunction):
... def ndims(self) -> int:
... return 128
...
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
... return self.compute_source_embeddings(query, *args, **kwargs)
...
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
...
>>> registry.get("my-embedding-function")
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
"""
@classmethod
def get_instance(cls):
return __REGISTRY__
def __init__(self):
self._functions = {}
def register(self, alias: str = None):
"""
This creates a decorator that can be used to register
an EmbeddingFunction.
Parameters
----------
alias : Optional[str]
a human friendly name for the embedding function. If not
provided, the class name will be used.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunction):
raise TypeError("Must be a subclass of EmbeddingFunction")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
key = alias or cls.__name__
self._functions[key] = cls
cls.__embedding_function_registry_alias__ = alias
return cls
return decorator
def reset(self):
"""
Reset the registry to its initial state
"""
self._functions = {}
def get(self, name: str):
"""
Fetch an embedding function class by name
Parameters
----------
name : str
The name of the embedding function to fetch
Either the alias or the class name if no alias was provided
during registration
"""
return self._functions[name]
def parse_functions(
self, metadata: Optional[Dict[bytes, bytes]]
) -> Dict[str, "EmbeddingFunctionConfig"]:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
embedding function and source column
Parameters
----------
metadata : Optional[Dict[bytes, bytes]]
The metadata from an arrow table. Note that
the keys and values are bytes (pyarrow api)
Returns
-------
functions : dict
A mapping of vector column name to embedding function.
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None or b"embedding_functions" not in metadata:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
return {
obj["vector_column"]: EmbeddingFunctionConfig(
vector_column=obj["vector_column"],
source_column=obj["source_column"],
function=self.get(obj["name"])(**obj["model"]),
)
for obj in raw_list
}
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
func = conf.function
name = getattr(
func, "__embedding_function_registry_alias__", func.__class__.__name__
)
json_data = func.safe_model_dump()
return {
"name": name,
"model": json_data,
"source_column": conf.source_column,
"vector_column": conf.vector_column,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector configs
into a config dictionary that can be serialized into arrow metadata
"""
if func_list is None or len(func_list) == 0:
return None
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes
# so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
# Global instance
__REGISTRY__ = EmbeddingFunctionRegistry()
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
class EmbeddingFunction(BaseModel, ABC):
"""
An ABC for embedding functions.
All concrete embedding functions must implement the following:
1. compute_query_embeddings() which takes a query and returns a list of embeddings
2. get_source_embeddings() which returns a list of embeddings for the source column
For text data, the two will be the same. For multi-modal data, the source column
might be images and the vector column might be text.
3. ndims method which returns the number of dimensions of the vector column
"""
_ndims: int = PrivateAttr()
@classmethod
def create(cls, **kwargs):
"""
Create an instance of the embedding function
"""
return cls(**kwargs)
@abstractmethod
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for a given user query
"""
pass
@abstractmethod
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database
"""
pass
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, pa.Array):
texts = texts.to_pylist()
elif isinstance(texts, pa.ChunkedArray):
texts = texts.combine_chunks().to_pylist()
return texts
@classmethod
def safe_import(cls, module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
return importlib.import_module(module)
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION
if PYDANTIC_VERSION.major < 2:
return dict(self)
return self.model_dump()
@abstractmethod
def ndims(self):
"""
Return the dimensions of the vector column
"""
pass
def SourceField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the source column for this embedding function
"""
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
def VectorField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the target vector column for this embedding function
"""
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
class EmbeddingFunctionConfig(BaseModel):
"""
This model encapsulates the configuration for a embedding function
in a lancedb table. It holds the embedding function, the source column,
and the vector column
"""
vector_column: str
source_column: str
function: EmbeddingFunction
class TextEmbeddingFunction(EmbeddingFunction):
"""
A callable ABC for embedding functions that take text as input
"""
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, *args, **kwargs)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
pass
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the sentence-transformers library
https://huggingface.co/sentence-transformers
"""
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
@property
def embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.__class__.get_embedding_model(self.name, self.device)
def ndims(self):
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@classmethod
@cached(cache={})
def get_embedding_model(cls, name, device):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
Parameters
----------
name : str
The name of the model to load
device : str
The device to load the model on
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
sentence_transformers = cls.safe_import(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(name, device=device)
@register("openai")
class OpenAIEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the OpenAI API
https://platform.openai.com/docs/guides/embeddings
"""
name: str = "text-embedding-ada-002"
def ndims(self):
# TODO don't hardcode this
return 1536
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
openai = self.safe_import("openai")
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
return [v["embedding"] for v in rs]
@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the OpenClip API
For multi-modal text-to-image search
https://github.com/mlfoundations/open_clip
"""
name: str = "ViT-B-32"
pretrained: str = "laion2b_s34b_b79k"
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_preprocess = PrivateAttr()
_tokenizer = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip")
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
model.to(self.device)
self._model, self._preprocess = model, preprocess
self._tokenizer = open_clip.get_tokenizer(self.name)
self._ndims = None
def ndims(self):
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
with torch.no_grad():
text_features = self._model.encode_text(text.to(self.device))
if self.normalize:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().squeeze()
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given images
"""
images = self.sanitize_input(images)
embeddings = []
for i in range(0, len(images), self.batch_size):
j = min(i + self.batch_size, len(images))
batch = images[i:j]
embeddings.extend(self._parallel_get(batch))
return embeddings
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
"""
Issue concurrent requests to retrieve the image data
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [future.result() for future in futures]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
with torch.no_grad():
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
"""
encode a single image tensor and optionally normalize the output
"""
image_features = self._model.encode_image(image_tensor)
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
def url_retrieve(url: str):
"""
Parameters
----------
url: str
URL to download from
"""
try:
with urllib.request.urlopen(url) as conn:
return conn.read()
except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError("could not download {} due to {}".format(url, err))

View File

@@ -1,4 +1,4 @@
# Copyright 2023 LanceDB Developers # Copyright (c) 2023. LanceDB Developers
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ import pyarrow as pa
from lance.vector import vec_to_table from lance.vector import vec_to_table
from retry import retry from retry import retry
from .util import safe_import_pandas from ..util import safe_import_pandas
pd = safe_import_pandas() pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"] DATA = Union[pa.Table, "pd.DataFrame"]
@@ -58,7 +58,7 @@ def with_embeddings(
pa.Table pa.Table
The input table with a new column called "vector" containing the embeddings. The input table with a new column called "vector" containing the embeddings.
""" """
func = EmbeddingFunction(func) func = FunctionWrapper(func)
if wrap_api: if wrap_api:
func = func.retry().rate_limit() func = func.retry().rate_limit()
func = func.batch_size(batch_size) func = func.batch_size(batch_size)
@@ -71,7 +71,11 @@ def with_embeddings(
return data.append_column("vector", table["vector"]) return data.append_column("vector", table["vector"])
class EmbeddingFunction: class FunctionWrapper:
"""
A wrapper for embedding functions that adds rate limiting, retries, and batching.
"""
def __init__(self, func: Callable): def __init__(self, func: Callable):
self.func = func self.func = func
self.rate_limiter_kwargs = {} self.rate_limiter_kwargs = {}

View File

@@ -26,6 +26,8 @@ import pyarrow as pa
import pydantic import pydantic
import semver import semver
from .embeddings import EmbeddingFunctionRegistry
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try: try:
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
@@ -46,7 +48,19 @@ class FixedSizeListMixin(ABC):
raise NotImplementedError raise NotImplementedError
def vector( def vector(dim: int, value_type: pa.DataType = pa.float32()):
# TODO: remove in future release
from warnings import warn
warn(
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
"This function will be removed in future release",
DeprecationWarning,
)
return Vector(dim, value_type)
def Vector(
dim: int, value_type: pa.DataType = pa.float32() dim: int, value_type: pa.DataType = pa.float32()
) -> Type[FixedSizeListMixin]: ) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type. """Pydantic Vector Type.
@@ -65,12 +79,12 @@ def vector(
-------- --------
>>> import pydantic >>> import pydantic
>>> from lancedb.pydantic import vector >>> from lancedb.pydantic import Vector
... ...
>>> class MyModel(pydantic.BaseModel): >>> class MyModel(pydantic.BaseModel):
... id: int ... id: int
... url: str ... url: str
... embeddings: vector(768) ... embeddings: Vector(768)
>>> schema = pydantic_to_schema(MyModel) >>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([ >>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False), ... pa.field("id", pa.int64(), False),
@@ -114,7 +128,7 @@ def vector(
def validate(cls, v): def validate(cls, v):
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim: if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
raise TypeError("A list of numbers or numpy.ndarray is needed") raise TypeError("A list of numbers or numpy.ndarray is needed")
return v return cls(v)
if PYDANTIC_VERSION < (2, 0): if PYDANTIC_VERSION < (2, 0):
@@ -224,27 +238,18 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
>>> from typing import List, Optional >>> from typing import List, Optional
>>> import pydantic >>> import pydantic
>>> from lancedb.pydantic import pydantic_to_schema >>> from lancedb.pydantic import pydantic_to_schema
...
>>> class InnerModel(pydantic.BaseModel):
... a: str
... b: Optional[float]
>>>
>>> class FooModel(pydantic.BaseModel): >>> class FooModel(pydantic.BaseModel):
... id: int ... id: int
... s: Optional[str] = None ... s: str
... vec: List[float] ... vec: List[float]
... li: List[int] ... li: List[int]
... inner: InnerModel ...
>>> schema = pydantic_to_schema(FooModel) >>> schema = pydantic_to_schema(FooModel)
>>> assert schema == pa.schema([ >>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False), ... pa.field("id", pa.int64(), False),
... pa.field("s", pa.utf8(), True), ... pa.field("s", pa.utf8(), False),
... pa.field("vec", pa.list_(pa.float64()), False), ... pa.field("vec", pa.list_(pa.float64()), False),
... pa.field("li", pa.list_(pa.int64()), False), ... pa.field("li", pa.list_(pa.int64()), False),
... pa.field("inner", pa.struct([
... pa.field("a", pa.utf8(), False),
... pa.field("b", pa.float64(), True),
... ]), False),
... ]) ... ])
""" """
fields = _pydantic_model_to_fields(model) fields = _pydantic_model_to_fields(model)
@@ -258,11 +263,11 @@ class LanceModel(pydantic.BaseModel):
Examples Examples
-------- --------
>>> import lancedb >>> import lancedb
>>> from lancedb.pydantic import LanceModel, vector >>> from lancedb.pydantic import LanceModel, Vector
>>> >>>
>>> class TestModel(LanceModel): >>> class TestModel(LanceModel):
... name: str ... name: str
... vector: vector(2) ... vector: Vector(2)
... ...
>>> db = lancedb.connect("/tmp") >>> db = lancedb.connect("/tmp")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema()) >>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
@@ -278,13 +283,58 @@ class LanceModel(pydantic.BaseModel):
""" """
Get the Arrow Schema for this model. Get the Arrow Schema for this model.
""" """
return pydantic_to_schema(cls) schema = pydantic_to_schema(cls)
functions = cls.parse_embedding_functions()
if len(functions) > 0:
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
functions
)
schema = schema.with_metadata(metadata)
return schema
@classmethod @classmethod
def field_names(cls) -> List[str]: def field_names(cls) -> List[str]:
""" """
Get the field names of this model. Get the field names of this model.
""" """
return list(cls.safe_get_fields().keys())
@classmethod
def safe_get_fields(cls):
if PYDANTIC_VERSION.major < 2: if PYDANTIC_VERSION.major < 2:
return list(cls.__fields__.keys()) return cls.__fields__
return list(cls.model_fields.keys()) return cls.model_fields
@classmethod
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
"""
Parse the embedding functions from this model.
"""
from .embeddings import EmbeddingFunctionConfig
vec_and_function = []
for name, field_info in cls.safe_get_fields().items():
func = get_extras(field_info, "vector_column_for")
if func is not None:
vec_and_function.append([name, func])
configs = []
for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items():
src_func = get_extras(field_info, "source_column_for")
if src_func == func:
configs.append(
EmbeddingFunctionConfig(
source_column=source, vector_column=vec, function=func
)
)
return configs
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)

View File

@@ -13,6 +13,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union from typing import List, Literal, Optional, Type, Union
import numpy as np import numpy as np
@@ -54,7 +55,163 @@ class Query(pydantic.BaseModel):
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
class LanceQueryBuilder: class LanceQueryBuilder(ABC):
@classmethod
def create(
cls,
table: "lancedb.table.Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
# convert "auto" query_type to "vector" or "fts"
# and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
if isinstance(query, list):
query = np.array(query, dtype=np.float32)
elif isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name)
@classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name):
# If query_type is fts, then query must be a string.
# otherwise raise TypeError
if query_type == "fts":
if not isinstance(query, str):
raise TypeError(f"'fts' queries must be a string: {type(query)}")
return query, query_type
elif query_type == "vector":
if not isinstance(query, (list, np.ndarray)):
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
return query, "vector"
else:
return query, "fts"
else:
raise ValueError(
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
)
def __init__(self, table: "lancedb.table.Table"):
self._table = table
self._limit = 10
self._columns = None
self._where = None
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
raise NotImplementedError
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
def limit(self, limit: int) -> LanceVectorQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceVectorQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
class LanceVectorQueryBuilder(LanceQueryBuilder):
""" """
A builder for nearest neighbor queries for LanceDB. A builder for nearest neighbor queries for LanceDB.
@@ -80,68 +237,17 @@ class LanceQueryBuilder:
def __init__( def __init__(
self, self,
table: "lancedb.table.Table", table: "lancedb.table.Table",
query: Union[np.ndarray, str], query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str = VECTOR_COLUMN_NAME, vector_column: str = VECTOR_COLUMN_NAME,
): ):
super().__init__(table)
self._query = query
self._metric = "L2" self._metric = "L2"
self._nprobes = 20 self._nprobes = 20
self._refine_factor = None self._refine_factor = None
self._table = table
self._query = query
self._limit = 10
self._columns = None
self._where = None
self._vector_column = vector_column self._vector_column = vector_column
def limit(self, limit: int) -> LanceQueryBuilder: def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
Parameters Parameters
@@ -151,13 +257,13 @@ class LanceQueryBuilder:
Returns Returns
------- -------
LanceQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._metric = metric self._metric = metric
return self return self
def nprobes(self, nprobes: int) -> LanceQueryBuilder: def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
"""Set the number of probes to use. """Set the number of probes to use.
Higher values will yield better recall (more likely to find vectors if Higher values will yield better recall (more likely to find vectors if
@@ -173,13 +279,13 @@ class LanceQueryBuilder:
Returns Returns
------- -------
LanceQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder: def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled. """Set the refine factor to use, increasing the number of vectors sampled.
As an example, a refine factor of 2 will sample 2x as many vectors as As an example, a refine factor of 2 will sample 2x as many vectors as
@@ -195,22 +301,12 @@ class LanceQueryBuilder:
Returns Returns
------- -------
LanceQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._refine_factor = refine_factor self._refine_factor = refine_factor
return self return self
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
""" """
Execute the query and return the results as an Execute the query and return the results as an
@@ -233,25 +329,12 @@ class LanceQueryBuilder:
) )
return self._table._execute_query(query) return self._table._execute_query(query)
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "lancedb.table.Table", query: str):
super().__init__(table)
self._query = query
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
try: try:
import tantivy import tantivy
@@ -275,3 +358,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores) output_tbl = output_tbl.append_column("score", scores)
return output_tbl return output_tbl
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
limit=self._limit,
)

View File

@@ -18,10 +18,9 @@ from urllib.parse import urlparse
import pyarrow as pa import pyarrow as pa
from lancedb.common import DATA from ..common import DATA
from lancedb.db import DBConnection from ..db import DBConnection
from lancedb.table import Table, _sanitize_data from ..table import Table, _sanitize_data
from .arrow import to_ipc_binary from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient

View File

@@ -20,7 +20,7 @@ from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
from .arrow import to_ipc_binary from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE from .client import ARROW_STREAM_CONTENT_TYPE
@@ -73,7 +73,11 @@ class RemoteTable(Table):
fill_value: float = 0.0, fill_value: float = 0.0,
) -> int: ) -> int:
data = _sanitize_data( data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
) )
payload = to_ipc_binary(data) payload = to_ipc_binary(data)
@@ -89,9 +93,9 @@ class RemoteTable(Table):
) )
def search( def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder: ) -> LanceVectorQueryBuilder:
return LanceQueryBuilder(self, query, vector_column) return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)

View File

@@ -17,7 +17,7 @@ import inspect
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import Iterable, List, Optional, Union from typing import Any, Iterable, List, Optional, Union
import lance import lance
import numpy as np import numpy as np
@@ -28,54 +28,89 @@ from lance.dataset import ReaderLike
from lance.vector import vec_to_table from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionRegistry
from .embeddings.functions import EmbeddingFunctionConfig
from .pydantic import LanceModel from .pydantic import LanceModel
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas from .util import fs_from_uri, safe_import_pandas
pd = safe_import_pandas() pd = safe_import_pandas()
def _sanitize_data(data, schema, on_bad_vectors, fill_value): def _sanitize_data(
data,
schema: Optional[pa.Schema],
metadata: Optional[dict],
on_bad_vectors: str,
fill_value: Any,
):
if isinstance(data, list): if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels # convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel): if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema() schema = data[0].__class__.to_arrow_schema()
data = [dict(d) for d in data] data = [dict(d) for d in data]
data = pa.Table.from_pylist(data) data = pa.Table.from_pylist(data)
data = _sanitize_schema( elif isinstance(data, dict):
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if isinstance(data, dict):
data = vec_to_table(data) data = vec_to_table(data)
if pd is not None and isinstance(data, pd.DataFrame): elif pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False) data = pa.Table.from_pandas(data, preserve_index=False)
# Do not serialize Pandas metadata
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
if isinstance(data, pa.Table):
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
data = _sanitize_schema( data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
) )
# Do not serialize Pandas metadata elif isinstance(data, Iterable):
metadata = data.schema.metadata if data.schema.metadata is not None else {} data = _to_record_batch_generator(
metadata = {k: v for k, v in metadata.items() if k != b"pandas"} data, schema, metadata, on_bad_vectors, fill_value
schema = data.schema.with_metadata(metadata) )
data = pa.Table.from_arrays(data.columns, schema=schema) else:
if isinstance(data, Iterable):
data = _to_record_batch_generator(data, schema, on_bad_vectors, fill_value)
if not isinstance(data, (pa.Table, Iterable)):
raise TypeError(f"Unsupported data type: {type(data)}") raise TypeError(f"Unsupported data type: {type(data)}")
return data return data
def _to_record_batch_generator(data: Iterable, schema, on_bad_vectors, fill_value): def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
"""
Use the embedding function to automatically embed the source column and add the
vector column to the table.
"""
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_column, conf in functions.items():
func = conf.function
if vector_column not in data.column_names:
col_data = func.compute_source_embeddings(data[conf.source_column])
if schema is not None:
dtype = schema.field(vector_column).type
else:
dtype = pa.list_(pa.float32(), len(col_data[0]))
data = data.append_column(
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
)
return data
def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value
):
for batch in data: for batch in data:
if not isinstance(batch, pa.RecordBatch): if not isinstance(batch, pa.RecordBatch):
table = _sanitize_data(batch, schema, on_bad_vectors, fill_value) table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for batch in table.to_batches(): for batch in table.to_batches():
yield batch yield batch
yield batch else:
yield batch
class Table(ABC): class Table(ABC):
""" """
A [Table](Table) is a collection of Records in a LanceDB [Database](Database). A Table is a collection of Records in a LanceDB Database.
Examples Examples
-------- --------
@@ -196,17 +231,30 @@ class Table(ABC):
@abstractmethod @abstractmethod
def search( def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. of the given query vector.
Parameters Parameters
---------- ----------
query: list, np.ndarray query: str, list, np.ndarray, PIL.Image.Image, default None
The query vector. The query to search for. If None then
vector_column: str, default "vector" the select/where/limit clauses are applied to filter
the table
vector_column_name: str, default "vector"
The name of the vector column to search. The name of the vector column to search.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns Returns
------- -------
@@ -325,14 +373,14 @@ class LanceTable(Table):
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}]) >>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version >>> table.version
1 2
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> table.version >>> table.version
2 3
>>> table.checkout(1) >>> table.checkout(2)
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
@@ -361,19 +409,19 @@ class LanceTable(Table):
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}]) >>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version >>> table.version
1 2
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> table.version >>> table.version
2 3
>>> table.restore(1) >>> table.restore(2)
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
>>> len(table.list_versions()) >>> len(table.list_versions())
3 4
""" """
max_ver = max([v["version"] for v in self._dataset.versions()]) max_ver = max([v["version"] for v in self._dataset.versions()])
if version is None: if version is None:
@@ -480,6 +528,9 @@ class LanceTable(Table):
fill_value: float = 0.0, fill_value: float = 0.0,
): ):
"""Add data to the table. """Add data to the table.
If vector columns are missing and the table
has embedding functions, then the vector columns
are automatically computed and added.
Parameters Parameters
---------- ----------
@@ -501,7 +552,11 @@ class LanceTable(Table):
""" """
# TODO: manage table listing and metadata separately # TODO: manage table listing and metadata separately
data = _sanitize_data( data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
) )
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode) lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._reset_dataset() self._reset_dataset()
@@ -511,7 +566,7 @@ class LanceTable(Table):
other_table: Union[LanceTable, ReaderLike], other_table: Union[LanceTable, ReaderLike],
left_on: str, left_on: str,
right_on: Optional[str] = None, right_on: Optional[str] = None,
schema: Optional[pa.Schema, LanceModel] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
): ):
"""Merge another table into this table. """Merge another table into this table.
@@ -569,18 +624,46 @@ class LanceTable(Table):
) )
self._reset_dataset() self._reset_dataset()
@cached_property
def embedding_functions(self) -> dict:
"""
Get the embedding functions for the table
Returns
-------
funcs: dict
A mapping of the vector column to the embedding function
or empty dict if not configured.
"""
return EmbeddingFunctionRegistry.get_instance().parse_functions(
self.schema.metadata
)
def search( def search(
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. of the given query vector.
Parameters Parameters
---------- ----------
query: list, np.ndarray query: str, list, np.ndarray, a PIL Image or None
The query vector. The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
vector_column_name: str, default "vector" vector_column_name: str, default "vector"
The name of the vector column to search. The name of the vector column to search.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If the query is a string, then the query type is "vector" if the
table has embedding functions, else the query type is "fts"
Returns Returns
------- -------
@@ -590,17 +673,9 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if isinstance(query, str): return LanceQueryBuilder.create(
# fts self, query, query_type, vector_column_name=vector_column_name
return LanceFtsQueryBuilder(self, query, vector_column_name) )
if isinstance(query, list):
query = np.array(query)
if isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query, vector_column_name)
@classmethod @classmethod
def create( def create(
@@ -612,6 +687,7 @@ class LanceTable(Table):
mode="create", mode="create",
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionConfig] = None,
): ):
""" """
Create a new table. Create a new table.
@@ -649,20 +725,58 @@ class LanceTable(Table):
One of "error", "drop", "fill". One of "error", "drop", "fill".
fill_value: float, default 0. fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
embedding_functions: list of EmbeddingFunctionModel, default None
The embedding functions to use when creating the table.
""" """
tbl = LanceTable(db, name) tbl = LanceTable(db, name)
if inspect.isclass(schema) and issubclass(schema, LanceModel): if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema() schema = schema.to_arrow_schema()
metadata = None
if embedding_functions is not None:
# If we passed in embedding functions explicitly
# then we'll override any schema metadata that
# may was implicitly specified by the LanceModel schema
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)
if data is not None: if data is not None:
data = _sanitize_data( data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data,
schema,
metadata=metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
) )
else:
if schema is None: if schema is None:
if data is None:
raise ValueError("Either data or schema must be provided") raise ValueError("Either data or schema must be provided")
data = pa.Table.from_pylist([], schema=schema) elif hasattr(data, "schema"):
lance.write_dataset(data, tbl._dataset_uri, schema=schema, mode=mode) schema = data.schema
return LanceTable(db, name) elif isinstance(data, Iterable):
if metadata:
raise TypeError(
(
"Persistent embedding functions not yet "
"supported for generator data input"
)
)
if metadata:
schema = schema.with_metadata(metadata)
empty = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
table = LanceTable(db, name)
if data is not None:
table.add(data)
return table
@classmethod @classmethod
def open(cls, db, name): def open(cls, db, name):
@@ -678,6 +792,56 @@ class LanceTable(Table):
def delete(self, where: str): def delete(self, where: str):
self._dataset.delete(where) self._dataset.delete(where)
def update(self, where: str, values: dict):
"""
EXPERIMENTAL: Update rows in the table (not threadsafe).
This can be used to update zero to all rows depending on how many
rows match the where clause.
Parameters
----------
where: str
The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict
The values to update. The keys are the column names and the values
are the values to set.
Examples
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 3 [5.0, 6.0]
2 2 [10.0, 10.0]
"""
orig_data = self._dataset.to_table(filter=where).combine_chunks()
if len(orig_data) == 0:
return
for col, val in values.items():
i = orig_data.column_names.index(col)
if i < 0:
raise ValueError(f"Column {col} does not exist")
orig_data = orig_data.set_column(
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
)
self.delete(where)
self.add(orig_data, mode="append")
self._reset_dataset()
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance() ds = self.to_lance()
return ds.to_table( return ds.to_table(
@@ -720,22 +884,38 @@ def _sanitize_schema(
return data return data
# cast the columns to the expected types # cast the columns to the expected types
data = data.combine_chunks() data = data.combine_chunks()
data = _sanitize_vector_column( for field in schema:
# TODO: we're making an assumption that fixed size list of 10 or more
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_float32(field.type.value_type)
and field.type.list_size >= 10
)
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
if field.name in data.column_names and (
likely_vector_col or is_default_vector_col
):
data = _sanitize_vector_column(
data,
vector_column_name=field.name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
# just check the vector column
if VECTOR_COLUMN_NAME in data.column_names:
return _sanitize_vector_column(
data, data,
vector_column_name=VECTOR_COLUMN_NAME, vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema return data
)
# just check the vector column
return _sanitize_vector_column(
data,
vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
def _sanitize_vector_column( def _sanitize_vector_column(
@@ -759,8 +939,6 @@ def _sanitize_vector_column(
fill_value: float, default 0.0 fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
""" """
if vector_column_name not in data.column_names:
raise ValueError(f"Missing vector column: {vector_column_name}")
# ChunkedArray is annoying to work with, so we combine chunks here # ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks() vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_list(data[vector_column_name].type): if pa.types.is_list(data[vector_column_name].type):

View File

@@ -70,7 +70,11 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
Get a PyArrow FileSystem from a URI, handling extra environment variables. Get a PyArrow FileSystem from a URI, handling extra environment variables.
""" """
if get_uri_scheme(uri) == "s3": if get_uri_scheme(uri) == "s3":
fs = pa_fs.S3FileSystem(endpoint_override=os.environ.get("AWS_ENDPOINT")) fs = pa_fs.S3FileSystem(
endpoint_override=os.environ.get("AWS_ENDPOINT"),
request_timeout=30,
connect_timeout=30,
)
path = get_uri_location(uri) path = get_uri_location(uri)
return fs, path return fs, path

View File

@@ -1,15 +1,16 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.2.2" version = "0.2.5"
dependencies = [ dependencies = [
"pylance==0.6.5", "pylance==0.7.4",
"ratelimiter", "ratelimiter",
"retry", "retry",
"tqdm", "tqdm",
"aiohttp", "aiohttp",
"pydantic", "pydantic",
"attr", "attr",
"semver>=3.0" "semver>=3.0",
"cachetools"
] ]
description = "lancedb" description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
@@ -43,9 +44,11 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies] [project.optional-dependencies]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio"] tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
dev = ["ruff", "pre-commit", "black"] dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
[build-system] [build-system]
requires = ["setuptools", "wheel"] requires = ["setuptools", "wheel"]
@@ -53,3 +56,10 @@ build-backend = "setuptools.build_meta"
[tool.isort] [tool.isort]
profile = "black" profile = "black"
[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio"
]

View File

@@ -17,7 +17,7 @@ import pyarrow as pa
import pytest import pytest
import lancedb import lancedb
from lancedb.pydantic import LanceModel, vector from lancedb.pydantic import LanceModel, Vector
def test_basic(tmp_path): def test_basic(tmp_path):
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
def test_ingest_iterator(tmp_path): def test_ingest_iterator(tmp_path):
class PydanticSchema(LanceModel): class PydanticSchema(LanceModel):
vector: vector(2) vector: Vector(2)
item: str item: str
price: float price: float
@@ -136,15 +136,14 @@ def test_ingest_iterator(tmp_path):
def run_tests(schema): def run_tests(schema):
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite") tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl.to_pandas() tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0 assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0 assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
tbl_len = len(tbl) tbl_len = len(tbl)
tbl.add(make_batches()) tbl.add(make_batches())
assert tbl_len == 50
assert len(tbl) == tbl_len * 2 assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 2 assert len(tbl.list_versions()) == 3
db.drop_database() db.drop_database()
run_tests(arrow_schema) run_tests(arrow_schema)

View File

@@ -12,10 +12,16 @@
# limitations under the License. # limitations under the License.
import sys import sys
import lance
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
from lancedb.embeddings import with_embeddings from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
with_embeddings,
)
def mock_embed_func(input_data): def mock_embed_func(input_data):
@@ -40,3 +46,40 @@ def test_with_embeddings():
assert data.column_names == ["text", "price", "vector"] assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"] assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0] assert data.column("price").to_pylist() == [10.0, 20.0]
def test_embedding_function(tmp_path):
registry = EmbeddingFunctionRegistry.get_instance()
# let's create a table
table = pa.table(
{
"text": pa.array(["hello world", "goodbye world"]),
"vector": [np.random.randn(10), np.random.randn(10)],
}
)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
table = table.replace_schema_metadata(metadata)
# Write it to disk
lance.write_dataset(table, tmp_path / "test.lance")
# Load this back
ds = lance.dataset(tmp_path / "test.lance")
# can we get the serialized version back out?
configs = registry.parse_functions(ds.schema.metadata)
conf = configs["vector"]
func = conf.function
actual = func.compute_query_embeddings("hello world")
# And we make sure we can call it
expected = func.compute_query_embeddings("hello world")
assert np.allclose(actual, expected)

View File

@@ -0,0 +1,125 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import numpy as np
import pandas as pd
import pytest
import requests
import lancedb
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_sentence_transformer(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get(alias).create()
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words)
table.add(
pd.DataFrame(
{
"text": [
"hello world",
"goodbye world",
"fizz",
"buzz",
"foo",
"bar",
"baz",
]
}
)
)
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
vec = func.compute_query_embeddings(query)[0]
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
assert actual.text == expected.text
assert actual.text == "hello world"
@pytest.mark.slow
def test_openclip(tmp_path):
from PIL import Image
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get("open-clip").create()
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
)
# text search
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
assert actual.label == "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == frombytes.label
assert np.allclose(actual.vector, frombytes.vector)
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(io.BytesIO(image_bytes))
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
assert actual.label == "dog"
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == other.label
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
assert np.allclose(
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)

View File

@@ -19,8 +19,9 @@ from typing import List, Optional
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import pytest import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
@pytest.mark.skipif( @pytest.mark.skipif(
@@ -107,7 +108,7 @@ def test_pydantic_to_arrow_py38():
def test_fixed_size_list_field(): def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel):
vec: vector(16) vec: Vector(16)
li: List[int] li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3]) data = TestModel(vec=list(range(16)), li=[1, 2, 3])
@@ -154,7 +155,7 @@ def test_fixed_size_list_field():
def test_fixed_size_list_validation(): def test_fixed_size_list_validation():
class TestModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel):
vec: vector(8) vec: Vector(8)
with pytest.raises(pydantic.ValidationError): with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(9)) TestModel(vec=range(9))
@@ -167,9 +168,12 @@ def test_fixed_size_list_validation():
def test_lance_model(): def test_lance_model():
class TestModel(LanceModel): class TestModel(LanceModel):
vec: vector(16) vector: Vector(16) = Field(default=[0.0] * 16)
li: List[int] li: List[int] = Field(default=[1, 2, 3])
schema = pydantic_to_schema(TestModel) schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema() assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["vec", "li"] assert TestModel.field_names() == ["vector", "li"]
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])

View File

@@ -20,8 +20,8 @@ import pyarrow as pa
import pytest import pytest
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceQueryBuilder, Query from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable from lancedb.table import LanceTable
@@ -67,12 +67,12 @@ def table(tmp_path) -> MockTable:
def test_cast(table): def test_cast(table):
class TestModel(LanceModel): class TestModel(LanceModel):
vector: vector(2) vector: Vector(2)
id: int id: int
str_field: str str_field: str
float_field: float float_field: float
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1) q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
results = q.to_pydantic(TestModel) results = q.to_pydantic(TestModel)
assert len(results) == 1 assert len(results) == 1
r0 = results[0] r0 = results[0]
@@ -84,13 +84,15 @@ def test_cast(table):
def test_query_builder(table): def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df() df = (
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
)
assert df["id"].values[0] == 1 assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2]) assert all(df["vector"].values[0] == [1, 2])
def test_query_builder_with_filter(table): def test_query_builder_with_filter(table):
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df() df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2 assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4]) assert all(df["vector"].values[0] == [3, 4])
@@ -98,12 +100,14 @@ def test_query_builder_with_filter(table):
def test_query_builder_with_metric(table): def test_query_builder_with_metric(table):
query = [4, 8] query = [4, 8]
vector_column_name = "vector" vector_column_name = "vector"
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df() df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df() df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
)
tm.assert_frame_equal(df_default, df_l2) tm.assert_frame_equal(df_default, df_l2)
df_cosine = ( df_cosine = (
LanceQueryBuilder(table, query, vector_column_name) LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine") .metric("cosine")
.limit(1) .limit(1)
.to_df() .to_df()
@@ -120,7 +124,7 @@ def test_query_builder_with_different_vector_column():
query = [4, 8] query = [4, 8]
vector_column_name = "foo_vector" vector_column_name = "foo_vector"
builder = ( builder = (
LanceQueryBuilder(table, query, vector_column_name) LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine") .metric("cosine")
.where("b < 10") .where("b < 10")
.select(["b"]) .select(["b"])

View File

@@ -22,8 +22,10 @@ import pandas as pd
import pyarrow as pa import pyarrow as pa
import pytest import pytest
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.table import LanceTable from lancedb.table import LanceTable
@@ -139,7 +141,7 @@ def test_add(db):
def test_add_pydantic_model(db): def test_add_pydantic_model(db):
class TestModel(LanceModel): class TestModel(LanceModel):
vector: vector(16) vector: Vector(16)
li: List[int] li: List[int]
data = TestModel(vector=list(range(16)), li=[1, 2, 3]) data = TestModel(vector=list(range(16)), li=[1, 2, 3])
@@ -178,16 +180,16 @@ def test_versioning(db):
], ],
) )
assert len(table.list_versions()) == 1
assert table.version == 1
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 2 assert len(table.list_versions()) == 2
assert table.version == 2 assert table.version == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 3 assert len(table) == 3
table.checkout(1) table.checkout(2)
assert table.version == 1 assert table.version == 2
assert len(table) == 2 assert len(table) == 2
@@ -278,21 +280,21 @@ def test_restore(db):
data=[{"vector": [1.1, 0.9], "type": "vector"}], data=[{"vector": [1.1, 0.9], "type": "vector"}],
) )
table.add([{"vector": [0.5, 0.2], "type": "vector"}]) table.add([{"vector": [0.5, 0.2], "type": "vector"}])
table.restore(1) table.restore(2)
assert len(table.list_versions()) == 3 assert len(table.list_versions()) == 4
assert len(table) == 1 assert len(table) == 1
expected = table.to_arrow() expected = table.to_arrow()
table.checkout(1) table.checkout(2)
table.restore() table.restore()
assert len(table.list_versions()) == 4 assert len(table.list_versions()) == 5
assert table.to_arrow() == expected assert table.to_arrow() == expected
table.restore(4) # latest version should be no-op table.restore(5) # latest version should be no-op
assert len(table.list_versions()) == 4 assert len(table.list_versions()) == 5
with pytest.raises(ValueError): with pytest.raises(ValueError):
table.restore(5) table.restore(6)
with pytest.raises(ValueError): with pytest.raises(ValueError):
table.restore(0) table.restore(0)
@@ -306,7 +308,7 @@ def test_merge(db, tmp_path):
) )
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]}) other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
table.merge(other_table, left_on="id") table.merge(other_table, left_on="id")
assert len(table.list_versions()) == 2 assert len(table.list_versions()) == 3
expected = pa.table( expected = pa.table(
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]}, {"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
schema=table.schema, schema=table.schema,
@@ -316,3 +318,126 @@ def test_merge(db, tmp_path):
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance") other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
table.restore(1) table.restore(1)
table.merge(other_dataset, left_on="id") table.merge(other_dataset, left_on="id")
def test_delete(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 2
table.delete("id=0")
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
def test_update(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 2
table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 4
assert table.version == 4
assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: Vector(10)
func = MockTextEmbeddingFunction()
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
conf = EmbeddingFunctionConfig(
source_column="text", vector_column="vector", function=func
)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[conf],
)
table.add(df)
query_str = "hi how are you?"
query_vector = func.compute_query_embeddings(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_add_with_embedding_function(db):
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
table = LanceTable.create(db, "my_table", schema=MyTable)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"]
table.add([{"text": t} for t in texts])
query_str = "hi how are you?"
query_vector = emb.compute_query_embeddings(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: Vector(10)
vector2: Vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
def test_empty_query(db):
table = LanceTable.create(
db,
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
val = df.id.iloc[0]
assert val == 1

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb-node" name = "vectordb-node"
version = "0.2.4" version = "0.2.6"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
edition = "2018" edition = "2018"
@@ -18,6 +18,7 @@ once_cell = "1"
futures = "0.3" futures = "0.3"
half = { workspace = true } half = { workspace = true }
lance = { workspace = true } lance = { workspace = true }
lance-linalg = { workspace = true }
vectordb = { path = "../../vectordb" } vectordb = { path = "../../vectordb" }
tokio = { version = "1.23", features = ["rt-multi-thread"] } tokio = { version = "1.23", features = ["rt-multi-thread"] }
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] } neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }

View File

@@ -28,7 +28,9 @@ fn validate_vector_column(record_batch: &RecordBatch) -> Result<()> {
record_batch record_batch
.column_by_name(VECTOR_COLUMN_NAME) .column_by_name(VECTOR_COLUMN_NAME)
.map(|_| ()) .map(|_| ())
.context(MissingColumnSnafu { name: VECTOR_COLUMN_NAME }) .context(MissingColumnSnafu {
name: VECTOR_COLUMN_NAME,
})
} }
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> { pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {

View File

@@ -14,7 +14,7 @@
use lance::index::vector::ivf::IvfBuildParams; use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams; use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::MetricType; use lance_linalg::distance::MetricType;
use neon::context::FunctionContext; use neon::context::FunctionContext;
use neon::prelude::*; use neon::prelude::*;
use std::convert::TryFrom; use std::convert::TryFrom;

View File

@@ -183,11 +183,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let aws_region = get_aws_region(&mut cx, 4)?; let aws_region = get_aws_region(&mut cx, 4)?;
let params = ReadParams { let params = ReadParams {
store_options: Some(ObjectStoreParams { store_options: Some(ObjectStoreParams::with_aws_credentials(
aws_credentials: aws_creds, aws_creds, aws_region,
aws_region, )),
..ObjectStoreParams::default()
}),
..ReadParams::default() ..ReadParams::default()
}; };

View File

@@ -3,7 +3,7 @@ use std::ops::Deref;
use arrow_array::Float32Array; use arrow_array::Float32Array;
use futures::{TryFutureExt, TryStreamExt}; use futures::{TryFutureExt, TryStreamExt};
use lance::index::vector::MetricType; use lance_linalg::distance::MetricType;
use neon::context::FunctionContext; use neon::context::FunctionContext;
use neon::handle::Handle; use neon::handle::Handle;
use neon::prelude::*; use neon::prelude::*;

View File

@@ -43,7 +43,8 @@ impl JsTable {
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?; .downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx); let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let buffer = cx.argument::<JsBuffer>(1)?; let buffer = cx.argument::<JsBuffer>(1)?;
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
// Write mode // Write mode
let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() { let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() {
@@ -65,11 +66,9 @@ impl JsTable {
let aws_region = get_aws_region(&mut cx, 6)?; let aws_region = get_aws_region(&mut cx, 6)?;
let params = WriteParams { let params = WriteParams {
store_params: Some(ObjectStoreParams { store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_credentials: aws_creds, aws_creds, aws_region,
aws_region, )),
..ObjectStoreParams::default()
}),
mode: mode, mode: mode,
..WriteParams::default() ..WriteParams::default()
}; };
@@ -92,7 +91,8 @@ impl JsTable {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?; let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx); let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let channel = cx.channel(); let channel = cx.channel();
let mut table = js_table.table.clone(); let mut table = js_table.table.clone();
@@ -108,11 +108,9 @@ impl JsTable {
let aws_region = get_aws_region(&mut cx, 5)?; let aws_region = get_aws_region(&mut cx, 5)?;
let params = WriteParams { let params = WriteParams {
store_params: Some(ObjectStoreParams { store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_credentials: aws_creds, aws_creds, aws_region,
aws_region, )),
..ObjectStoreParams::default()
}),
mode: write_mode, mode: write_mode,
..WriteParams::default() ..WriteParams::default()
}; };

View File

@@ -1,21 +1,30 @@
[package] [package]
name = "vectordb" name = "vectordb"
version = "0.2.4" version = "0.2.6"
edition = "2021" edition = "2021"
description = "Serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
keywords = ["lancedb", "lance", "database", "search"]
categories = ["database-implementations"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
arrow = { workspace = true }
arrow-array = { workspace = true } arrow-array = { workspace = true }
arrow-data = { workspace = true } arrow-data = { workspace = true }
arrow-schema = { workspace = true } arrow-schema = { workspace = true }
arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
object_store = { workspace = true } object_store = { workspace = true }
snafu = { workspace = true } snafu = { workspace = true }
half = { workspace = true } half = { workspace = true }
lance = { workspace = true } lance = { workspace = true }
lance-linalg = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] } tokio = { version = "1.23", features = ["rt-multi-thread"] }
log = { workspace = true }
num-traits = "0"
url = { workspace = true }
[dev-dependencies] [dev-dependencies]
tempfile = "3.5.0" tempfile = "3.5.0"

3
rust/vectordb/README.md Normal file
View File

@@ -0,0 +1,3 @@
# LanceDB Rust
Rust client for LanceDB, a serverless vector database. Read more at: https://lancedb.com/

View File

@@ -0,0 +1,15 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use lance::arrow::*;

18
rust/vectordb/src/data.rs Normal file
View File

@@ -0,0 +1,18 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Data types, schema coercion, and data cleaning and etc.
pub mod inspect;
pub mod sanitize;

View File

@@ -0,0 +1,180 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use arrow::compute::kernels::{aggregate::bool_and, length::length};
use arrow_array::{
cast::AsArray,
types::{ArrowPrimitiveType, Int32Type, Int64Type},
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
};
use arrow_ord::comparison::eq_dyn_scalar;
use arrow_schema::DataType;
use num_traits::{ToPrimitive, Zero};
use crate::error::{Error, Result};
pub(crate) fn infer_dimension<T: ArrowPrimitiveType>(
list_arr: &GenericListArray<T::Native>,
) -> Result<Option<T::Native>>
where
T::Native: OffsetSizeTrait + ToPrimitive,
{
let len_arr = length(list_arr)?;
if len_arr.is_empty() {
return Ok(Some(Zero::zero()));
}
let dim = len_arr.as_primitive::<T>().value(0);
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
Ok(None)
} else {
Ok(Some(dim))
}
}
/// Infer the vector columns from a dataset.
///
/// Parameters
/// ----------
/// - reader: RecordBatchReader
/// - strict: if set true, only fixed_size_list<float> is considered as vector column. If set to false,
/// a list<float> column with same length is also considered as vector column.
pub fn infer_vector_columns(
reader: impl RecordBatchReader + Send,
strict: bool,
) -> Result<Vec<String>> {
let mut columns = vec![];
let mut columns_to_infer: HashMap<String, Option<i64>> = HashMap::new();
for field in reader.schema().fields() {
match field.data_type() {
DataType::FixedSizeList(sub_field, _) if sub_field.data_type().is_floating() => {
columns.push(field.name().to_string());
}
DataType::List(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().to_string(), None);
}
DataType::LargeList(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().to_string(), None);
}
_ => {}
}
}
for batch in reader {
let batch = batch?;
let col_names = columns_to_infer.keys().cloned().collect::<Vec<_>>();
for col_name in col_names {
let col = batch.column_by_name(&col_name).ok_or(Error::Schema {
message: format!("Column {} not found", col_name),
})?;
if let Some(dim) = match *col.data_type() {
DataType::List(_) => {
infer_dimension::<Int32Type>(col.as_list::<i32>())?.map(|d| d as i64)
}
DataType::LargeList(_) => infer_dimension::<Int64Type>(col.as_list::<i64>())?,
_ => {
return Err(Error::Schema {
message: format!("Column {} is not a list", col_name),
})
}
} {
if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) {
if prev_dim != &dim {
columns_to_infer.remove(&col_name);
}
} else {
columns_to_infer.insert(col_name, Some(dim));
}
} else {
columns_to_infer.remove(&col_name);
}
}
}
columns.extend(columns_to_infer.keys().cloned());
Ok(columns)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{
types::{Float32Type, Float64Type},
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use std::{sync::Arc, vec};
#[test]
fn test_infer_vector_columns() {
let schema = Arc::new(Schema::new(vec![
Field::new("f", DataType::Float32, false),
Field::new("s", DataType::Utf8, false),
Field::new(
"l1",
DataType::List(Arc::new(Field::new("item", DataType::Float32, true))),
false,
),
Field::new(
"l2",
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
false,
),
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 32),
true,
),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
Arc::new(ListArray::from_iter_primitive::<Float32Type, _, _>(
(0..3).map(|_| Some(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)])),
)),
// Var-length list
Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
Some(vec![Some(1.0_f64)]),
Some(vec![Some(2.0_f64), Some(3.0_f64)]),
Some(vec![Some(4.0_f64), Some(5.0_f64), Some(6.0_f64)]),
])),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
vec![
Some(vec![Some(1.0); 32]),
Some(vec![Some(2.0); 32]),
Some(vec![Some(3.0); 32]),
],
32,
),
),
],
)
.unwrap();
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
let cols = infer_vector_columns(reader, false).unwrap();
assert_eq!(cols, vec!["fl", "l1"]);
let reader = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
let cols = infer_vector_columns(reader, true).unwrap();
assert_eq!(cols, vec!["fl"]);
}
}

View File

@@ -0,0 +1,284 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{iter::repeat_with, sync::Arc};
use arrow_array::{
cast::AsArray,
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
RecordBatchReader,
};
use arrow_cast::{can_cast_types, cast};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use half::f16;
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
use log::warn;
use num_traits::cast::AsPrimitive;
use super::inspect::infer_dimension;
use crate::error::Result;
fn cast_array<I: ArrowNumericType, O: ArrowNumericType>(
arr: &PrimitiveArray<I>,
) -> Arc<PrimitiveArray<O>>
where
I::Native: AsPrimitive<O::Native>,
{
Arc::new(PrimitiveArray::<O>::from_iter_values(
arr.values().iter().map(|v| (*v).as_()),
))
}
fn cast_float_array<I: ArrowNumericType>(
arr: &PrimitiveArray<I>,
dt: &DataType,
) -> std::result::Result<Arc<dyn Array>, ArrowError>
where
I::Native: AsPrimitive<f64> + AsPrimitive<f32> + AsPrimitive<f16>,
{
match dt {
DataType::Float16 => Ok(cast_array::<I, Float16Type>(arr)),
DataType::Float32 => Ok(cast_array::<I, Float32Type>(arr)),
DataType::Float64 => Ok(cast_array::<I, Float64Type>(arr)),
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field: unable to coerce {:?} to {:?}",
arr.data_type(),
dt
))),
}
}
fn coerce_array(
array: &Arc<dyn Array>,
field: &Field,
) -> std::result::Result<Arc<dyn Array>, ArrowError> {
if array.data_type() == field.data_type() {
return Ok(array.clone());
}
match (array.data_type(), field.data_type()) {
// Normal cast-able types.
(adt, dt) if can_cast_types(adt, dt) => cast(&array, dt),
// Casting between f16/f32/f64 can be lossy.
(adt, dt) if (adt.is_floating() || dt.is_floating()) => {
if adt.byte_width() > dt.byte_width() {
warn!(
"Coercing field {} {:?} to {:?} might lose precision",
field.name(),
adt,
dt
);
}
match adt {
DataType::Float16 => cast_float_array(array.as_primitive::<Float16Type>(), dt),
DataType::Float32 => cast_float_array(array.as_primitive::<Float32Type>(), dt),
DataType::Float64 => cast_float_array(array.as_primitive::<Float64Type>(), dt),
_ => unreachable!(),
}
}
(adt, DataType::FixedSizeList(exp_field, exp_dim)) => match adt {
// Cast a float fixed size array with same dimension to the expected type.
DataType::FixedSizeList(_, dim) if dim == exp_dim => {
let actual_sub = array.as_fixed_size_list();
let values = coerce_array(actual_sub.values(), exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*dim,
)?) as Arc<dyn Array>)
}
DataType::List(_) | DataType::LargeList(_) => {
let Some(dim) = (match adt {
DataType::List(_) => infer_dimension::<Int32Type>(array.as_list::<i32>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from list: {}",
e
))
})?
.map(|d| d as i64),
DataType::LargeList(_) => infer_dimension::<Int64Type>(array.as_list::<i64>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from large list: {}",
e
))
})?,
_ => unreachable!(),
}) else {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)));
};
if dim != *exp_dim as i64 {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: expected dimension {} but got {}",
exp_dim, dim
)));
}
let values = coerce_array(array, exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*exp_dim,
)?) as Arc<dyn Array>)
}
_ => Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)))?,
},
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field {}: unable to coerce {:?} to {:?}",
field.name(),
array.data_type(),
field.data_type()
)))?,
}
}
fn coerce_schema_batch(
batch: RecordBatch,
schema: Arc<Schema>,
) -> std::result::Result<RecordBatch, ArrowError> {
if batch.schema() == schema {
return Ok(batch);
}
let columns = schema
.fields()
.iter()
.map(|field| {
batch
.column_by_name(field.name())
.ok_or_else(|| {
ArrowError::SchemaError(format!("Column {} not found", field.name()))
})
.and_then(|c| coerce_array(c, field))
})
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
RecordBatch::try_new(schema, columns)
}
/// Coerce the reader (input data) to match the given [Schema].
///
pub fn coerce_schema(
reader: impl RecordBatchReader + Send + 'static,
schema: Arc<Schema>,
) -> Result<Box<dyn RecordBatchReader + Send>> {
if reader.schema() == schema {
return Ok(Box::new(RecordBatchIterator::new(reader, schema)));
}
let s = schema.clone();
let batches = reader
.zip(repeat_with(move || s.clone()))
.map(|(batch, s)| coerce_schema_batch(batch?, s));
Ok(Box::new(RecordBatchIterator::new(batches, schema)))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow_array::{
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array,
RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::Field;
use half::f16;
use lance::arrow::FixedSizeListArrayExt;
#[test]
fn test_coerce_list_to_fixed_size_list() {
let schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float16, true),
Field::new("i", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float32Array::from_iter_values((0..256).map(|v| v as f32)),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float16Array::from_iter_values(
(0..4).map(|v| f16::from_f32(v as f32)),
)),
Arc::new(Int32Array::from_iter_values(0..4)),
],
)
.unwrap();
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
let expected_schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float64, true),
Field::new("i", DataType::Int8, true),
]));
let stream = coerce_schema(reader, expected_schema.clone()).unwrap();
let batches = stream.collect::<Vec<_>>();
assert_eq!(batches.len(), 1);
let batch = batches[0].as_ref().unwrap();
assert_eq!(batch.schema(), expected_schema);
let expected = RecordBatch::try_new(
expected_schema,
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float16Array::from_iter_values((0..256).map(|v| f16::from_f32(v as f32))),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float64Array::from_iter_values((0..4).map(|v| v as f64))),
Arc::new(Int8Array::from_iter_values(0..4)),
],
)
.unwrap();
assert_eq!(batch, &expected);
}
}

View File

@@ -27,12 +27,14 @@ pub const LANCE_FILE_EXTENSION: &str = "lance";
pub struct Database { pub struct Database {
object_store: ObjectStore, object_store: ObjectStore,
query_string: Option<String>,
pub(crate) uri: String, pub(crate) uri: String,
pub(crate) base_path: object_store::path::Path, pub(crate) base_path: object_store::path::Path,
} }
const LANCE_EXTENSION: &str = "lance"; const LANCE_EXTENSION: &str = "lance";
const ENGINE: &str = "engine";
/// A connection to LanceDB /// A connection to LanceDB
impl Database { impl Database {
@@ -46,12 +48,73 @@ impl Database {
/// ///
/// * A [Database] object. /// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> { pub async fn connect(uri: &str) -> Result<Database> {
let (object_store, base_path) = ObjectStore::from_uri(uri).await?; let parse_res = url::Url::parse(uri);
if object_store.is_local() {
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?; match parse_res {
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await,
Ok(mut url) => {
// iter thru the query params and extract the commit store param
let mut engine = None;
let mut filtered_querys = vec![];
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
// THE API WILL CHANGE
for (key, value) in url.query_pairs() {
if key == ENGINE {
engine = Some(value.to_string());
} else {
// to owned so we can modify the url
filtered_querys.push((key.to_string(), value.to_string()));
}
}
// Filter out the commit store query param -- it's a lancedb param
url.query_pairs_mut().clear();
url.query_pairs_mut().extend_pairs(filtered_querys);
// Take a copy of the query string so we can propagate it to lance
let query_string = url.query().map(|s| s.to_string());
// clear the query string so we can use the url as the base uri
// use .set_query(None) instead of .set_query("") because the latter
// will add a trailing '?' to the url
url.set_query(None);
let table_base_uri = if let Some(store) = engine {
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
WARN_ONCE.call_once(|| {
log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
});
let old_scheme = url.scheme().to_string();
let new_scheme = format!("{}+{}", old_scheme, store);
url.to_string().replacen(&old_scheme, &new_scheme, 1)
} else {
url.to_string()
};
let plain_uri = url.to_string();
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
if object_store.is_local() {
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
}
Ok(Database {
uri: table_base_uri,
query_string,
base_path,
object_store,
})
}
Err(_) => Self::open_path(uri).await,
} }
Ok(Database { }
uri: uri.to_string(),
async fn open_path(path: &str) -> Result<Database> {
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
}
Ok(Self {
uri: path.to_string(),
query_string: None,
base_path, base_path,
object_store, object_store,
}) })
@@ -149,17 +212,26 @@ impl Database {
let path = Path::new(&self.uri); let path = Path::new(&self.uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let uri = table_uri let mut uri = table_uri
.as_path() .as_path()
.to_str() .to_str()
.context(InvalidTableNameSnafu { name })?; .context(InvalidTableNameSnafu { name })?
Ok(uri.to_string()) .to_string();
// If there are query string set on the connection, propagate to lance
if let Some(query) = self.query_string.as_ref() {
uri.push('?');
uri.push_str(query.as_str());
}
Ok(uri)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::fs::create_dir_all; use std::fs::create_dir_all;
use tempfile::tempdir; use tempfile::tempdir;
use crate::database::Database; use crate::database::Database;
@@ -173,6 +245,28 @@ mod tests {
assert_eq!(db.uri, uri); assert_eq!(db.uri, uri);
} }
#[cfg(not(windows))]
#[tokio::test]
async fn test_connect_relative() {
let tmp_dir = tempdir().unwrap();
let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap();
let mut relative_anacestors = vec![];
let current_dir = std::env::current_dir().unwrap();
let mut ancestors = current_dir.ancestors();
while let Some(_) = ancestors.next() {
relative_anacestors.push("..");
}
let relative_root = std::path::PathBuf::from(relative_anacestors.join("/"));
let relative_uri = relative_root.join(&uri);
let db = Database::connect(relative_uri.to_str().unwrap())
.await
.unwrap();
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
}
#[tokio::test] #[tokio::test]
async fn test_table_names() { async fn test_table_names() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use arrow_schema::ArrowError;
use snafu::Snafu; use snafu::Snafu;
#[derive(Debug, Snafu)] #[derive(Debug, Snafu)]
@@ -32,10 +33,20 @@ pub enum Error {
Store { message: String }, Store { message: String },
#[snafu(display("LanceDBError: {message}"))] #[snafu(display("LanceDBError: {message}"))]
Lance { message: String }, Lance { message: String },
#[snafu(display("LanceDB Schema Error: {message}"))]
Schema { message: String },
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
impl From<ArrowError> for Error {
fn from(e: ArrowError) -> Self {
Self::Lance {
message: e.to_string(),
}
}
}
impl From<lance::Error> for Error { impl From<lance::Error> for Error {
fn from(e: lance::Error) -> Self { fn from(e: lance::Error) -> Self {
Self::Lance { Self::Lance {

View File

@@ -14,7 +14,8 @@
use lance::index::vector::ivf::IvfBuildParams; use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams; use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::{MetricType, VectorIndexParams}; use lance::index::vector::VectorIndexParams;
use lance_linalg::distance::MetricType;
pub trait VectorIndexBuilder { pub trait VectorIndexBuilder {
fn get_column(&self) -> Option<String>; fn get_column(&self) -> Option<String>;
@@ -107,9 +108,11 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use lance::index::vector::ivf::IvfBuildParams; use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams; use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::{MetricType, StageParams}; use lance::index::vector::StageParams;
use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder}; use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod data;
pub mod database; pub mod database;
pub mod error; pub mod error;
pub mod index; pub mod index;

View File

@@ -17,7 +17,7 @@ use std::sync::Arc;
use arrow_array::Float32Array; use arrow_array::Float32Array;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::Dataset; use lance::dataset::Dataset;
use lance::index::vector::MetricType; use lance_linalg::distance::MetricType;
use crate::error::Result; use crate::error::Result;
@@ -164,10 +164,10 @@ impl Query {
mod tests { mod tests {
use std::sync::Arc; use std::sync::Arc;
use super::*;
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use lance::dataset::Dataset; use lance::dataset::Dataset;
use lance::index::vector::MetricType;
use crate::query::Query; use crate::query::Query;