Compare commits

..

24 Commits

Author SHA1 Message Date
Lance Release
661fcecf38 [python] Bump version: 0.2.2 → 0.2.3 2023-09-14 17:48:32 +00:00
Lance Release
07fe284810 Updating package-lock.json 2023-09-10 23:58:06 +00:00
Lance Release
800bb691c3 Updating package-lock.json 2023-09-09 19:45:58 +00:00
Lance Release
ec24e09add Bump version: 0.2.4 → 0.2.5 2023-09-09 19:45:43 +00:00
Rob Meng
0554db03b3 progagate uri query string to lance; add aws integration tests (#486)
# WARNING: specifying engine is NOT a publicly supported feature in
lancedb yet. THE API WILL CHANGE.

This PR exposes dynamodb based commit to `vectordb` and JS SDK (will do
python in another PR since it's on a different release track)

This PR also added aws integration test using `localstack`

## What?
This PR adds uri parameters to DB connection string. User may specify
`engine` in the connection string to let LanceDB know that the user
wants to use an external store when reading and writing a table. User
may also pass any parameters required by the commitStore in the
connection string, these parameters will be propagated to lance.

e.g.
```
vectordb.connect("s3://my-db-bucket?engine=ddb&ddbTableName=my-commit-table")
```
will automatically convert table path to
```
s3+ddb://my-db-bucket/my_table.lance?&ddbTableName=my-commit-table
```
2023-09-09 13:33:16 -04:00
Lei Xu
b315ea3978 [Python] Pydantic vector field with default value (#474)
Rename `lance.pydantic.vector` to `Vector` and deprecate `vector(dim)`
2023-09-08 22:35:31 -07:00
Ayush Chaurasia
aa7806cf0d [Python]Fix record_batch_generator (#483)
Should fix - https://github.com/lancedb/lancedb/issues/482
2023-09-08 21:18:50 +05:30
Lei Xu
6799613109 feat: upgrade lance to 0.7.3 (#481) 2023-09-07 17:01:45 -07:00
Lei Xu
0f26915d22 [Rust] schema coerce and vector column inference (#476)
Split the rust core from #466 for easy review and less merge conflicts.
2023-09-06 10:00:46 -07:00
Chang She
32163063dc Fix up docs (#477) 2023-09-05 22:29:50 -07:00
Chang She
9a9a73a65d [python] Use pydantic for embedding function persistence (#467)
1. Support persistent embedding function so users can just search using
query string
2. Add fixed size list conversion for multiple vector columns
3. Add support for empty query (just apply select/where/limit).
4. Refactor and simplify some of the data prep code

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
2023-09-05 21:30:45 -07:00
Ayush Chaurasia
52fa7f5577 [Docs] Small typo fixes (#460) 2023-09-02 22:17:19 +05:30
Chang She
0cba0f4f92 [python] Temporary update feature (#457)
Combine delete and append to make a temporary update feature that is
only enabled for the local python lancedb.

The reason why this is temporary is because it first has to load the
data that matches the where clause into memory, which is technical
unbounded.

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-08-30 00:25:26 -07:00
Will Jones
8391ffee84 chore: make crate more discoverable (#443)
A few small changes to make the Rust crate more discoverable.
2023-08-25 08:59:14 -07:00
Lance Release
fe8848efb9 [python] Bump version: 0.2.1 → 0.2.2 2023-08-24 23:18:10 +00:00
Chang She
213c313b99 Revert "Updating package-lock.json" (#455)
This reverts commit ab97e5d632.

Co-authored-by: Chang She <chang@lancedb.com>
2023-08-24 15:54:57 -07:00
Chang She
157e995a43 Revert "Bump version: 0.2.4 → 0.2.5" (#454)
This reverts commit 87e9a0250f.

I triggered the nodejs release commit GHA by mistake. Reverting it.
The tag will be removed manually.

Co-authored-by: Chang She <chang@lancedb.com>
2023-08-24 15:44:37 -07:00
Lance Release
ab97e5d632 Updating package-lock.json 2023-08-24 21:54:35 +00:00
Lance Release
87e9a0250f Bump version: 0.2.4 → 0.2.5 2023-08-24 21:54:18 +00:00
Chang She
e587a17a64 [python] Support schema evolution in local LanceDB (#452)
Previously if you needed to add a column to a table you'd have to
rewrite the whole table. Instead,
we use the merge functionality from Lance format
to incrementally add columns from another table
or dataframe.

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
2023-08-24 14:40:49 -07:00
Chang She
2f1f9f6338 [python] improve restore functionality (#451)
Previously the temporary restore feature required copying data. The new
feature in pylance does not.

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
2023-08-24 11:00:34 -07:00
Lance Release
a34fa4df26 Updating package-lock.json 2023-08-24 05:23:19 +00:00
Lance Release
e20979b335 Updating package-lock.json 2023-08-24 04:48:11 +00:00
Lance Release
08689c345d Bump version: 0.2.3 → 0.2.4 2023-08-24 04:47:57 +00:00
44 changed files with 1940 additions and 339 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.2.3 current_version = 0.2.5
commit = True commit = True
message = Bump version: {current_version} → {new_version} message = Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -107,3 +107,56 @@ jobs:
- name: Test - name: Test
run: | run: |
npm run test npm run test
aws-integtest:
timeout-minutes: 45
runs-on: "ubuntu-22.04"
defaults:
run:
shell: bash
working-directory: node
env:
AWS_ACCESS_KEY_ID: ACCESSKEY
AWS_SECRET_ACCESS_KEY: SECRETKEY
AWS_DEFAULT_REGION: us-west-2
# this one is for s3
AWS_ENDPOINT: http://localhost:4566
# this one is for dynamodb
DYNAMODB_ENDPOINT: http://localhost:4566
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
with:
node-version: 18
cache: 'npm'
cache-dependency-path: node/package-lock.json
- name: start local stack
run: docker compose -f ../docker-compose.yml up -d
- name: create s3
run: aws s3 mb s3://lancedb-integtest --endpoint $AWS_ENDPOINT
- name: create ddb
run: |
aws dynamodb create-table \
--table-name lancedb-integtest \
--attribute-definitions '[{"AttributeName": "base_uri", "AttributeType": "S"}, {"AttributeName": "version", "AttributeType": "N"}]' \
--key-schema '[{"AttributeName": "base_uri", "KeyType": "HASH"}, {"AttributeName": "version", "KeyType": "RANGE"}]' \
--provisioned-throughput '{"ReadCapacityUnits": 10, "WriteCapacityUnits": 10}' \
--endpoint-url $DYNAMODB_ENDPOINT
- uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: |
npm ci
npm run tsc
npm run build
npm run pack-build
npm install --no-save ./dist/lancedb-vectordb-*.tgz
# Remove index.node to test with dependency installed
rm index.node
- name: Test
run: npm run integration-test

View File

@@ -1,16 +1,24 @@
[workspace] [workspace]
members = [ members = ["rust/ffi/node", "rust/vectordb"]
"rust/vectordb", # Python package needs to be built by maturin.
"rust/ffi/node" exclude = ["python"]
]
resolver = "2" resolver = "2"
[workspace.dependencies] [workspace.dependencies]
lance = "=0.6.5" lance = { "version" = "=0.7.3", "features" = ["dynamodb"] }
# Note that this one does not include pyarrow
arrow = { version = "43.0.0", optional = false }
arrow-array = "43.0" arrow-array = "43.0"
arrow-data = "43.0" arrow-data = "43.0"
arrow-schema = "43.0"
arrow-ipc = "43.0" arrow-ipc = "43.0"
half = { "version" = "=2.2.1", default-features = false } arrow-ord = "43.0"
arrow-schema = "43.0"
arrow-arith = "43.0"
arrow-cast = "43.0"
half = { "version" = "=2.2.1", default-features = false, features = [
"num-traits"
] }
log = "0.4"
object_store = "0.6.1" object_store = "0.6.1"
snafu = "0.7.4" snafu = "0.7.4"
url = "2"

15
docker-compose.yml Normal file
View File

@@ -0,0 +1,15 @@
version: "3.9"
services:
localstack:
image: localstack/localstack:0.14
ports:
- 4566:4566
environment:
- SERVICES=s3,dynamodb
- DEBUG=1
- LS_LOG=trace
- DOCKER_HOST=unix:///var/run/docker.sock
- AWS_ACCESS_KEY_ID=ACCESSKEY
- AWS_SECRET_ACCESS_KEY=SECRETKEY
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]

View File

@@ -67,6 +67,11 @@ nav:
- Home: - Home:
- 🏢 Home: index.md - 🏢 Home: index.md
- 💡 Basics: basic.md - 💡 Basics: basic.md
- 📚 Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- 🧬 Embeddings: embedding.md - 🧬 Embeddings: embedding.md
- 🔍 Python full-text search: fts.md - 🔍 Python full-text search: fts.md
- 🔌 Integrations: - 🔌 Integrations:
@@ -91,12 +96,12 @@ nav:
- Serverless Website Chatbot: examples/serverless_website_chatbot.md - Serverless Website Chatbot: examples/serverless_website_chatbot.md
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md - YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md - TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
- 📚 Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- Basics: basic.md - Basics: basic.md
- Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- Embeddings: embedding.md - Embeddings: embedding.md
- Python full-text search: fts.md - Python full-text search: fts.md
- Integrations: - Integrations:
@@ -121,12 +126,6 @@ nav:
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md - YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md - Serverless Chatbot from any website: examples/serverless_website_chatbot.md
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md - TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
- Guides:
- Tables: guides/tables.md
- Vector Search: search.md
- SQL filters: sql.md
- Indexing: ann_indexes.md
- API references: - API references:
- Python API: python/python.md - Python API: python/python.md
- Javascript API: javascript/modules.md - Javascript API: javascript/modules.md

View File

@@ -49,11 +49,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
db.create_table("table2", data) db.create_table("table2", data)
db["table2"].head() db["table2"].head()
``` ```
!!! info "Note" !!! info "Note"
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly. Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
```python ```python
custom_schema = pa.schema([ custom_schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -66,7 +66,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
### From PyArrow Tables ### From PyArrow Tables
You can also create LanceDB tables directly from pyarrow tables You can also create LanceDB tables directly from pyarrow tables
```python ```python
table = pa.Table.from_arrays( table = pa.Table.from_arrays(
[ [
@@ -87,15 +87,15 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method. LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method.
```python ```python
from lancedb.pydantic import vector, LanceModel from lancedb.pydantic import Vector, LanceModel
class Content(LanceModel): class Content(LanceModel):
movie_id: int movie_id: int
vector: vector(128) vector: Vector(128)
genres: str genres: str
title: str title: str
imdb_id: int imdb_id: int
@property @property
def imdb_url(self) -> str: def imdb_url(self) -> str:
return f"https://www.imdb.com/title/tt{self.imdb_id}" return f"https://www.imdb.com/title/tt{self.imdb_id}"
@@ -103,7 +103,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
import pyarrow as pa import pyarrow as pa
db = lancedb.connect("~/.lancedb") db = lancedb.connect("~/.lancedb")
table_name = "movielens_small" table_name = "movielens_small"
table = db.create_table(table_name, schema=Content.to_arrow_schema()) table = db.create_table(table_name, schema=Content)
``` ```
### Using Iterators / Writing Large Datasets ### Using Iterators / Writing Large Datasets
@@ -113,7 +113,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types. LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types.
Here's an example using using `RecordBatch` iterator for creating tables. Here's an example using using `RecordBatch` iterator for creating tables.
```python ```python
import pyarrow as pa import pyarrow as pa
@@ -142,11 +142,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
## Creating Empty Table ## Creating Empty Table
You can also create empty tables in python. Initialize it with schema and later ingest data into it. You can also create empty tables in python. Initialize it with schema and later ingest data into it.
```python ```python
import lancedb import lancedb
import pyarrow as pa import pyarrow as pa
schema = pa.schema( schema = pa.schema(
[ [
pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -168,8 +168,8 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
from lancedb.pydantic import LanceModel, vector from lancedb.pydantic import LanceModel, vector
class Model(LanceModel): class Model(LanceModel):
vector: vector(2) vector: Vector(2)
tbl = db.create_table("table5", schema=Model.to_arrow_schema()) tbl = db.create_table("table5", schema=Model.to_arrow_schema())
``` ```
@@ -249,7 +249,7 @@ After a table has been created, you can always add more data to it using
You can also add a large dataset batch in one go using Iterator of any supported data types. You can also add a large dataset batch in one go using Iterator of any supported data types.
### Adding to table using Iterator ### Adding to table using Iterator
```python ```python
import pandas as pd import pandas as pd
@@ -261,10 +261,10 @@ After a table has been created, you can always add more data to it using
"item": ["foo", "bar"], "item": ["foo", "bar"],
"price": [10.0, 20.0], "price": [10.0, 20.0],
}) })
tbl.add(make_batches()) tbl.add(make_batches())
``` ```
The other arguments accepted: The other arguments accepted:
| Name | Type | Description | Default | | Name | Type | Description | Default |
@@ -274,7 +274,7 @@ After a table has been created, you can always add more data to it using
| on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop | | on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop |
| fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 | | fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 |
=== "Javascript/Typescript" === "Javascript/Typescript"
```javascript ```javascript
@@ -312,7 +312,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector # x vector
# 0 1 [1.0, 2.0] # 0 1 [1.0, 2.0]
# 1 3 [5.0, 6.0] # 1 3 [5.0, 6.0]
``` ```
### Delete from a list of values ### Delete from a list of values
@@ -325,7 +325,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector # x vector
# 0 3 [5.0, 6.0] # 0 3 [5.0, 6.0]
``` ```
=== "Javascript/Typescript" === "Javascript/Typescript"
```javascript ```javascript

View File

@@ -1,6 +1,6 @@
# LanceDB # LanceDB
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings. LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrieval, filtering and management of embeddings.
![Illustration](/lancedb/assets/ecosystem-illustration.png) ![Illustration](/lancedb/assets/ecosystem-illustration.png)

View File

@@ -249,11 +249,11 @@
} }
], ],
"source": [ "source": [
"from lancedb.pydantic import vector, LanceModel\n", "from lancedb.pydantic import Vector, LanceModel\n",
"\n", "\n",
"class Content(LanceModel):\n", "class Content(LanceModel):\n",
" movie_id: int\n", " movie_id: int\n",
" vector: vector(128)\n", " vector: Vector(128)\n",
" genres: str\n", " genres: str\n",
" title: str\n", " title: str\n",
" imdb_id: int\n", " imdb_id: int\n",
@@ -359,7 +359,7 @@
"import pandas as pd\n", "import pandas as pd\n",
"\n", "\n",
"class PydanticSchema(LanceModel):\n", "class PydanticSchema(LanceModel):\n",
" vector: vector(2)\n", " vector: Vector(2)\n",
" item: str\n", " item: str\n",
" price: float\n", " price: float\n",
"\n", "\n",
@@ -394,10 +394,10 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import lancedb\n", "import lancedb\n",
"from lancedb.pydantic import LanceModel, vector\n", "from lancedb.pydantic import LanceModel, Vector\n",
"\n", "\n",
"class Model(LanceModel):\n", "class Model(LanceModel):\n",
" vector: vector(2)\n", " vector: Vector(2)\n",
"\n", "\n",
"tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())" "tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())"
] ]

View File

@@ -13,10 +13,10 @@ via [pydantic_to_schema()](python.md##lancedb.pydantic.pydantic_to_schema) metho
## Vector Field ## Vector Field
LanceDB provides a [`vector(dim)`](python.md#lancedb.pydantic.vector) method to define a LanceDB provides a [`Vector(dim)`](python.md#lancedb.pydantic.Vector) method to define a
vector Field in a Pydantic Model. vector Field in a Pydantic Model.
::: lancedb.pydantic.vector ::: lancedb.pydantic.Vector
## Type Conversion ## Type Conversion
@@ -33,4 +33,4 @@ Current supported type conversions:
| `str` | `pyarrow.utf8()` | | `str` | `pyarrow.utf8()` |
| `list` | `pyarrow.List` | | `list` | `pyarrow.List` |
| `BaseModel` | `pyarrow.Struct` | | `BaseModel` | `pyarrow.Struct` |
| `vector(n)` | `pyarrow.FixedSizeList(float32, n)` | | `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |

View File

@@ -28,7 +28,13 @@ pip install lancedb
::: lancedb.embeddings.with_embeddings ::: lancedb.embeddings.with_embeddings
::: lancedb.embeddings.EmbeddingFunction ::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
::: lancedb.embeddings.functions.EmbeddingFunctionModel
::: lancedb.embeddings.functions.TextEmbeddingFunctionModel
::: lancedb.embeddings.functions.SentenceTransformerEmbeddingFunction
## Context ## Context

View File

@@ -8,7 +8,8 @@ excluded_globs = [
"../src/embedding.md", "../src/embedding.md",
"../src/examples/*.md", "../src/examples/*.md",
"../src/integrations/voxel51.md", "../src/integrations/voxel51.md",
"../src/guides/tables.md" "../src/guides/tables.md",
"../src/python/duckdb.md",
] ]
python_prefix = "py" python_prefix = "py"

105
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.2.3", "version": "0.2.5",
"lockfileVersion": 2, "lockfileVersion": 2,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.2.3", "version": "0.2.5",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -31,6 +31,7 @@
"@types/node": "^18.16.2", "@types/node": "^18.16.2",
"@types/sinon": "^10.0.15", "@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1", "@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1", "cargo-cp-artifact": "^0.1",
"chai": "^4.3.7", "chai": "^4.3.7",
@@ -48,14 +49,15 @@
"ts-node-dev": "^2.0.0", "ts-node-dev": "^2.0.0",
"typedoc": "^0.24.7", "typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3", "typedoc-plugin-markdown": "^3.15.3",
"typescript": "*" "typescript": "*",
"uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.2.3", "@lancedb/vectordb-darwin-arm64": "0.2.5",
"@lancedb/vectordb-darwin-x64": "0.2.3", "@lancedb/vectordb-darwin-x64": "0.2.5",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.3", "@lancedb/vectordb-linux-arm64-gnu": "0.2.5",
"@lancedb/vectordb-linux-x64-gnu": "0.2.3", "@lancedb/vectordb-linux-x64-gnu": "0.2.5",
"@lancedb/vectordb-win32-x64-msvc": "0.2.3" "@lancedb/vectordb-win32-x64-msvc": "0.2.5"
} }
}, },
"node_modules/@apache-arrow/ts": { "node_modules/@apache-arrow/ts": {
@@ -315,9 +317,9 @@
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": { "node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.5.tgz",
"integrity": "sha512-/9dRCXrV/UsZv3fqAC/Q+D2FPKXMRprcb+a77tt4I0Iy5iGT55UDRfpaXvmJeKquhTJkZ0AuyoK5BmOh7cY41w==", "integrity": "sha512-V4206SajkMN3o+bBFBAYJq5emlrjevitP0g8RFfVlmj/LS38i8k4uvSe1bICQ2amUrYkL/Jw4ktYn19NRfTU+g==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -327,9 +329,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-darwin-x64": { "node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.5.tgz",
"integrity": "sha512-p06WkjmdVwDxkH8ghIWh59SCgUhjXBpy1gQISgktouymqfoFbBHz7vmeI6VO1oBA5ji6vSgGZxqjmeLRKM6blA==", "integrity": "sha512-orePizgXCbTJbDJ4bMMnYh/4OgmWDBbHShNxHKQobcX+NgWTexmR0lV1WNOG+DtczBiGH422e3gHJ+xhTO13vg==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -339,9 +341,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-gnu": { "node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.5.tgz",
"integrity": "sha512-cSDcJgfbnRmCXZ3AoRWpCAa07PMdB/k8m1LjmxnhpOnP1ohg1eUl99jwPCgd+5GK+iZmezRqbyO+YXlgsCp7GQ==", "integrity": "sha512-xIMNwsFGOHeY9EUWCHhUAcA2sCHZ5Lim0sc42uuUOeWayyH+HeR6ZWReptDQRuAoJHqQeag9qcqteE0AZPDTEw==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -351,9 +353,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-gnu": { "node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.5.tgz",
"integrity": "sha512-AFA3J4hBYapGC37iXheiN6tGruitx5bmoWXkUcDv/qAaE4tizVZHB9cgx9ThTB0RDsvZEOZ5zCy7BOzPH+oCOg==", "integrity": "sha512-Qr8dbHavtE+Zfd45kEORJQe01kRWhMF703gk8zhtZhskDUBCfqm3ap22JIux58tASxVcBqY8EtUFojfYGnQVvA==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -363,9 +365,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-win32-x64-msvc": { "node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.5.tgz",
"integrity": "sha512-LI1mz1HdcpNXTM7HbcLdXz0qvUU4LxSqRC7/kMU918VlOeWy/PnryRrjHnCjcgciGzu1rVlvCqRPh7fVwaG6Kg==", "integrity": "sha512-jTqkR9HRfbjxhUrlTfveNkJ78tlpVXeNn3BS4wBm4VIsPd75jminKBRYtrlQCWyHusqrUQedKny4hhG1CuNUkg==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -596,6 +598,12 @@
"@types/node": "*" "@types/node": "*"
} }
}, },
"node_modules/@types/uuid": {
"version": "9.0.3",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
"dev": true
},
"node_modules/@typescript-eslint/eslint-plugin": { "node_modules/@typescript-eslint/eslint-plugin": {
"version": "5.59.1", "version": "5.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
@@ -4451,6 +4459,15 @@
"punycode": "^2.1.0" "punycode": "^2.1.0"
} }
}, },
"node_modules/uuid": {
"version": "9.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
"dev": true,
"bin": {
"uuid": "dist/bin/uuid"
}
},
"node_modules/v8-compile-cache-lib": { "node_modules/v8-compile-cache-lib": {
"version": "3.0.1", "version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
@@ -4852,33 +4869,33 @@
} }
}, },
"@lancedb/vectordb-darwin-arm64": { "@lancedb/vectordb-darwin-arm64": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.5.tgz",
"integrity": "sha512-/9dRCXrV/UsZv3fqAC/Q+D2FPKXMRprcb+a77tt4I0Iy5iGT55UDRfpaXvmJeKquhTJkZ0AuyoK5BmOh7cY41w==", "integrity": "sha512-V4206SajkMN3o+bBFBAYJq5emlrjevitP0g8RFfVlmj/LS38i8k4uvSe1bICQ2amUrYkL/Jw4ktYn19NRfTU+g==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-darwin-x64": { "@lancedb/vectordb-darwin-x64": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.5.tgz",
"integrity": "sha512-p06WkjmdVwDxkH8ghIWh59SCgUhjXBpy1gQISgktouymqfoFbBHz7vmeI6VO1oBA5ji6vSgGZxqjmeLRKM6blA==", "integrity": "sha512-orePizgXCbTJbDJ4bMMnYh/4OgmWDBbHShNxHKQobcX+NgWTexmR0lV1WNOG+DtczBiGH422e3gHJ+xhTO13vg==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-linux-arm64-gnu": { "@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.5.tgz",
"integrity": "sha512-cSDcJgfbnRmCXZ3AoRWpCAa07PMdB/k8m1LjmxnhpOnP1ohg1eUl99jwPCgd+5GK+iZmezRqbyO+YXlgsCp7GQ==", "integrity": "sha512-xIMNwsFGOHeY9EUWCHhUAcA2sCHZ5Lim0sc42uuUOeWayyH+HeR6ZWReptDQRuAoJHqQeag9qcqteE0AZPDTEw==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-linux-x64-gnu": { "@lancedb/vectordb-linux-x64-gnu": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.5.tgz",
"integrity": "sha512-AFA3J4hBYapGC37iXheiN6tGruitx5bmoWXkUcDv/qAaE4tizVZHB9cgx9ThTB0RDsvZEOZ5zCy7BOzPH+oCOg==", "integrity": "sha512-Qr8dbHavtE+Zfd45kEORJQe01kRWhMF703gk8zhtZhskDUBCfqm3ap22JIux58tASxVcBqY8EtUFojfYGnQVvA==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-win32-x64-msvc": { "@lancedb/vectordb-win32-x64-msvc": {
"version": "0.2.3", "version": "0.2.5",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.3.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.5.tgz",
"integrity": "sha512-LI1mz1HdcpNXTM7HbcLdXz0qvUU4LxSqRC7/kMU918VlOeWy/PnryRrjHnCjcgciGzu1rVlvCqRPh7fVwaG6Kg==", "integrity": "sha512-jTqkR9HRfbjxhUrlTfveNkJ78tlpVXeNn3BS4wBm4VIsPd75jminKBRYtrlQCWyHusqrUQedKny4hhG1CuNUkg==",
"optional": true "optional": true
}, },
"@neon-rs/cli": { "@neon-rs/cli": {
@@ -5093,6 +5110,12 @@
"@types/node": "*" "@types/node": "*"
} }
}, },
"@types/uuid": {
"version": "9.0.3",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
"dev": true
},
"@typescript-eslint/eslint-plugin": { "@typescript-eslint/eslint-plugin": {
"version": "5.59.1", "version": "5.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
@@ -7844,6 +7867,12 @@
"punycode": "^2.1.0" "punycode": "^2.1.0"
} }
}, },
"uuid": {
"version": "9.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
"dev": true
},
"v8-compile-cache-lib": { "v8-compile-cache-lib": {
"version": "3.0.1", "version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.2.3", "version": "0.2.5",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
@@ -9,6 +9,7 @@
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json", "build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
"build-release": "npm run build -- --release", "build-release": "npm run build -- --release",
"test": "npm run tsc && mocha -recursive dist/test", "test": "npm run tsc && mocha -recursive dist/test",
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
"lint": "eslint native.js src --ext .js,.ts", "lint": "eslint native.js src --ext .js,.ts",
"clean": "rm -rf node_modules *.node dist/", "clean": "rm -rf node_modules *.node dist/",
"pack-build": "neon pack-build", "pack-build": "neon pack-build",
@@ -34,6 +35,7 @@
"@types/node": "^18.16.2", "@types/node": "^18.16.2",
"@types/sinon": "^10.0.15", "@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1", "@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1", "cargo-cp-artifact": "^0.1",
"chai": "^4.3.7", "chai": "^4.3.7",
@@ -51,7 +53,8 @@
"ts-node-dev": "^2.0.0", "ts-node-dev": "^2.0.0",
"typedoc": "^0.24.7", "typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3", "typedoc-plugin-markdown": "^3.15.3",
"typescript": "*" "typescript": "*",
"uuid": "^9.0.0"
}, },
"dependencies": { "dependencies": {
"@apache-arrow/ts": "^12.0.0", "@apache-arrow/ts": "^12.0.0",
@@ -78,10 +81,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.2.3", "@lancedb/vectordb-darwin-arm64": "0.2.5",
"@lancedb/vectordb-darwin-x64": "0.2.3", "@lancedb/vectordb-darwin-x64": "0.2.5",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.3", "@lancedb/vectordb-linux-arm64-gnu": "0.2.5",
"@lancedb/vectordb-linux-x64-gnu": "0.2.3", "@lancedb/vectordb-linux-x64-gnu": "0.2.5",
"@lancedb/vectordb-win32-x64-msvc": "0.2.3" "@lancedb/vectordb-win32-x64-msvc": "0.2.5"
} }
} }

View File

@@ -0,0 +1,43 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { describe } from 'mocha'
import * as chai from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import { v4 as uuidv4 } from 'uuid'
import * as lancedb from '../index'
const assert = chai.assert
chai.use(chaiAsPromised)
describe('LanceDB AWS Integration test', function () {
it('s3+ddb schema is processed correctly', async function () {
this.timeout(5000)
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
// THE API WILL CHANGE
const conn = await lancedb.connect('s3://lancedb-integtest?engine=ddb&ddbTableName=lancedb-integtest')
const data = [{ vector: Array(128).fill(1.0) }]
const tableName = uuidv4()
let table = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
const futs = [table.add(data), table.add(data), table.add(data), table.add(data), table.add(data)]
await Promise.allSettled(futs)
table = await conn.openTable(tableName)
assert.equal(await table.countRows(), 6)
})
})

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.2.1 current_version = 0.2.3
commit = True commit = True
message = [python] Bump version: {current_version} → {new_version} message = [python] Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -31,9 +31,13 @@ def connect(
---------- ----------
uri: str or Path uri: str or Path
The uri of the database. The uri of the database.
api_token: str, optional api_key: str, optional
If presented, connect to LanceDB cloud. If presented, connect to LanceDB cloud.
Otherwise, connect to a database on file system or cloud storage. Otherwise, connect to a database on file system or cloud storage.
region: str, default "us-west-2"
The region to use for LanceDB Cloud.
host_override: str, optional
The override url for LanceDB Cloud.
Examples Examples
-------- --------

View File

@@ -1,7 +1,10 @@
import os import os
import pyarrow as pa
import pytest import pytest
from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
# import lancedb so we don't have to in every example # import lancedb so we don't have to in every example
@@ -14,3 +17,22 @@ def doctest_setup(monkeypatch, tmpdir):
monkeypatch.setitem(os.environ, "COLUMNS", "80") monkeypatch.setitem(os.environ, "COLUMNS", "80")
# Work in a temporary directory # Work in a temporary directory
monkeypatch.chdir(tmpdir) monkeypatch.chdir(tmpdir)
registry = EmbeddingFunctionRegistry.get_instance()
@registry.register()
class MockEmbeddingFunction(EmbeddingFunctionModel):
def __call__(self, data):
if isinstance(data, str):
data = [data]
elif isinstance(data, pa.ChunkedArray):
data = data.combine_chunks().to_pylist()
elif isinstance(data, pa.Array):
data = data.to_pylist()
return [self.embed(row) for row in data]
def embed(self, row):
return [float(hash(c)) for c in row[:10]]

View File

@@ -16,12 +16,13 @@ from __future__ import annotations
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional, Union
import pyarrow as pa import pyarrow as pa
from pyarrow import fs from pyarrow import fs
from .common import DATA, URI from .common import DATA, URI
from .embeddings import EmbeddingFunctionModel
from .pydantic import LanceModel from .pydantic import LanceModel
from .table import LanceTable, Table from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme from .util import fs_from_uri, get_uri_location, get_uri_scheme
@@ -40,7 +41,7 @@ class DBConnection(ABC):
self, self,
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[pa.Schema, LanceModel] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create", mode: str = "create",
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
@@ -285,10 +286,11 @@ class LanceDBConnection(DBConnection):
self, self,
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[pa.Schema, LanceModel] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create", mode: str = "create",
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionModel]] = None,
) -> LanceTable: ) -> LanceTable:
"""Create a table in the database. """Create a table in the database.
@@ -307,6 +309,7 @@ class LanceDBConnection(DBConnection):
mode=mode, mode=mode,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
embedding_functions=embedding_functions,
) )
return tbl return tbl

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .functions import (
REGISTRY,
EmbeddingFunctionModel,
EmbeddingFunctionRegistry,
SentenceTransformerEmbeddingFunction,
TextEmbeddingFunctionModel,
)
from .utils import with_embeddings

View File

@@ -0,0 +1,228 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing
"""
@classmethod
def get_instance(cls):
return REGISTRY
def __init__(self):
self._functions = {}
def register(self):
"""
This creates a decorator that can be used to register
an EmbeddingFunctionModel.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunctionModel):
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
self._functions[cls.__name__] = cls
return cls
return decorator
def reset(self):
"""
Reset the registry to its initial state
"""
self._functions = {}
def load(self, name: str):
"""
Fetch an embedding function class by name
"""
return self._functions[name]
def parse_functions(self, metadata: Optional[dict]) -> dict:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
embedding function and source column
Parameters
----------
metadata : Optional[dict]
The metadata from an arrow table. Note that
the keys and values are bytes.
Returns
-------
functions : dict
A mapping of vector column name to embedding function.
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None or b"embedding_functions" not in metadata:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
functions = {}
for obj in raw_list:
model = self.load(obj["schema"]["title"])
functions[obj["model"]["vector_column"]] = model(**obj["model"])
return functions
def function_to_metadata(self, func):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
schema = func.model_json_schema()
json_data = func.model_dump()
return {
"schema": schema,
"model": json_data,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
REGISTRY = EmbeddingFunctionRegistry()
class EmbeddingFunctionModel(BaseModel, ABC):
"""
A callable ABC for embedding functions
"""
source_column: Optional[str]
vector_column: str
@abstractmethod
def __call__(self, *args, **kwargs) -> List[np.array]:
pass
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
"""
A callable ABC for embedding functions that take text as input
"""
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function. This is called
before generate_embeddings() and is useful for stripping
whitespace, lowercasing, etc.
"""
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, pa.Array):
texts = texts.to_pylist()
elif isinstance(texts, pa.ChunkedArray):
texts = texts.combine_chunks().to_pylist()
return texts
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
pass
@REGISTRY.register()
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
"""
An embedding function that uses the sentence-transformers library
"""
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = False
@property
def embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.__class__.get_embedding_model(self.name, self.device)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@classmethod
@cached(cache={})
def get_embedding_model(cls, name, device):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
Parameters
----------
name : str
The name of the model to load
device : str
The device to load the model on
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
try:
from sentence_transformers import SentenceTransformer
return SentenceTransformer(name, device=device)
except ImportError:
raise ValueError("Please install sentence_transformers")

View File

@@ -1,4 +1,4 @@
# Copyright 2023 LanceDB Developers # Copyright (c) 2023. LanceDB Developers
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ import pyarrow as pa
from lance.vector import vec_to_table from lance.vector import vec_to_table
from retry import retry from retry import retry
from .util import safe_import_pandas from ..util import safe_import_pandas
pd = safe_import_pandas() pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"] DATA = Union[pa.Table, "pd.DataFrame"]
@@ -58,7 +58,7 @@ def with_embeddings(
pa.Table pa.Table
The input table with a new column called "vector" containing the embeddings. The input table with a new column called "vector" containing the embeddings.
""" """
func = EmbeddingFunction(func) func = FunctionWrapper(func)
if wrap_api: if wrap_api:
func = func.retry().rate_limit() func = func.retry().rate_limit()
func = func.batch_size(batch_size) func = func.batch_size(batch_size)
@@ -71,7 +71,11 @@ def with_embeddings(
return data.append_column("vector", table["vector"]) return data.append_column("vector", table["vector"])
class EmbeddingFunction: class FunctionWrapper:
"""
A wrapper for embedding functions that adds rate limiting, retries, and batching.
"""
def __init__(self, func: Callable): def __init__(self, func: Callable):
self.func = func self.func = func
self.rate_limiter_kwargs = {} self.rate_limiter_kwargs = {}

View File

@@ -46,7 +46,19 @@ class FixedSizeListMixin(ABC):
raise NotImplementedError raise NotImplementedError
def vector( def vector(dim: int, value_type: pa.DataType = pa.float32()):
# TODO: remove in future release
from warnings import warn
warn(
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
"This function will be removed in future release",
DeprecationWarning,
)
return Vector(dim, value_type)
def Vector(
dim: int, value_type: pa.DataType = pa.float32() dim: int, value_type: pa.DataType = pa.float32()
) -> Type[FixedSizeListMixin]: ) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type. """Pydantic Vector Type.
@@ -65,12 +77,12 @@ def vector(
-------- --------
>>> import pydantic >>> import pydantic
>>> from lancedb.pydantic import vector >>> from lancedb.pydantic import Vector
... ...
>>> class MyModel(pydantic.BaseModel): >>> class MyModel(pydantic.BaseModel):
... id: int ... id: int
... url: str ... url: str
... embeddings: vector(768) ... embeddings: Vector(768)
>>> schema = pydantic_to_schema(MyModel) >>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([ >>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False), ... pa.field("id", pa.int64(), False),
@@ -258,11 +270,11 @@ class LanceModel(pydantic.BaseModel):
Examples Examples
-------- --------
>>> import lancedb >>> import lancedb
>>> from lancedb.pydantic import LanceModel, vector >>> from lancedb.pydantic import LanceModel, Vector
>>> >>>
>>> class TestModel(LanceModel): >>> class TestModel(LanceModel):
... name: str ... name: str
... vector: vector(2) ... vector: Vector(2)
... ...
>>> db = lancedb.connect("/tmp") >>> db = lancedb.connect("/tmp")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema()) >>> table = db.create_table("test", schema=TestModel.to_arrow_schema())

View File

@@ -13,6 +13,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union from typing import List, Literal, Optional, Type, Union
import numpy as np import numpy as np
@@ -54,7 +55,164 @@ class Query(pydantic.BaseModel):
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
class LanceQueryBuilder: class LanceQueryBuilder(ABC):
@classmethod
def create(
cls,
table: "lancedb.table.Table",
query: Optional[Union[np.ndarray, str]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
if isinstance(query, list):
query = np.array(query, dtype=np.float32)
elif isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name)
@classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name):
# If query_type is fts, then query must be a string.
# otherwise raise TypeError
if query_type == "fts":
if not isinstance(query, str):
raise TypeError(
f"Query type is 'fts' but query is not a string: {type(query)}"
)
return query, query_type
elif query_type == "vector":
# If query_type is vector, then query must be a list or np.ndarray.
# otherwise raise TypeError
if not isinstance(query, (list, np.ndarray)):
raise TypeError(
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
elif isinstance(query, str):
func = table.embedding_functions.get(vector_column_name, None)
if func is not None:
query = func(query)[0]
return query, "vector"
else:
return query, "fts"
else:
raise TypeError("Query must be a list, np.ndarray, or str")
else:
raise ValueError(
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
)
def __init__(self, table: "lancedb.table.Table"):
self._table = table
self._limit = 10
self._columns = None
self._where = None
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
raise NotImplementedError
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
def limit(self, limit: int) -> LanceVectorQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceVectorQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
class LanceVectorQueryBuilder(LanceQueryBuilder):
""" """
A builder for nearest neighbor queries for LanceDB. A builder for nearest neighbor queries for LanceDB.
@@ -80,68 +238,17 @@ class LanceQueryBuilder:
def __init__( def __init__(
self, self,
table: "lancedb.table.Table", table: "lancedb.table.Table",
query: Union[np.ndarray, str], query: Union[np.ndarray, list],
vector_column: str = VECTOR_COLUMN_NAME, vector_column: str = VECTOR_COLUMN_NAME,
): ):
super().__init__(table)
self._query = query
self._metric = "L2" self._metric = "L2"
self._nprobes = 20 self._nprobes = 20
self._refine_factor = None self._refine_factor = None
self._table = table
self._query = query
self._limit = 10
self._columns = None
self._where = None
self._vector_column = vector_column self._vector_column = vector_column
def limit(self, limit: int) -> LanceQueryBuilder: def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
Parameters Parameters
@@ -151,13 +258,13 @@ class LanceQueryBuilder:
Returns Returns
------- -------
LanceQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._metric = metric self._metric = metric
return self return self
def nprobes(self, nprobes: int) -> LanceQueryBuilder: def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
"""Set the number of probes to use. """Set the number of probes to use.
Higher values will yield better recall (more likely to find vectors if Higher values will yield better recall (more likely to find vectors if
@@ -173,13 +280,13 @@ class LanceQueryBuilder:
Returns Returns
------- -------
LanceQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder: def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled. """Set the refine factor to use, increasing the number of vectors sampled.
As an example, a refine factor of 2 will sample 2x as many vectors as As an example, a refine factor of 2 will sample 2x as many vectors as
@@ -195,22 +302,12 @@ class LanceQueryBuilder:
Returns Returns
------- -------
LanceQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._refine_factor = refine_factor self._refine_factor = refine_factor
return self return self
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
""" """
Execute the query and return the results as an Execute the query and return the results as an
@@ -233,25 +330,12 @@ class LanceQueryBuilder:
) )
return self._table._execute_query(query) return self._table._execute_query(query)
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "lancedb.table.Table", query: str):
super().__init__(table)
self._query = query
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
try: try:
import tantivy import tantivy
@@ -275,3 +359,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores) output_tbl = output_tbl.append_column("score", scores)
return output_tbl return output_tbl
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
limit=self._limit,
)

View File

@@ -20,7 +20,7 @@ from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
from .arrow import to_ipc_binary from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE from .client import ARROW_STREAM_CONTENT_TYPE
@@ -73,7 +73,11 @@ class RemoteTable(Table):
fill_value: float = 0.0, fill_value: float = 0.0,
) -> int: ) -> int:
data = _sanitize_data( data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
) )
payload = to_ipc_binary(data) payload = to_ipc_binary(data)
@@ -89,9 +93,9 @@ class RemoteTable(Table):
) )
def search( def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder: ) -> LanceVectorQueryBuilder:
return LanceQueryBuilder(self, query, vector_column) return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)

View File

@@ -17,64 +17,98 @@ import inspect
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import Iterable, List, Union from typing import Any, Iterable, List, Optional, Union
import lance import lance
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
import pyarrow.compute as pc import pyarrow.compute as pc
from lance import LanceDataset from lance import LanceDataset
from lance.dataset import ReaderLike
from lance.vector import vec_to_table from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
from .pydantic import LanceModel from .pydantic import LanceModel
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas from .util import fs_from_uri, safe_import_pandas
pd = safe_import_pandas() pd = safe_import_pandas()
def _sanitize_data(data, schema, on_bad_vectors, fill_value): def _sanitize_data(
data,
schema: Optional[pa.Schema],
metadata: Optional[dict],
on_bad_vectors: str,
fill_value: Any,
):
if isinstance(data, list): if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels # convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel): if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema() schema = data[0].__class__.to_arrow_schema()
data = [dict(d) for d in data] data = [dict(d) for d in data]
data = pa.Table.from_pylist(data) data = pa.Table.from_pylist(data)
data = _sanitize_schema( elif isinstance(data, dict):
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if isinstance(data, dict):
data = vec_to_table(data) data = vec_to_table(data)
if pd is not None and isinstance(data, pd.DataFrame): elif pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False) data = pa.Table.from_pandas(data, preserve_index=False)
# Do not serialize Pandas metadata
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
if isinstance(data, pa.Table):
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
data = _sanitize_schema( data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
) )
# Do not serialize Pandas metadata elif isinstance(data, Iterable):
metadata = data.schema.metadata if data.schema.metadata is not None else {} data = _to_record_batch_generator(
metadata = {k: v for k, v in metadata.items() if k != b"pandas"} data, schema, metadata, on_bad_vectors, fill_value
schema = data.schema.with_metadata(metadata) )
data = pa.Table.from_arrays(data.columns, schema=schema) else:
if isinstance(data, Iterable):
data = _to_record_batch_generator(data, schema, on_bad_vectors, fill_value)
if not isinstance(data, (pa.Table, Iterable)):
raise TypeError(f"Unsupported data type: {type(data)}") raise TypeError(f"Unsupported data type: {type(data)}")
return data return data
def _to_record_batch_generator(data: Iterable, schema, on_bad_vectors, fill_value): def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
"""
Use the embedding function to automatically embed the source column and add the
vector column to the table.
"""
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_col, func in functions.items():
if vector_col not in data.column_names:
col_data = func(data[func.source_column])
if schema is not None:
dtype = schema.field(vector_col).type
else:
dtype = pa.list_(pa.float32(), len(col_data[0]))
data = data.append_column(
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
)
return data
def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value
):
for batch in data: for batch in data:
if not isinstance(batch, pa.RecordBatch): if not isinstance(batch, pa.RecordBatch):
table = _sanitize_data(batch, schema, on_bad_vectors, fill_value) table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for batch in table.to_batches(): for batch in table.to_batches():
yield batch yield batch
yield batch else:
yield batch
class Table(ABC): class Table(ABC):
""" """
A [Table](Table) is a collection of Records in a LanceDB [Database](Database). A Table is a collection of Records in a LanceDB Database.
Examples Examples
-------- --------
@@ -195,17 +229,28 @@ class Table(ABC):
@abstractmethod @abstractmethod
def search( def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME self,
query: Optional[Union[VEC, str]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. of the given query vector.
Parameters Parameters
---------- ----------
query: list, np.ndarray query: str, list, np.ndarray, default None
The query vector. The query to search for. If None then
vector_column: str, default "vector" the select/where/limit clauses are applied to filter
the table
vector_column_name: str, default "vector"
The name of the vector column to search. The name of the vector column to search.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns Returns
------- -------
@@ -311,7 +356,7 @@ class LanceTable(Table):
This allows viewing previous versions of the table. If you wish to This allows viewing previous versions of the table. If you wish to
keep writing to the dataset starting from an old version, then use keep writing to the dataset starting from an old version, then use
the `restore` function instead. the `restore` function.
Parameters Parameters
---------- ----------
@@ -324,14 +369,14 @@ class LanceTable(Table):
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}]) >>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version >>> table.version
1 2
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> table.version >>> table.version
2 3
>>> table.checkout(1) >>> table.checkout(2)
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
@@ -341,16 +386,18 @@ class LanceTable(Table):
raise ValueError(f"Invalid version {version}") raise ValueError(f"Invalid version {version}")
self._reset_dataset(version=version) self._reset_dataset(version=version)
def restore(self, version: int): def restore(self, version: int = None):
"""Restore a version of the table. This is an in-place operation. """Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the This creates a new version where the data is equivalent to the
specified previous version. Note that this creates a new snapshot. specified previous version. Data is not copied (as of python-v0.2.1).
Parameters Parameters
---------- ----------
version : int version : int, default None
The version to restore. The version to restore. If unspecified then restores the currently
checked out version. If the currently checked out version is the
latest version then this is a no-op.
Examples Examples
-------- --------
@@ -358,30 +405,33 @@ class LanceTable(Table):
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}]) >>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
>>> table.version >>> table.version
1 2
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
>>> table.version >>> table.version
2 3
>>> table.restore(1) >>> table.restore(2)
>>> table.to_pandas() >>> table.to_pandas()
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
>>> len(table.list_versions()) >>> len(table.list_versions())
3 4
""" """
max_ver = max([v["version"] for v in self._dataset.versions()]) max_ver = max([v["version"] for v in self._dataset.versions()])
if version < 1 or version >= max_ver: if version is None:
version = self.version
elif version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}") raise ValueError(f"Invalid version {version}")
else:
self.checkout(version)
if version == max_ver: if version == max_ver:
self._reset_dataset() # no-op if restoring the latest version
return return
self.checkout(version)
data = self.to_arrow() self._dataset.restore()
self.checkout(max_ver)
self.add(data, mode="overwrite")
self._reset_dataset() self._reset_dataset()
def __len__(self): def __len__(self):
@@ -495,23 +545,122 @@ class LanceTable(Table):
""" """
# TODO: manage table listing and metadata separately # TODO: manage table listing and metadata separately
data = _sanitize_data( data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
) )
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode) lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._reset_dataset() self._reset_dataset()
def merge(
self,
other_table: Union[LanceTable, ReaderLike],
left_on: str,
right_on: Optional[str] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
):
"""Merge another table into this table.
Performs a left join, where the dataset is the left side and other_table
is the right side. Rows existing in the dataset but not on the left will
be filled with null values, unless Lance doesn't support null values for
some types, in which case an error will be raised. The only overlapping
column allowed is the join column. If other overlapping columns exist,
an error will be raised.
Parameters
----------
other_table: LanceTable or Reader-like
The data to be merged. Acceptable types are:
- Pandas DataFrame, Pyarrow Table, Dataset, Scanner,
Iterator[RecordBatch], or RecordBatchReader
- LanceTable
left_on: str
The name of the column in the dataset to join on.
right_on: str or None
The name of the column in other_table to join on. If None, defaults to
left_on.
schema: pa.Schema or LanceModel, optional
The schema of the other_table.
If not provided, the schema is inferred from the data.
Examples
--------
>>> import lancedb
>>> import pyarrow as pa
>>> df = pa.table({'x': [1, 2, 3], 'y': ['a', 'b', 'c']})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("dataset", df)
>>> table.to_pandas()
x y
0 1 a
1 2 b
2 3 c
>>> new_df = pa.table({'x': [1, 2, 3], 'z': ['d', 'e', 'f']})
>>> table.merge(new_df, 'x')
>>> table.to_pandas()
x y z
0 1 a d
1 2 b e
2 3 c f
"""
if isinstance(schema, LanceModel):
schema = schema.to_arrow_schema()
if isinstance(other_table, LanceTable):
other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset):
other_table = other_table.to_table()
self._dataset.merge(
other_table, left_on=left_on, right_on=right_on, schema=schema
)
self._reset_dataset()
def _get_embedding_function_for_source_col(self, column_name: str):
for k, v in self.embedding_functions.items():
if v.source_column == column_name:
return v
return None
@cached_property
def embedding_functions(self) -> dict:
"""
Get the embedding functions for the table
Returns
-------
funcs: dict
A mapping of the vector column to the embedding function
or empty dict if not configured.
"""
return EmbeddingFunctionRegistry.get_instance().parse_functions(
self.schema.metadata
)
def search( def search(
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME self,
query: Optional[Union[VEC, str]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. of the given query vector.
Parameters Parameters
---------- ----------
query: list, np.ndarray query: str, list, np.ndarray, or None
The query vector. The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
vector_column_name: str, default "vector" vector_column_name: str, default "vector"
The name of the vector column to search. The name of the vector column to search.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If the query is a list/np.ndarray then the query type is "vector";
If the query is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns Returns
------- -------
@@ -521,17 +670,9 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if isinstance(query, str): return LanceQueryBuilder.create(
# fts self, query, query_type, vector_column_name=vector_column_name
return LanceFtsQueryBuilder(self, query, vector_column_name) )
if isinstance(query, list):
query = np.array(query)
if isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query, vector_column_name)
@classmethod @classmethod
def create( def create(
@@ -543,6 +684,7 @@ class LanceTable(Table):
mode="create", mode="create",
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionModel] = None,
): ):
""" """
Create a new table. Create a new table.
@@ -580,20 +722,52 @@ class LanceTable(Table):
One of "error", "drop", "fill". One of "error", "drop", "fill".
fill_value: float, default 0. fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
embedding_functions: list of EmbeddingFunctionModel, default None
The embedding functions to use when creating the table.
""" """
tbl = LanceTable(db, name) tbl = LanceTable(db, name)
if inspect.isclass(schema) and issubclass(schema, LanceModel): if inspect.isclass(schema) and issubclass(schema, LanceModel):
schema = schema.to_arrow_schema() schema = schema.to_arrow_schema()
metadata = None
if embedding_functions is not None:
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)
if data is not None: if data is not None:
data = _sanitize_data( data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value data,
schema,
metadata=metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
) )
else:
if schema is None: if schema is None:
if data is None:
raise ValueError("Either data or schema must be provided") raise ValueError("Either data or schema must be provided")
data = pa.Table.from_pylist([], schema=schema) elif hasattr(data, "schema"):
lance.write_dataset(data, tbl._dataset_uri, schema=schema, mode=mode) schema = data.schema
return LanceTable(db, name) elif isinstance(data, Iterable):
if metadata:
raise TypeError(
(
"Persistent embedding functions not yet "
"supported for generator data input"
)
)
if metadata:
schema = schema.with_metadata(metadata)
empty = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
table = LanceTable(db, name)
if data is not None:
table.add(data)
return table
@classmethod @classmethod
def open(cls, db, name): def open(cls, db, name):
@@ -609,6 +783,56 @@ class LanceTable(Table):
def delete(self, where: str): def delete(self, where: str):
self._dataset.delete(where) self._dataset.delete(where)
def update(self, where: str, values: dict):
"""
EXPERIMENTAL: Update rows in the table (not threadsafe).
This can be used to update zero to all rows depending on how many
rows match the where clause.
Parameters
----------
where: str
The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict
The values to update. The keys are the column names and the values
are the values to set.
Examples
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 3 [5.0, 6.0]
2 2 [10.0, 10.0]
"""
orig_data = self._dataset.to_table(filter=where).combine_chunks()
if len(orig_data) == 0:
return
for col, val in values.items():
i = orig_data.column_names.index(col)
if i < 0:
raise ValueError(f"Column {col} does not exist")
orig_data = orig_data.set_column(
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
)
self.delete(where)
self.add(orig_data, mode="append")
self._reset_dataset()
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance() ds = self.to_lance()
return ds.to_table( return ds.to_table(
@@ -651,22 +875,38 @@ def _sanitize_schema(
return data return data
# cast the columns to the expected types # cast the columns to the expected types
data = data.combine_chunks() data = data.combine_chunks()
data = _sanitize_vector_column( for field in schema:
# TODO: we're making an assumption that fixed size list of 10 or more
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_float32(field.type.value_type)
and field.type.list_size >= 10
)
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
if field.name in data.column_names and (
likely_vector_col or is_default_vector_col
):
data = _sanitize_vector_column(
data,
vector_column_name=field.name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
# just check the vector column
if VECTOR_COLUMN_NAME in data.column_names:
return _sanitize_vector_column(
data, data,
vector_column_name=VECTOR_COLUMN_NAME, vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema return data
)
# just check the vector column
return _sanitize_vector_column(
data,
vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
def _sanitize_vector_column( def _sanitize_vector_column(
@@ -690,8 +930,6 @@ def _sanitize_vector_column(
fill_value: float, default 0.0 fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
""" """
if vector_column_name not in data.column_names:
raise ValueError(f"Missing vector column: {vector_column_name}")
# ChunkedArray is annoying to work with, so we combine chunks here # ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks() vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_list(data[vector_column_name].type): if pa.types.is_list(data[vector_column_name].type):

View File

@@ -1,15 +1,16 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.2.1" version = "0.2.3"
dependencies = [ dependencies = [
"pylance==0.6.5", "pylance==0.7.3",
"ratelimiter", "ratelimiter",
"retry", "retry",
"tqdm", "tqdm",
"aiohttp", "aiohttp",
"pydantic", "pydantic",
"attr", "attr",
"semver>=3.0" "semver>=3.0",
"cachetools"
] ]
description = "lancedb" description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]

View File

@@ -17,7 +17,7 @@ import pyarrow as pa
import pytest import pytest
import lancedb import lancedb
from lancedb.pydantic import LanceModel, vector from lancedb.pydantic import LanceModel, Vector
def test_basic(tmp_path): def test_basic(tmp_path):
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
def test_ingest_iterator(tmp_path): def test_ingest_iterator(tmp_path):
class PydanticSchema(LanceModel): class PydanticSchema(LanceModel):
vector: vector(2) vector: Vector(2)
item: str item: str
price: float price: float
@@ -143,8 +143,9 @@ def test_ingest_iterator(tmp_path):
tbl_len = len(tbl) tbl_len = len(tbl)
tbl.add(make_batches()) tbl.add(make_batches())
assert tbl_len == 50
assert len(tbl) == tbl_len * 2 assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 2 assert len(tbl.list_versions()) == 3
db.drop_database() db.drop_database()
run_tests(arrow_schema) run_tests(arrow_schema)

View File

@@ -12,10 +12,12 @@
# limitations under the License. # limitations under the License.
import sys import sys
import lance
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
from lancedb.embeddings import with_embeddings from lancedb.conftest import MockEmbeddingFunction
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
def mock_embed_func(input_data): def mock_embed_func(input_data):
@@ -40,3 +42,37 @@ def test_with_embeddings():
assert data.column_names == ["text", "price", "vector"] assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"] assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0] assert data.column("price").to_pylist() == [10.0, 20.0]
def test_embedding_function(tmp_path):
registry = EmbeddingFunctionRegistry.get_instance()
# let's create a table
table = pa.table(
{
"text": pa.array(["hello world", "goodbye world"]),
"vector": [np.random.randn(10), np.random.randn(10)],
}
)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
metadata = registry.get_table_metadata([func])
table = table.replace_schema_metadata(metadata)
# Write it to disk
lance.write_dataset(table, tmp_path / "test.lance")
# Load this back
ds = lance.dataset(tmp_path / "test.lance")
# can we get the serialized version back out?
functions = registry.parse_functions(ds.schema.metadata)
func = functions["vector"]
actual = func("hello world")
# We create an instance
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
# And we make sure we can call it
expected = expected_func("hello world")
assert np.allclose(actual, expected)

View File

@@ -19,8 +19,9 @@ from typing import List, Optional
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import pytest import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
@pytest.mark.skipif( @pytest.mark.skipif(
@@ -107,7 +108,7 @@ def test_pydantic_to_arrow_py38():
def test_fixed_size_list_field(): def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel):
vec: vector(16) vec: Vector(16)
li: List[int] li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3]) data = TestModel(vec=list(range(16)), li=[1, 2, 3])
@@ -154,7 +155,7 @@ def test_fixed_size_list_field():
def test_fixed_size_list_validation(): def test_fixed_size_list_validation():
class TestModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel):
vec: vector(8) vec: Vector(8)
with pytest.raises(pydantic.ValidationError): with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(9)) TestModel(vec=range(9))
@@ -167,9 +168,12 @@ def test_fixed_size_list_validation():
def test_lance_model(): def test_lance_model():
class TestModel(LanceModel): class TestModel(LanceModel):
vec: vector(16) vector: Vector(16) = Field(default=[0.0] * 16)
li: List[int] li: List[int] = Field(default=[1, 2, 3])
schema = pydantic_to_schema(TestModel) schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema() assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["vec", "li"] assert TestModel.field_names() == ["vector", "li"]
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])

View File

@@ -20,8 +20,8 @@ import pyarrow as pa
import pytest import pytest
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceQueryBuilder, Query from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable from lancedb.table import LanceTable
@@ -67,12 +67,12 @@ def table(tmp_path) -> MockTable:
def test_cast(table): def test_cast(table):
class TestModel(LanceModel): class TestModel(LanceModel):
vector: vector(2) vector: Vector(2)
id: int id: int
str_field: str str_field: str
float_field: float float_field: float
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1) q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
results = q.to_pydantic(TestModel) results = q.to_pydantic(TestModel)
assert len(results) == 1 assert len(results) == 1
r0 = results[0] r0 = results[0]
@@ -84,13 +84,15 @@ def test_cast(table):
def test_query_builder(table): def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df() df = (
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
)
assert df["id"].values[0] == 1 assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2]) assert all(df["vector"].values[0] == [1, 2])
def test_query_builder_with_filter(table): def test_query_builder_with_filter(table):
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df() df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2 assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4]) assert all(df["vector"].values[0] == [3, 4])
@@ -98,12 +100,14 @@ def test_query_builder_with_filter(table):
def test_query_builder_with_metric(table): def test_query_builder_with_metric(table):
query = [4, 8] query = [4, 8]
vector_column_name = "vector" vector_column_name = "vector"
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df() df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df() df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
)
tm.assert_frame_equal(df_default, df_l2) tm.assert_frame_equal(df_default, df_l2)
df_cosine = ( df_cosine = (
LanceQueryBuilder(table, query, vector_column_name) LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine") .metric("cosine")
.limit(1) .limit(1)
.to_df() .to_df()
@@ -120,7 +124,7 @@ def test_query_builder_with_different_vector_column():
query = [4, 8] query = [4, 8]
vector_column_name = "foo_vector" vector_column_name = "foo_vector"
builder = ( builder = (
LanceQueryBuilder(table, query, vector_column_name) LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine") .metric("cosine")
.where("b < 10") .where("b < 10")
.select(["b"]) .select(["b"])

View File

@@ -16,13 +16,15 @@ from pathlib import Path
from typing import List from typing import List
from unittest.mock import PropertyMock, patch from unittest.mock import PropertyMock, patch
import lance
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
import pytest import pytest
from lancedb.conftest import MockEmbeddingFunction
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector from lancedb.pydantic import LanceModel, Vector
from lancedb.table import LanceTable from lancedb.table import LanceTable
@@ -138,7 +140,7 @@ def test_add(db):
def test_add_pydantic_model(db): def test_add_pydantic_model(db):
class TestModel(LanceModel): class TestModel(LanceModel):
vector: vector(16) vector: Vector(16)
li: List[int] li: List[int]
data = TestModel(vector=list(range(16)), li=[1, 2, 3]) data = TestModel(vector=list(range(16)), li=[1, 2, 3])
@@ -177,16 +179,16 @@ def test_versioning(db):
], ],
) )
assert len(table.list_versions()) == 1
assert table.version == 1
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 2 assert len(table.list_versions()) == 2
assert table.version == 2 assert table.version == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 3 assert len(table) == 3
table.checkout(1) table.checkout(2)
assert table.version == 1 assert table.version == 2
assert len(table) == 2 assert len(table) == 2
@@ -277,6 +279,165 @@ def test_restore(db):
data=[{"vector": [1.1, 0.9], "type": "vector"}], data=[{"vector": [1.1, 0.9], "type": "vector"}],
) )
table.add([{"vector": [0.5, 0.2], "type": "vector"}]) table.add([{"vector": [0.5, 0.2], "type": "vector"}])
table.restore(1) table.restore(2)
assert len(table.list_versions()) == 3 assert len(table.list_versions()) == 4
assert len(table) == 1 assert len(table) == 1
expected = table.to_arrow()
table.checkout(2)
table.restore()
assert len(table.list_versions()) == 5
assert table.to_arrow() == expected
table.restore(5) # latest version should be no-op
assert len(table.list_versions()) == 5
with pytest.raises(ValueError):
table.restore(6)
with pytest.raises(ValueError):
table.restore(0)
def test_merge(db, tmp_path):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
table.merge(other_table, left_on="id")
assert len(table.list_versions()) == 3
expected = pa.table(
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
schema=table.schema,
)
assert table.to_arrow() == expected
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
table.restore(1)
table.merge(other_dataset, left_on="id")
def test_delete(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 2
table.delete("id=0")
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
def test_update(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 2
table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 4
assert table.version == 4
assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: Vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func(texts)})
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
table.add(df)
query_str = "hi how are you?"
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_add_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: Vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"]
table.add([{"text": t} for t in texts])
query_str = "hi how are you?"
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: Vector(10)
vector2: Vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
def test_empty_query(db):
table = LanceTable.create(
db,
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
val = df.id.iloc[0]
assert val == 1

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb-node" name = "vectordb-node"
version = "0.2.3" version = "0.2.5"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
edition = "2018" edition = "2018"

View File

@@ -28,7 +28,9 @@ fn validate_vector_column(record_batch: &RecordBatch) -> Result<()> {
record_batch record_batch
.column_by_name(VECTOR_COLUMN_NAME) .column_by_name(VECTOR_COLUMN_NAME)
.map(|_| ()) .map(|_| ())
.context(MissingColumnSnafu { name: VECTOR_COLUMN_NAME }) .context(MissingColumnSnafu {
name: VECTOR_COLUMN_NAME,
})
} }
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> { pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {

View File

@@ -183,11 +183,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let aws_region = get_aws_region(&mut cx, 4)?; let aws_region = get_aws_region(&mut cx, 4)?;
let params = ReadParams { let params = ReadParams {
store_options: Some(ObjectStoreParams { store_options: Some(ObjectStoreParams::with_aws_credentials(
aws_credentials: aws_creds, aws_creds, aws_region,
aws_region, )),
..ObjectStoreParams::default()
}),
..ReadParams::default() ..ReadParams::default()
}; };

View File

@@ -43,7 +43,8 @@ impl JsTable {
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?; .downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx); let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let buffer = cx.argument::<JsBuffer>(1)?; let buffer = cx.argument::<JsBuffer>(1)?;
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
// Write mode // Write mode
let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() { let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() {
@@ -65,11 +66,9 @@ impl JsTable {
let aws_region = get_aws_region(&mut cx, 6)?; let aws_region = get_aws_region(&mut cx, 6)?;
let params = WriteParams { let params = WriteParams {
store_params: Some(ObjectStoreParams { store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_credentials: aws_creds, aws_creds, aws_region,
aws_region, )),
..ObjectStoreParams::default()
}),
mode: mode, mode: mode,
..WriteParams::default() ..WriteParams::default()
}; };
@@ -92,7 +91,8 @@ impl JsTable {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?; let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx); let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let channel = cx.channel(); let channel = cx.channel();
let mut table = js_table.table.clone(); let mut table = js_table.table.clone();
@@ -108,11 +108,9 @@ impl JsTable {
let aws_region = get_aws_region(&mut cx, 5)?; let aws_region = get_aws_region(&mut cx, 5)?;
let params = WriteParams { let params = WriteParams {
store_params: Some(ObjectStoreParams { store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_credentials: aws_creds, aws_creds, aws_region,
aws_region, )),
..ObjectStoreParams::default()
}),
mode: write_mode, mode: write_mode,
..WriteParams::default() ..WriteParams::default()
}; };

View File

@@ -1,21 +1,29 @@
[package] [package]
name = "vectordb" name = "vectordb"
version = "0.2.3" version = "0.2.5"
edition = "2021" edition = "2021"
description = "Serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
keywords = ["lancedb", "lance", "database", "search"]
categories = ["database-implementations"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
arrow = { workspace = true }
arrow-array = { workspace = true } arrow-array = { workspace = true }
arrow-data = { workspace = true } arrow-data = { workspace = true }
arrow-schema = { workspace = true } arrow-schema = { workspace = true }
arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
object_store = { workspace = true } object_store = { workspace = true }
snafu = { workspace = true } snafu = { workspace = true }
half = { workspace = true } half = { workspace = true }
lance = { workspace = true } lance = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] } tokio = { version = "1.23", features = ["rt-multi-thread"] }
log = { workspace = true }
num-traits = "0"
url = { workspace = true }
[dev-dependencies] [dev-dependencies]
tempfile = "3.5.0" tempfile = "3.5.0"

3
rust/vectordb/README.md Normal file
View File

@@ -0,0 +1,3 @@
# LanceDB Rust
Rust client for LanceDB, a serverless vector database. Read more at: https://lancedb.com/

View File

@@ -0,0 +1,15 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use lance::arrow::*;

18
rust/vectordb/src/data.rs Normal file
View File

@@ -0,0 +1,18 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Data types, schema coercion, and data cleaning and etc.
pub mod inspect;
pub mod sanitize;

View File

@@ -0,0 +1,180 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use arrow::compute::kernels::{aggregate::bool_and, length::length};
use arrow_array::{
cast::AsArray,
types::{ArrowPrimitiveType, Int32Type, Int64Type},
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
};
use arrow_ord::comparison::eq_dyn_scalar;
use arrow_schema::DataType;
use num_traits::{ToPrimitive, Zero};
use crate::error::{Error, Result};
pub(crate) fn infer_dimension<T: ArrowPrimitiveType>(
list_arr: &GenericListArray<T::Native>,
) -> Result<Option<T::Native>>
where
T::Native: OffsetSizeTrait + ToPrimitive,
{
let len_arr = length(list_arr)?;
if len_arr.is_empty() {
return Ok(Some(Zero::zero()));
}
let dim = len_arr.as_primitive::<T>().value(0);
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
Ok(None)
} else {
Ok(Some(dim))
}
}
/// Infer the vector columns from a dataset.
///
/// Parameters
/// ----------
/// - reader: RecordBatchReader
/// - strict: if set true, only fixed_size_list<float> is considered as vector column. If set to false,
/// a list<float> column with same length is also considered as vector column.
pub fn infer_vector_columns(
reader: impl RecordBatchReader + Send,
strict: bool,
) -> Result<Vec<String>> {
let mut columns = vec![];
let mut columns_to_infer: HashMap<String, Option<i64>> = HashMap::new();
for field in reader.schema().fields() {
match field.data_type() {
DataType::FixedSizeList(sub_field, _) if sub_field.data_type().is_floating() => {
columns.push(field.name().to_string());
}
DataType::List(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().to_string(), None);
}
DataType::LargeList(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().to_string(), None);
}
_ => {}
}
}
for batch in reader {
let batch = batch?;
let col_names = columns_to_infer.keys().cloned().collect::<Vec<_>>();
for col_name in col_names {
let col = batch.column_by_name(&col_name).ok_or(Error::Schema {
message: format!("Column {} not found", col_name),
})?;
if let Some(dim) = match *col.data_type() {
DataType::List(_) => {
infer_dimension::<Int32Type>(col.as_list::<i32>())?.map(|d| d as i64)
}
DataType::LargeList(_) => infer_dimension::<Int64Type>(col.as_list::<i64>())?,
_ => {
return Err(Error::Schema {
message: format!("Column {} is not a list", col_name),
})
}
} {
if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) {
if prev_dim != &dim {
columns_to_infer.remove(&col_name);
}
} else {
columns_to_infer.insert(col_name, Some(dim));
}
} else {
columns_to_infer.remove(&col_name);
}
}
}
columns.extend(columns_to_infer.keys().cloned());
Ok(columns)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{
types::{Float32Type, Float64Type},
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use std::{sync::Arc, vec};
#[test]
fn test_infer_vector_columns() {
let schema = Arc::new(Schema::new(vec![
Field::new("f", DataType::Float32, false),
Field::new("s", DataType::Utf8, false),
Field::new(
"l1",
DataType::List(Arc::new(Field::new("item", DataType::Float32, true))),
false,
),
Field::new(
"l2",
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
false,
),
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 32),
true,
),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
Arc::new(ListArray::from_iter_primitive::<Float32Type, _, _>(
(0..3).map(|_| Some(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)])),
)),
// Var-length list
Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
Some(vec![Some(1.0_f64)]),
Some(vec![Some(2.0_f64), Some(3.0_f64)]),
Some(vec![Some(4.0_f64), Some(5.0_f64), Some(6.0_f64)]),
])),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
vec![
Some(vec![Some(1.0); 32]),
Some(vec![Some(2.0); 32]),
Some(vec![Some(3.0); 32]),
],
32,
),
),
],
)
.unwrap();
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
let cols = infer_vector_columns(reader, false).unwrap();
assert_eq!(cols, vec!["fl", "l1"]);
let reader = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
let cols = infer_vector_columns(reader, true).unwrap();
assert_eq!(cols, vec!["fl"]);
}
}

View File

@@ -0,0 +1,284 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{iter::repeat_with, sync::Arc};
use arrow_array::{
cast::AsArray,
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
RecordBatchReader,
};
use arrow_cast::{can_cast_types, cast};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use half::f16;
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
use log::warn;
use num_traits::cast::AsPrimitive;
use super::inspect::infer_dimension;
use crate::error::Result;
fn cast_array<I: ArrowNumericType, O: ArrowNumericType>(
arr: &PrimitiveArray<I>,
) -> Arc<PrimitiveArray<O>>
where
I::Native: AsPrimitive<O::Native>,
{
Arc::new(PrimitiveArray::<O>::from_iter_values(
arr.values().iter().map(|v| (*v).as_()),
))
}
fn cast_float_array<I: ArrowNumericType>(
arr: &PrimitiveArray<I>,
dt: &DataType,
) -> std::result::Result<Arc<dyn Array>, ArrowError>
where
I::Native: AsPrimitive<f64> + AsPrimitive<f32> + AsPrimitive<f16>,
{
match dt {
DataType::Float16 => Ok(cast_array::<I, Float16Type>(arr)),
DataType::Float32 => Ok(cast_array::<I, Float32Type>(arr)),
DataType::Float64 => Ok(cast_array::<I, Float64Type>(arr)),
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field: unable to coerce {:?} to {:?}",
arr.data_type(),
dt
))),
}
}
fn coerce_array(
array: &Arc<dyn Array>,
field: &Field,
) -> std::result::Result<Arc<dyn Array>, ArrowError> {
if array.data_type() == field.data_type() {
return Ok(array.clone());
}
match (array.data_type(), field.data_type()) {
// Normal cast-able types.
(adt, dt) if can_cast_types(adt, dt) => cast(&array, dt),
// Casting between f16/f32/f64 can be lossy.
(adt, dt) if (adt.is_floating() || dt.is_floating()) => {
if adt.byte_width() > dt.byte_width() {
warn!(
"Coercing field {} {:?} to {:?} might lose precision",
field.name(),
adt,
dt
);
}
match adt {
DataType::Float16 => cast_float_array(array.as_primitive::<Float16Type>(), dt),
DataType::Float32 => cast_float_array(array.as_primitive::<Float32Type>(), dt),
DataType::Float64 => cast_float_array(array.as_primitive::<Float64Type>(), dt),
_ => unreachable!(),
}
}
(adt, DataType::FixedSizeList(exp_field, exp_dim)) => match adt {
// Cast a float fixed size array with same dimension to the expected type.
DataType::FixedSizeList(_, dim) if dim == exp_dim => {
let actual_sub = array.as_fixed_size_list();
let values = coerce_array(actual_sub.values(), exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*dim,
)?) as Arc<dyn Array>)
}
DataType::List(_) | DataType::LargeList(_) => {
let Some(dim) = (match adt {
DataType::List(_) => infer_dimension::<Int32Type>(array.as_list::<i32>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from list: {}",
e
))
})?
.map(|d| d as i64),
DataType::LargeList(_) => infer_dimension::<Int64Type>(array.as_list::<i64>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from large list: {}",
e
))
})?,
_ => unreachable!(),
}) else {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)));
};
if dim != *exp_dim as i64 {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: expected dimension {} but got {}",
exp_dim, dim
)));
}
let values = coerce_array(array, exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*exp_dim,
)?) as Arc<dyn Array>)
}
_ => Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)))?,
},
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field {}: unable to coerce {:?} to {:?}",
field.name(),
array.data_type(),
field.data_type()
)))?,
}
}
fn coerce_schema_batch(
batch: RecordBatch,
schema: Arc<Schema>,
) -> std::result::Result<RecordBatch, ArrowError> {
if batch.schema() == schema {
return Ok(batch);
}
let columns = schema
.fields()
.iter()
.map(|field| {
batch
.column_by_name(field.name())
.ok_or_else(|| {
ArrowError::SchemaError(format!("Column {} not found", field.name()))
})
.and_then(|c| coerce_array(c, field))
})
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
RecordBatch::try_new(schema, columns)
}
/// Coerce the reader (input data) to match the given [Schema].
///
pub fn coerce_schema(
reader: impl RecordBatchReader + Send + 'static,
schema: Arc<Schema>,
) -> Result<Box<dyn RecordBatchReader + Send>> {
if reader.schema() == schema {
return Ok(Box::new(RecordBatchIterator::new(reader, schema)));
}
let s = schema.clone();
let batches = reader
.zip(repeat_with(move || s.clone()))
.map(|(batch, s)| coerce_schema_batch(batch?, s));
Ok(Box::new(RecordBatchIterator::new(batches, schema)))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow_array::{
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array,
RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::Field;
use half::f16;
use lance::arrow::FixedSizeListArrayExt;
#[test]
fn test_coerce_list_to_fixed_size_list() {
let schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float16, true),
Field::new("i", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float32Array::from_iter_values((0..256).map(|v| v as f32)),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float16Array::from_iter_values(
(0..4).map(|v| f16::from_f32(v as f32)),
)),
Arc::new(Int32Array::from_iter_values(0..4)),
],
)
.unwrap();
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
let expected_schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float64, true),
Field::new("i", DataType::Int8, true),
]));
let stream = coerce_schema(reader, expected_schema.clone()).unwrap();
let batches = stream.collect::<Vec<_>>();
assert_eq!(batches.len(), 1);
let batch = batches[0].as_ref().unwrap();
assert_eq!(batch.schema(), expected_schema);
let expected = RecordBatch::try_new(
expected_schema,
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float16Array::from_iter_values((0..256).map(|v| f16::from_f32(v as f32))),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float64Array::from_iter_values((0..4).map(|v| v as f64))),
Arc::new(Int8Array::from_iter_values(0..4)),
],
)
.unwrap();
assert_eq!(batch, &expected);
}
}

View File

@@ -20,19 +20,32 @@ use lance::dataset::WriteParams;
use lance::io::object_store::ObjectStore; use lance::io::object_store::ObjectStore;
use snafu::prelude::*; use snafu::prelude::*;
use crate::error::{CreateDirSnafu, InvalidTableNameSnafu, Result}; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::table::{ReadParams, Table}; use crate::table::{ReadParams, Table};
pub const LANCE_FILE_EXTENSION: &str = "lance"; pub const LANCE_FILE_EXTENSION: &str = "lance";
pub struct Database { pub struct Database {
object_store: ObjectStore, object_store: ObjectStore,
query_string: Option<String>,
pub(crate) uri: String, pub(crate) uri: String,
pub(crate) base_path: object_store::path::Path, pub(crate) base_path: object_store::path::Path,
} }
const LANCE_EXTENSION: &str = "lance"; const LANCE_EXTENSION: &str = "lance";
const ENGINE: &str = "engine";
/// Parse a url, if it's not a valid url, assume it's a local file
/// and try to parse with file:// appended
fn parse_url(url: &str) -> Result<url::Url> {
match url::Url::parse(url) {
Ok(url) => Ok(url),
Err(_) => url::Url::parse(format!("file://{}", url).as_str()).map_err(|e| Error::Lance {
message: format!("Failed to parse uri: {}", e),
}),
}
}
/// A connection to LanceDB /// A connection to LanceDB
impl Database { impl Database {
@@ -46,12 +59,71 @@ impl Database {
/// ///
/// * A [Database] object. /// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> { pub async fn connect(uri: &str) -> Result<Database> {
let (object_store, base_path) = ObjectStore::from_uri(uri).await?; // For a native (using lance directly) connection
if object_store.is_local() { // The DB doesn't use any uri parameters, but lance does
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?; // So we need to parse the uri, extract the query string, and progate it to lance
let mut url = parse_url(uri)?;
// special handling for windows
if url.scheme().len() == 1 && cfg!(windows) {
let (object_store, base_path) = ObjectStore::from_uri(uri).await?;
if object_store.is_local() {
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
}
return Ok(Database {
uri: uri.to_string(),
query_string: None,
base_path,
object_store,
});
} }
// iter thru the query params and extract the commit store param
let mut engine = None;
let mut filtered_querys = vec![];
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
// THE API WILL CHANGE
for (key, value) in url.query_pairs() {
if key == ENGINE {
engine = Some(value.to_string());
} else {
// to owned so we can modify the url
filtered_querys.push((key.to_string(), value.to_string()));
}
}
// Filter out the commit store query param -- it's a lancedb param
url.query_pairs_mut().clear();
url.query_pairs_mut().extend_pairs(filtered_querys);
// Take a copy of the query string so we can propagate it to lance
let query_string = url.query().map(|s| s.to_string());
// clear the query string so we can use the url as the base uri
// use .set_query(None) instead of .set_query("") because the latter
// will add a trailing '?' to the url
url.set_query(None);
let table_base_uri = if let Some(store) = engine {
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
WARN_ONCE.call_once(|| {
log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
});
let old_scheme = url.scheme().to_string();
let new_scheme = format!("{}+{}", old_scheme, store);
url.to_string().replacen(&old_scheme, &new_scheme, 1)
} else {
url.to_string()
};
let plain_uri = url.to_string();
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
if object_store.is_local() {
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
}
Ok(Database { Ok(Database {
uri: uri.to_string(), uri: table_base_uri,
query_string,
base_path, base_path,
object_store, object_store,
}) })
@@ -149,11 +221,19 @@ impl Database {
let path = Path::new(&self.uri); let path = Path::new(&self.uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let uri = table_uri let mut uri = table_uri
.as_path() .as_path()
.to_str() .to_str()
.context(InvalidTableNameSnafu { name })?; .context(InvalidTableNameSnafu { name })?
Ok(uri.to_string()) .to_string();
// If there are query string set on the connection, propagate to lance
if let Some(query) = self.query_string.as_ref() {
uri.push('?');
uri.push_str(query.as_str());
}
Ok(uri)
} }
} }
@@ -170,7 +250,15 @@ mod tests {
let uri = tmp_dir.path().to_str().unwrap(); let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap(); let db = Database::connect(uri).await.unwrap();
assert_eq!(db.uri, uri); // file:// scheme should be automatically appended if not specified
// windows path come with drive letter, so file:// won't be appended
let expected = if cfg!(windows) {
uri.to_string()
} else {
format!("file://{}", uri)
};
assert_eq!(db.uri, expected);
} }
#[tokio::test] #[tokio::test]

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use arrow_schema::ArrowError;
use snafu::Snafu; use snafu::Snafu;
#[derive(Debug, Snafu)] #[derive(Debug, Snafu)]
@@ -32,10 +33,20 @@ pub enum Error {
Store { message: String }, Store { message: String },
#[snafu(display("LanceDBError: {message}"))] #[snafu(display("LanceDBError: {message}"))]
Lance { message: String }, Lance { message: String },
#[snafu(display("LanceDB Schema Error: {message}"))]
Schema { message: String },
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
impl From<ArrowError> for Error {
fn from(e: ArrowError) -> Self {
Self::Lance {
message: e.to_string(),
}
}
}
impl From<lance::Error> for Error { impl From<lance::Error> for Error {
fn from(e: lance::Error) -> Self { fn from(e: lance::Error) -> Self {
Self::Lance { Self::Lance {

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod data;
pub mod database; pub mod database;
pub mod error; pub mod error;
pub mod index; pub mod index;