mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
33 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d326146a40 | ||
|
|
693bca1eba | ||
|
|
343e274ea5 | ||
|
|
a695fb8030 | ||
|
|
bc8670d7af | ||
|
|
74004161ff | ||
|
|
34ddb1de6d | ||
|
|
1029fc9cb0 | ||
|
|
31c5df6d99 | ||
|
|
dbf37a0434 | ||
|
|
f20f19b804 | ||
|
|
55207ce844 | ||
|
|
c21f9cdda0 | ||
|
|
bc38abb781 | ||
|
|
731f86e44c | ||
|
|
31dad71c94 | ||
|
|
9585f550b3 | ||
|
|
8dc2315479 | ||
|
|
f6bfb5da11 | ||
|
|
661fcecf38 | ||
|
|
07fe284810 | ||
|
|
800bb691c3 | ||
|
|
ec24e09add | ||
|
|
0554db03b3 | ||
|
|
b315ea3978 | ||
|
|
aa7806cf0d | ||
|
|
6799613109 | ||
|
|
0f26915d22 | ||
|
|
32163063dc | ||
|
|
9a9a73a65d | ||
|
|
52fa7f5577 | ||
|
|
0cba0f4f92 | ||
|
|
8391ffee84 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.2.4
|
||||
current_version = 0.2.6
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
54
.github/workflows/node.yml
vendored
54
.github/workflows/node.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
- node/**
|
||||
- rust/ffi/node/**
|
||||
- .github/workflows/node.yml
|
||||
- docker-compose.yml
|
||||
|
||||
env:
|
||||
# Disable full debug symbol generation to speed up CI build and keep memory down
|
||||
@@ -107,3 +108,56 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
npm run test
|
||||
aws-integtest:
|
||||
timeout-minutes: 45
|
||||
runs-on: "ubuntu-22.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: node
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ACCESSKEY
|
||||
AWS_SECRET_ACCESS_KEY: SECRETKEY
|
||||
AWS_DEFAULT_REGION: us-west-2
|
||||
# this one is for s3
|
||||
AWS_ENDPOINT: http://localhost:4566
|
||||
# this one is for dynamodb
|
||||
DYNAMODB_ENDPOINT: http://localhost:4566
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 18
|
||||
cache: 'npm'
|
||||
cache-dependency-path: node/package-lock.json
|
||||
- name: start local stack
|
||||
run: docker compose -f ../docker-compose.yml up -d --wait
|
||||
- name: create s3
|
||||
run: aws s3 mb s3://lancedb-integtest --endpoint $AWS_ENDPOINT
|
||||
- name: create ddb
|
||||
run: |
|
||||
aws dynamodb create-table \
|
||||
--table-name lancedb-integtest \
|
||||
--attribute-definitions '[{"AttributeName": "base_uri", "AttributeType": "S"}, {"AttributeName": "version", "AttributeType": "N"}]' \
|
||||
--key-schema '[{"AttributeName": "base_uri", "KeyType": "HASH"}, {"AttributeName": "version", "KeyType": "RANGE"}]' \
|
||||
--provisioned-throughput '{"ReadCapacityUnits": 10, "WriteCapacityUnits": 10}' \
|
||||
--endpoint-url $DYNAMODB_ENDPOINT
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Build
|
||||
run: |
|
||||
npm ci
|
||||
npm run tsc
|
||||
npm run build
|
||||
npm run pack-build
|
||||
npm install --no-save ./dist/lancedb-vectordb-*.tgz
|
||||
# Remove index.node to test with dependency installed
|
||||
rm index.node
|
||||
- name: Test
|
||||
run: npm run integration-test
|
||||
|
||||
34
.github/workflows/python.yml
vendored
34
.github/workflows/python.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: isort
|
||||
run: isort --check --diff --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -x -v --durations=30 tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
run: pytest --doctest-modules lancedb
|
||||
mac:
|
||||
@@ -65,4 +65,34 @@ jobs:
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -x -v --durations=30 tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
pydantic1x:
|
||||
timeout-minutes: 30
|
||||
runs-on: "ubuntu-22.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: python
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install lancedb
|
||||
run: |
|
||||
pip install "pydantic<2"
|
||||
pip install -e .[tests]
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock black isort
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: isort
|
||||
run: isort --check --diff --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
run: pytest --doctest-modules lancedb
|
||||
23
Cargo.toml
23
Cargo.toml
@@ -1,16 +1,25 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"rust/vectordb",
|
||||
"rust/ffi/node"
|
||||
]
|
||||
members = ["rust/ffi/node", "rust/vectordb"]
|
||||
# Python package needs to be built by maturin.
|
||||
exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = "=0.6.5"
|
||||
lance = { "version" = "=0.7.5", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.7.5" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "43.0.0", optional = false }
|
||||
arrow-array = "43.0"
|
||||
arrow-data = "43.0"
|
||||
arrow-schema = "43.0"
|
||||
arrow-ipc = "43.0"
|
||||
half = { "version" = "=2.2.1", default-features = false }
|
||||
arrow-ord = "43.0"
|
||||
arrow-schema = "43.0"
|
||||
arrow-arith = "43.0"
|
||||
arrow-cast = "43.0"
|
||||
half = { "version" = "=2.2.1", default-features = false, features = [
|
||||
"num-traits"
|
||||
] }
|
||||
log = "0.4"
|
||||
object_store = "0.6.1"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
|
||||
18
docker-compose.yml
Normal file
18
docker-compose.yml
Normal file
@@ -0,0 +1,18 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
localstack:
|
||||
image: localstack/localstack:0.14
|
||||
ports:
|
||||
- 4566:4566
|
||||
environment:
|
||||
- SERVICES=s3,dynamodb
|
||||
- DEBUG=1
|
||||
- LS_LOG=trace
|
||||
- DOCKER_HOST=unix:///var/run/docker.sock
|
||||
- AWS_ACCESS_KEY_ID=ACCESSKEY
|
||||
- AWS_SECRET_ACCESS_KEY=SECRETKEY
|
||||
healthcheck:
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]
|
||||
interval: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
@@ -67,6 +67,11 @@ nav:
|
||||
- Home:
|
||||
- 🏢 Home: index.md
|
||||
- 💡 Basics: basic.md
|
||||
- 📚 Guides:
|
||||
- Tables: guides/tables.md
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- 🧬 Embeddings: embedding.md
|
||||
- 🔍 Python full-text search: fts.md
|
||||
- 🔌 Integrations:
|
||||
@@ -91,12 +96,12 @@ nav:
|
||||
- Serverless Website Chatbot: examples/serverless_website_chatbot.md
|
||||
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||
- 📚 Guides:
|
||||
- Tables: guides/tables.md
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- Basics: basic.md
|
||||
- Guides:
|
||||
- Tables: guides/tables.md
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- Embeddings: embedding.md
|
||||
- Python full-text search: fts.md
|
||||
- Integrations:
|
||||
@@ -121,12 +126,6 @@ nav:
|
||||
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
||||
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||
|
||||
- Guides:
|
||||
- Tables: guides/tables.md
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- API references:
|
||||
- Python API: python/python.md
|
||||
- Javascript API: javascript/modules.md
|
||||
|
||||
@@ -49,11 +49,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
|
||||
db.create_table("table2", data)
|
||||
|
||||
db["table2"].head()
|
||||
db["table2"].head()
|
||||
```
|
||||
!!! info "Note"
|
||||
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
|
||||
|
||||
|
||||
```python
|
||||
custom_schema = pa.schema([
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
@@ -66,7 +66,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
|
||||
### From PyArrow Tables
|
||||
You can also create LanceDB tables directly from pyarrow tables
|
||||
|
||||
|
||||
```python
|
||||
table = pa.Table.from_arrays(
|
||||
[
|
||||
@@ -84,18 +84,28 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
```
|
||||
|
||||
### From Pydantic Models
|
||||
LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method.
|
||||
When you create an empty table without data, you must specify the table schema.
|
||||
LanceDB supports creating tables by specifying a pyarrow schema or a specialized
|
||||
pydantic model called `LanceModel`.
|
||||
|
||||
For example, the following Content model specifies a table with 5 columns:
|
||||
movie_id, vector, genres, title, and imdb_id. When you create a table, you can
|
||||
pass the class as the value of the `schema` parameter to `create_table`.
|
||||
The `vector` column is a `Vector` type, which is a specialized pydantic type that
|
||||
can be configured with the vector dimensions. It is also important to note that
|
||||
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
|
||||
(which itself derives from `pydantic.BaseModel`).
|
||||
|
||||
```python
|
||||
from lancedb.pydantic import vector, LanceModel
|
||||
from lancedb.pydantic import Vector, LanceModel
|
||||
|
||||
class Content(LanceModel):
|
||||
movie_id: int
|
||||
vector: vector(128)
|
||||
vector: Vector(128)
|
||||
genres: str
|
||||
title: str
|
||||
imdb_id: int
|
||||
|
||||
|
||||
@property
|
||||
def imdb_url(self) -> str:
|
||||
return f"https://www.imdb.com/title/tt{self.imdb_id}"
|
||||
@@ -103,7 +113,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
import pyarrow as pa
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
table_name = "movielens_small"
|
||||
table = db.create_table(table_name, schema=Content.to_arrow_schema())
|
||||
table = db.create_table(table_name, schema=Content)
|
||||
```
|
||||
|
||||
### Using Iterators / Writing Large Datasets
|
||||
@@ -113,7 +123,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types.
|
||||
|
||||
Here's an example using using `RecordBatch` iterator for creating tables.
|
||||
|
||||
|
||||
```python
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -142,11 +152,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
|
||||
## Creating Empty Table
|
||||
You can also create empty tables in python. Initialize it with schema and later ingest data into it.
|
||||
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
@@ -168,8 +178,8 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
|
||||
class Model(LanceModel):
|
||||
vector: vector(2)
|
||||
|
||||
vector: Vector(2)
|
||||
|
||||
tbl = db.create_table("table5", schema=Model.to_arrow_schema())
|
||||
```
|
||||
|
||||
@@ -249,7 +259,7 @@ After a table has been created, you can always add more data to it using
|
||||
You can also add a large dataset batch in one go using Iterator of any supported data types.
|
||||
|
||||
### Adding to table using Iterator
|
||||
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
@@ -261,10 +271,10 @@ After a table has been created, you can always add more data to it using
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
})
|
||||
|
||||
|
||||
tbl.add(make_batches())
|
||||
```
|
||||
|
||||
|
||||
The other arguments accepted:
|
||||
|
||||
| Name | Type | Description | Default |
|
||||
@@ -274,7 +284,7 @@ After a table has been created, you can always add more data to it using
|
||||
| on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop |
|
||||
| fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 |
|
||||
|
||||
|
||||
|
||||
=== "Javascript/Typescript"
|
||||
|
||||
```javascript
|
||||
@@ -312,7 +322,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
|
||||
# x vector
|
||||
# 0 1 [1.0, 2.0]
|
||||
# 1 3 [5.0, 6.0]
|
||||
```
|
||||
```
|
||||
|
||||
### Delete from a list of values
|
||||
|
||||
@@ -325,7 +335,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
|
||||
# x vector
|
||||
# 0 3 [5.0, 6.0]
|
||||
```
|
||||
|
||||
|
||||
=== "Javascript/Typescript"
|
||||
|
||||
```javascript
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# LanceDB
|
||||
|
||||
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings.
|
||||
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrieval, filtering and management of embeddings.
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -249,11 +249,11 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from lancedb.pydantic import vector, LanceModel\n",
|
||||
"from lancedb.pydantic import Vector, LanceModel\n",
|
||||
"\n",
|
||||
"class Content(LanceModel):\n",
|
||||
" movie_id: int\n",
|
||||
" vector: vector(128)\n",
|
||||
" vector: Vector(128)\n",
|
||||
" genres: str\n",
|
||||
" title: str\n",
|
||||
" imdb_id: int\n",
|
||||
@@ -359,7 +359,7 @@
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"class PydanticSchema(LanceModel):\n",
|
||||
" vector: vector(2)\n",
|
||||
" vector: Vector(2)\n",
|
||||
" item: str\n",
|
||||
" price: float\n",
|
||||
"\n",
|
||||
@@ -394,10 +394,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import lancedb\n",
|
||||
"from lancedb.pydantic import LanceModel, vector\n",
|
||||
"from lancedb.pydantic import LanceModel, Vector\n",
|
||||
"\n",
|
||||
"class Model(LanceModel):\n",
|
||||
" vector: vector(2)\n",
|
||||
" vector: Vector(2)\n",
|
||||
"\n",
|
||||
"tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())"
|
||||
]
|
||||
|
||||
@@ -13,10 +13,10 @@ via [pydantic_to_schema()](python.md##lancedb.pydantic.pydantic_to_schema) metho
|
||||
|
||||
## Vector Field
|
||||
|
||||
LanceDB provides a [`vector(dim)`](python.md#lancedb.pydantic.vector) method to define a
|
||||
LanceDB provides a [`Vector(dim)`](python.md#lancedb.pydantic.Vector) method to define a
|
||||
vector Field in a Pydantic Model.
|
||||
|
||||
::: lancedb.pydantic.vector
|
||||
::: lancedb.pydantic.Vector
|
||||
|
||||
## Type Conversion
|
||||
|
||||
@@ -33,4 +33,4 @@ Current supported type conversions:
|
||||
| `str` | `pyarrow.utf8()` |
|
||||
| `list` | `pyarrow.List` |
|
||||
| `BaseModel` | `pyarrow.Struct` |
|
||||
| `vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
|
||||
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
|
||||
|
||||
@@ -26,9 +26,19 @@ pip install lancedb
|
||||
|
||||
## Embeddings
|
||||
|
||||
::: lancedb.embeddings.with_embeddings
|
||||
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
||||
|
||||
::: lancedb.embeddings.EmbeddingFunction
|
||||
::: lancedb.embeddings.functions.EmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.TextEmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenAIEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenClipEmbeddings
|
||||
|
||||
::: lancedb.embeddings.with_embeddings
|
||||
|
||||
## Context
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ excluded_globs = [
|
||||
"../src/embedding.md",
|
||||
"../src/examples/*.md",
|
||||
"../src/integrations/voxel51.md",
|
||||
"../src/guides/tables.md"
|
||||
"../src/guides/tables.md",
|
||||
"../src/python/duckdb.md",
|
||||
]
|
||||
|
||||
python_prefix = "py"
|
||||
|
||||
105
node/package-lock.json
generated
105
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.2.4",
|
||||
"version": "0.2.6",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.2.4",
|
||||
"version": "0.2.6",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -31,6 +31,7 @@
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/sinon": "^10.0.15",
|
||||
"@types/temp": "^0.9.1",
|
||||
"@types/uuid": "^9.0.3",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
"cargo-cp-artifact": "^0.1",
|
||||
"chai": "^4.3.7",
|
||||
@@ -48,14 +49,15 @@
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typedoc": "^0.24.7",
|
||||
"typedoc-plugin-markdown": "^3.15.3",
|
||||
"typescript": "*"
|
||||
"typescript": "*",
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.4",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.4",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.4"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.6",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.6",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
|
||||
}
|
||||
},
|
||||
"node_modules/@apache-arrow/ts": {
|
||||
@@ -315,9 +317,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz",
|
||||
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
|
||||
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -327,9 +329,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz",
|
||||
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
|
||||
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -339,9 +341,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -351,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -363,9 +365,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz",
|
||||
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
|
||||
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -596,6 +598,12 @@
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/uuid": {
|
||||
"version": "9.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
|
||||
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@typescript-eslint/eslint-plugin": {
|
||||
"version": "5.59.1",
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
|
||||
@@ -4451,6 +4459,15 @@
|
||||
"punycode": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/uuid": {
|
||||
"version": "9.0.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
|
||||
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
|
||||
"dev": true,
|
||||
"bin": {
|
||||
"uuid": "dist/bin/uuid"
|
||||
}
|
||||
},
|
||||
"node_modules/v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
@@ -4852,33 +4869,33 @@
|
||||
}
|
||||
},
|
||||
"@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz",
|
||||
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
|
||||
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz",
|
||||
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
|
||||
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz",
|
||||
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
|
||||
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@neon-rs/cli": {
|
||||
@@ -5093,6 +5110,12 @@
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"@types/uuid": {
|
||||
"version": "9.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
|
||||
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
|
||||
"dev": true
|
||||
},
|
||||
"@typescript-eslint/eslint-plugin": {
|
||||
"version": "5.59.1",
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
|
||||
@@ -7844,6 +7867,12 @@
|
||||
"punycode": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"uuid": {
|
||||
"version": "9.0.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
|
||||
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
|
||||
"dev": true
|
||||
},
|
||||
"v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.2.4",
|
||||
"version": "0.2.6",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -9,6 +9,7 @@
|
||||
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
|
||||
"build-release": "npm run build -- --release",
|
||||
"test": "npm run tsc && mocha -recursive dist/test",
|
||||
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
||||
"lint": "eslint native.js src --ext .js,.ts",
|
||||
"clean": "rm -rf node_modules *.node dist/",
|
||||
"pack-build": "neon pack-build",
|
||||
@@ -34,6 +35,7 @@
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/sinon": "^10.0.15",
|
||||
"@types/temp": "^0.9.1",
|
||||
"@types/uuid": "^9.0.3",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
"cargo-cp-artifact": "^0.1",
|
||||
"chai": "^4.3.7",
|
||||
@@ -51,7 +53,8 @@
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typedoc": "^0.24.7",
|
||||
"typedoc-plugin-markdown": "^3.15.3",
|
||||
"typescript": "*"
|
||||
"typescript": "*",
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@apache-arrow/ts": "^12.0.0",
|
||||
@@ -78,10 +81,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.4",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.4",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.4"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.6",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.6",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
|
||||
}
|
||||
}
|
||||
|
||||
43
node/src/integration_test/test.ts
Normal file
43
node/src/integration_test/test.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import * as chai from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
|
||||
const assert = chai.assert
|
||||
chai.use(chaiAsPromised)
|
||||
|
||||
describe('LanceDB AWS Integration test', function () {
|
||||
it('s3+ddb schema is processed correctly', async function () {
|
||||
this.timeout(15000)
|
||||
|
||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||
// THE API WILL CHANGE
|
||||
const conn = await lancedb.connect('s3://lancedb-integtest?engine=ddb&ddbTableName=lancedb-integtest')
|
||||
const data = [{ vector: Array(128).fill(1.0) }]
|
||||
|
||||
const tableName = uuidv4()
|
||||
let table = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
|
||||
|
||||
const futs = [table.add(data), table.add(data), table.add(data), table.add(data), table.add(data)]
|
||||
await Promise.allSettled(futs)
|
||||
|
||||
table = await conn.openTable(tableName)
|
||||
assert.equal(await table.countRows(), 6)
|
||||
})
|
||||
})
|
||||
@@ -19,7 +19,7 @@ import * as chaiAsPromised from 'chai-as-promised'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
|
||||
import { Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray } from 'apache-arrow'
|
||||
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
|
||||
|
||||
const expect = chai.expect
|
||||
const assert = chai.assert
|
||||
@@ -258,6 +258,36 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
})
|
||||
|
||||
describe('when searching an empty dataset', function () {
|
||||
it('should not fail', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||
assert.isEmpty(result)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when searching an empty-after-delete dataset', function () {
|
||||
it('should not fail', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
await table.add([{ vector: Array(128).fill(0.1) }])
|
||||
await table.delete('vector IS NOT NULL')
|
||||
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||
assert.isEmpty(result)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when creating a vector index', function () {
|
||||
it('overwrite all records in a table', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.2.2
|
||||
current_version = 0.2.6
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -11,12 +11,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.metadata
|
||||
from typing import Optional
|
||||
|
||||
from .db import URI, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
|
||||
def connect(
|
||||
uri: URI,
|
||||
@@ -31,9 +34,13 @@ def connect(
|
||||
----------
|
||||
uri: str or Path
|
||||
The uri of the database.
|
||||
api_token: str, optional
|
||||
api_key: str, optional
|
||||
If presented, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
region: str, default "us-west-2"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
|
||||
|
||||
# import lancedb so we don't have to in every example
|
||||
|
||||
|
||||
@@ -14,3 +17,24 @@ def doctest_setup(monkeypatch, tmpdir):
|
||||
monkeypatch.setitem(os.environ, "COLUMNS", "80")
|
||||
# Work in a temporary directory
|
||||
monkeypatch.chdir(tmpdir)
|
||||
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
|
||||
@registry.register("test")
|
||||
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
Return the hash of the first 10 characters
|
||||
"""
|
||||
|
||||
def generate_embeddings(self, texts):
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
|
||||
def _compute_one_embedding(self, row):
|
||||
emb = np.array([float(hash(c)) for c in row[:10]])
|
||||
emb /= np.linalg.norm(emb)
|
||||
return emb
|
||||
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
@@ -16,12 +16,13 @@ from __future__ import annotations
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from pyarrow import fs
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||
@@ -40,7 +41,7 @@ class DBConnection(ABC):
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[pa.Schema, LanceModel] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
@@ -285,10 +286,11 @@ class LanceDBConnection(DBConnection):
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[pa.Schema, LanceModel] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
@@ -307,6 +309,7 @@ class LanceDBConnection(DBConnection):
|
||||
mode=mode,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
)
|
||||
return tbl
|
||||
|
||||
|
||||
24
python/lancedb/embeddings/__init__.py
Normal file
24
python/lancedb/embeddings/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .functions import (
|
||||
EmbeddingFunction,
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
OpenAIEmbeddings,
|
||||
OpenClipEmbeddings,
|
||||
SentenceTransformerEmbeddings,
|
||||
TextEmbeddingFunction,
|
||||
)
|
||||
from .utils import with_embeddings
|
||||
578
python/lancedb/embeddings/functions.py
Normal file
578
python/lancedb/embeddings/functions.py
Normal file
@@ -0,0 +1,578 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import urllib.error
|
||||
import urllib.parse as urlparse
|
||||
import urllib.request
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the registry to its initial state
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
embedding function and source column
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
functions : dict
|
||||
A mapping of vector column name to embedding function.
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
An ABC for embedding functions.
|
||||
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
elif isinstance(texts, pa.Array):
|
||||
texts = texts.to_pylist()
|
||||
elif isinstance(texts, pa.ChunkedArray):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the model to load
|
||||
device : str
|
||||
The device to load the model on
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = cls.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
openai = self.safe_import("openai")
|
||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||
return [v["embedding"] for v in rs]
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in tqdm(futures)]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -20,7 +20,7 @@ import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
from retry import retry
|
||||
|
||||
from .util import safe_import_pandas
|
||||
from ..util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
@@ -58,7 +58,7 @@ def with_embeddings(
|
||||
pa.Table
|
||||
The input table with a new column called "vector" containing the embeddings.
|
||||
"""
|
||||
func = EmbeddingFunction(func)
|
||||
func = FunctionWrapper(func)
|
||||
if wrap_api:
|
||||
func = func.retry().rate_limit()
|
||||
func = func.batch_size(batch_size)
|
||||
@@ -71,7 +71,11 @@ def with_embeddings(
|
||||
return data.append_column("vector", table["vector"])
|
||||
|
||||
|
||||
class EmbeddingFunction:
|
||||
class FunctionWrapper:
|
||||
"""
|
||||
A wrapper for embedding functions that adds rate limiting, retries, and batching.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable):
|
||||
self.func = func
|
||||
self.rate_limiter_kwargs = {}
|
||||
@@ -26,6 +26,8 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
@@ -46,7 +48,19 @@ class FixedSizeListMixin(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vector(
|
||||
def vector(dim: int, value_type: pa.DataType = pa.float32()):
|
||||
# TODO: remove in future release
|
||||
from warnings import warn
|
||||
|
||||
warn(
|
||||
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
|
||||
"This function will be removed in future release",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return Vector(dim, value_type)
|
||||
|
||||
|
||||
def Vector(
|
||||
dim: int, value_type: pa.DataType = pa.float32()
|
||||
) -> Type[FixedSizeListMixin]:
|
||||
"""Pydantic Vector Type.
|
||||
@@ -65,12 +79,12 @@ def vector(
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import vector
|
||||
>>> from lancedb.pydantic import Vector
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... url: str
|
||||
... embeddings: vector(768)
|
||||
... embeddings: Vector(768)
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
@@ -114,7 +128,7 @@ def vector(
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
|
||||
raise TypeError("A list of numbers or numpy.ndarray is needed")
|
||||
return v
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@@ -224,27 +238,18 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
||||
>>> from typing import List, Optional
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import pydantic_to_schema
|
||||
...
|
||||
>>> class InnerModel(pydantic.BaseModel):
|
||||
... a: str
|
||||
... b: Optional[float]
|
||||
>>>
|
||||
>>> class FooModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... s: Optional[str] = None
|
||||
... s: str
|
||||
... vec: List[float]
|
||||
... li: List[int]
|
||||
... inner: InnerModel
|
||||
...
|
||||
>>> schema = pydantic_to_schema(FooModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("s", pa.utf8(), True),
|
||||
... pa.field("s", pa.utf8(), False),
|
||||
... pa.field("vec", pa.list_(pa.float64()), False),
|
||||
... pa.field("li", pa.list_(pa.int64()), False),
|
||||
... pa.field("inner", pa.struct([
|
||||
... pa.field("a", pa.utf8(), False),
|
||||
... pa.field("b", pa.float64(), True),
|
||||
... ]), False),
|
||||
... ])
|
||||
"""
|
||||
fields = _pydantic_model_to_fields(model)
|
||||
@@ -258,11 +263,11 @@ class LanceModel(pydantic.BaseModel):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> from lancedb.pydantic import LanceModel, vector
|
||||
>>> from lancedb.pydantic import LanceModel, Vector
|
||||
>>>
|
||||
>>> class TestModel(LanceModel):
|
||||
... name: str
|
||||
... vector: vector(2)
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("/tmp")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
@@ -278,13 +283,58 @@ class LanceModel(pydantic.BaseModel):
|
||||
"""
|
||||
Get the Arrow Schema for this model.
|
||||
"""
|
||||
return pydantic_to_schema(cls)
|
||||
schema = pydantic_to_schema(cls)
|
||||
functions = cls.parse_embedding_functions()
|
||||
if len(functions) > 0:
|
||||
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
|
||||
functions
|
||||
)
|
||||
schema = schema.with_metadata(metadata)
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def field_names(cls) -> List[str]:
|
||||
"""
|
||||
Get the field names of this model.
|
||||
"""
|
||||
return list(cls.safe_get_fields().keys())
|
||||
|
||||
@classmethod
|
||||
def safe_get_fields(cls):
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return list(cls.__fields__.keys())
|
||||
return list(cls.model_fields.keys())
|
||||
return cls.__fields__
|
||||
return cls.model_fields
|
||||
|
||||
@classmethod
|
||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the embedding functions from this model.
|
||||
"""
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
vec_and_function = []
|
||||
for name, field_info in cls.safe_get_fields().items():
|
||||
func = get_extras(field_info, "vector_column_for")
|
||||
if func is not None:
|
||||
vec_and_function.append([name, func])
|
||||
|
||||
configs = []
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func == func:
|
||||
configs.append(
|
||||
EmbeddingFunctionConfig(
|
||||
source_column=source, vector_column=vec, function=func
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
if PYDANTIC_VERSION.major >= 2:
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -37,6 +38,9 @@ class Query(pydantic.BaseModel):
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# if True then apply the filter before vector search
|
||||
prefilter: bool = False
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
@@ -54,44 +58,112 @@ class Query(pydantic.BaseModel):
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
|
||||
class LanceQueryBuilder:
|
||||
"""
|
||||
A builder for nearest neighbor queries for LanceDB.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4},
|
||||
... {"vector": [0.4, 0.4], "b": 6},
|
||||
... {"vector": [0.4, 0.4], "b": 10}]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data=data)
|
||||
>>> (table.search([0.4, 0.4])
|
||||
... .metric("cosine")
|
||||
... .where("b < 10")
|
||||
... .select(["b"])
|
||||
... .limit(2)
|
||||
... .to_df())
|
||||
b vector _distance
|
||||
0 6 [0.4, 0.4] 0.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
class LanceQueryBuilder(ABC):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "lancedb.table.Table",
|
||||
query: Union[np.ndarray, str],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# convert "auto" query_type to "vector" or "fts"
|
||||
# and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(table, query)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query, dtype=np.float32)
|
||||
elif isinstance(query, np.ndarray):
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||
# If query_type is fts, then query must be a string.
|
||||
# otherwise raise TypeError
|
||||
if query_type == "fts":
|
||||
if not isinstance(query, str):
|
||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
return query, query_type
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
)
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table"):
|
||||
self._table = table
|
||||
self._query = query
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._vector_column = vector_column
|
||||
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
@@ -125,7 +197,7 @@ class LanceQueryBuilder:
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where: str) -> LanceQueryBuilder:
|
||||
def where(self, where) -> LanceQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
@@ -141,7 +213,45 @@ class LanceQueryBuilder:
|
||||
self._where = where
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder:
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
A builder for nearest neighbor queries for LanceDB.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4},
|
||||
... {"vector": [0.4, 0.4], "b": 6},
|
||||
... {"vector": [0.4, 0.4], "b": 10}]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data=data)
|
||||
>>> (table.search([0.4, 0.4])
|
||||
... .metric("cosine")
|
||||
... .where("b < 10")
|
||||
... .select(["b"])
|
||||
... .limit(2)
|
||||
... .to_df())
|
||||
b vector _distance
|
||||
0 6 [0.4, 0.4] 0.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
@@ -151,13 +261,13 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of probes to use.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
@@ -173,13 +283,13 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
As an example, a refine factor of 2 will sample 2x as many vectors as
|
||||
@@ -195,22 +305,12 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
@@ -224,6 +324,7 @@ class LanceQueryBuilder:
|
||||
query = Query(
|
||||
vector=vector,
|
||||
filter=self._where,
|
||||
prefilter=self._prefilter,
|
||||
k=self._limit,
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
@@ -233,25 +334,36 @@ class LanceQueryBuilder:
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
where: str
|
||||
The where clause.
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||
without warning in the future. Currently this is only supported
|
||||
in OSS and can only be used with a table that does not have an ANN
|
||||
index.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[LanceModel]
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
self._where = where
|
||||
self._prefilter = prefilter
|
||||
return self
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
import tantivy
|
||||
@@ -275,3 +387,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("score", scores)
|
||||
return output_tbl
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
import abc
|
||||
from typing import List, Optional
|
||||
|
||||
import attr
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -44,7 +44,7 @@ class VectorQuery(BaseModel):
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
@attrs.define
|
||||
class VectorQueryResult:
|
||||
# for now the response is directly seralized into a pandas dataframe
|
||||
tbl: pa.Table
|
||||
|
||||
@@ -16,7 +16,7 @@ import functools
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import attr
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -43,14 +43,14 @@ async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
@attr.define(slots=False)
|
||||
@attrs.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
db_name: str
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attr.field(default=None)
|
||||
host_override: Optional[str] = attrs.field(default=None)
|
||||
|
||||
closed: bool = attr.field(default=False, init=False)
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
|
||||
@@ -18,10 +18,9 @@ from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.common import DATA
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.table import Table, _sanitize_data
|
||||
|
||||
from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..table import Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
|
||||
from ..query import LanceQueryBuilder
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
@@ -73,7 +73,11 @@ class RemoteTable(Table):
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
data = _sanitize_data(
|
||||
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
@@ -89,11 +93,13 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||
) -> LanceQueryBuilder:
|
||||
return LanceQueryBuilder(self, query, vector_column)
|
||||
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
|
||||
) -> LanceVectorQueryBuilder:
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
if query.prefilter:
|
||||
raise NotImplementedError("Cloud support for prefiltering is coming soon")
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -28,54 +28,89 @@ from lance.dataset import ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
from .embeddings.functions import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
||||
def _sanitize_data(
|
||||
data,
|
||||
schema: Optional[pa.Schema],
|
||||
metadata: Optional[dict],
|
||||
on_bad_vectors: str,
|
||||
fill_value: Any,
|
||||
):
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [dict(d) for d in data]
|
||||
data = pa.Table.from_pylist(data)
|
||||
data = _sanitize_schema(
|
||||
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
elif isinstance(data, dict):
|
||||
data = vec_to_table(data)
|
||||
if pd is not None and isinstance(data, pd.DataFrame):
|
||||
elif pd is not None and isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data, preserve_index=False)
|
||||
# Do not serialize Pandas metadata
|
||||
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
data = data.replace_schema_metadata(meta)
|
||||
|
||||
if isinstance(data, pa.Table):
|
||||
if metadata:
|
||||
data = _append_vector_col(data, metadata, schema)
|
||||
metadata.update(data.schema.metadata or {})
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
data = _sanitize_schema(
|
||||
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
# Do not serialize Pandas metadata
|
||||
metadata = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
metadata = {k: v for k, v in metadata.items() if k != b"pandas"}
|
||||
schema = data.schema.with_metadata(metadata)
|
||||
data = pa.Table.from_arrays(data.columns, schema=schema)
|
||||
if isinstance(data, Iterable):
|
||||
data = _to_record_batch_generator(data, schema, on_bad_vectors, fill_value)
|
||||
if not isinstance(data, (pa.Table, Iterable)):
|
||||
elif isinstance(data, Iterable):
|
||||
data = _to_record_batch_generator(
|
||||
data, schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Unsupported data type: {type(data)}")
|
||||
return data
|
||||
|
||||
|
||||
def _to_record_batch_generator(data: Iterable, schema, on_bad_vectors, fill_value):
|
||||
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
|
||||
"""
|
||||
Use the embedding function to automatically embed the source column and add the
|
||||
vector column to the table.
|
||||
"""
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
if vector_column not in data.column_names:
|
||||
col_data = func.compute_source_embeddings(data[conf.source_column])
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_column).type
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
data = data.append_column(
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _to_record_batch_generator(
|
||||
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
||||
):
|
||||
for batch in data:
|
||||
if not isinstance(batch, pa.RecordBatch):
|
||||
table = _sanitize_data(batch, schema, on_bad_vectors, fill_value)
|
||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for batch in table.to_batches():
|
||||
yield batch
|
||||
yield batch
|
||||
else:
|
||||
yield batch
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
"""
|
||||
A [Table](Table) is a collection of Records in a LanceDB [Database](Database).
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -196,17 +231,30 @@ class Table(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list, np.ndarray
|
||||
The query vector.
|
||||
vector_column: str, default "vector"
|
||||
query: str, list, np.ndarray, PIL.Image.Image, default None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, default "vector"
|
||||
The name of the vector column to search.
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -325,14 +373,14 @@ class LanceTable(Table):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.version
|
||||
1
|
||||
2
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.checkout(1)
|
||||
3
|
||||
>>> table.checkout(2)
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
@@ -361,19 +409,19 @@ class LanceTable(Table):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.version
|
||||
1
|
||||
2
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.restore(1)
|
||||
3
|
||||
>>> table.restore(2)
|
||||
>>> table.to_pandas()
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> len(table.list_versions())
|
||||
3
|
||||
4
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
if version is None:
|
||||
@@ -480,6 +528,9 @@ class LanceTable(Table):
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""Add data to the table.
|
||||
If vector columns are missing and the table
|
||||
has embedding functions, then the vector columns
|
||||
are automatically computed and added.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -501,7 +552,11 @@ class LanceTable(Table):
|
||||
"""
|
||||
# TODO: manage table listing and metadata separately
|
||||
data = _sanitize_data(
|
||||
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
||||
self._reset_dataset()
|
||||
@@ -511,7 +566,7 @@ class LanceTable(Table):
|
||||
other_table: Union[LanceTable, ReaderLike],
|
||||
left_on: str,
|
||||
right_on: Optional[str] = None,
|
||||
schema: Optional[pa.Schema, LanceModel] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
):
|
||||
"""Merge another table into this table.
|
||||
|
||||
@@ -569,18 +624,46 @@ class LanceTable(Table):
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
"""
|
||||
Get the embedding functions for the table
|
||||
|
||||
Returns
|
||||
-------
|
||||
funcs: dict
|
||||
A mapping of the vector column to the embedding function
|
||||
or empty dict if not configured.
|
||||
"""
|
||||
return EmbeddingFunctionRegistry.get_instance().parse_functions(
|
||||
self.schema.metadata
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list, np.ndarray
|
||||
The query vector.
|
||||
query: str, list, np.ndarray, a PIL Image or None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, default "vector"
|
||||
The name of the vector column to search.
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If the query is a string, then the query type is "vector" if the
|
||||
table has embedding functions, else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -590,17 +673,9 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query)
|
||||
if isinstance(query, np.ndarray):
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
return LanceQueryBuilder(self, query, vector_column_name)
|
||||
return LanceQueryBuilder.create(
|
||||
self, query, query_type, vector_column_name=vector_column_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -612,6 +687,7 @@ class LanceTable(Table):
|
||||
mode="create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -649,20 +725,58 @@ class LanceTable(Table):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
embedding_functions: list of EmbeddingFunctionModel, default None
|
||||
The embedding functions to use when creating the table.
|
||||
"""
|
||||
tbl = LanceTable(db, name)
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
data,
|
||||
schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
else:
|
||||
if schema is None:
|
||||
|
||||
if schema is None:
|
||||
if data is None:
|
||||
raise ValueError("Either data or schema must be provided")
|
||||
data = pa.Table.from_pylist([], schema=schema)
|
||||
lance.write_dataset(data, tbl._dataset_uri, schema=schema, mode=mode)
|
||||
return LanceTable(db, name)
|
||||
elif hasattr(data, "schema"):
|
||||
schema = data.schema
|
||||
elif isinstance(data, Iterable):
|
||||
if metadata:
|
||||
raise TypeError(
|
||||
(
|
||||
"Persistent embedding functions not yet "
|
||||
"supported for generator data input"
|
||||
)
|
||||
)
|
||||
|
||||
if metadata:
|
||||
schema = schema.with_metadata(metadata)
|
||||
|
||||
empty = pa.Table.from_pylist([], schema=schema)
|
||||
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
|
||||
table = LanceTable(db, name)
|
||||
|
||||
if data is not None:
|
||||
table.add(data)
|
||||
|
||||
return table
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name):
|
||||
@@ -678,11 +792,68 @@ class LanceTable(Table):
|
||||
def delete(self, where: str):
|
||||
self._dataset.delete(where)
|
||||
|
||||
def update(self, where: str, values: dict):
|
||||
"""
|
||||
EXPERIMENTAL: Update rows in the table (not threadsafe).
|
||||
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
values: dict
|
||||
The values to update. The keys are the column names and the values
|
||||
are the values to set.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 3 [5.0, 6.0]
|
||||
2 2 [10.0, 10.0]
|
||||
|
||||
"""
|
||||
orig_data = self._dataset.to_table(filter=where).combine_chunks()
|
||||
if len(orig_data) == 0:
|
||||
return
|
||||
for col, val in values.items():
|
||||
i = orig_data.column_names.index(col)
|
||||
if i < 0:
|
||||
raise ValueError(f"Column {col} does not exist")
|
||||
orig_data = orig_data.set_column(
|
||||
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
|
||||
)
|
||||
self.delete(where)
|
||||
self.add(orig_data, mode="append")
|
||||
self._reset_dataset()
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
ds = self.to_lance()
|
||||
if query.prefilter:
|
||||
for idx in ds.list_indices():
|
||||
if query.vector_column in idx["fields"]:
|
||||
raise NotImplementedError(
|
||||
"Prefiltering for indexed vector column is coming soon."
|
||||
)
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
@@ -720,22 +891,38 @@ def _sanitize_schema(
|
||||
return data
|
||||
# cast the columns to the expected types
|
||||
data = data.combine_chunks()
|
||||
data = _sanitize_vector_column(
|
||||
for field in schema:
|
||||
# TODO: we're making an assumption that fixed size list of 10 or more
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_float32(field.type.value_type)
|
||||
and field.type.list_size >= 10
|
||||
)
|
||||
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
|
||||
if field.name in data.column_names and (
|
||||
likely_vector_col or is_default_vector_col
|
||||
):
|
||||
data = _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=field.name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
|
||||
# just check the vector column
|
||||
if VECTOR_COLUMN_NAME in data.column_names:
|
||||
return _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
# just check the vector column
|
||||
return _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_vector_column(
|
||||
@@ -759,8 +946,6 @@ def _sanitize_vector_column(
|
||||
fill_value: float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if vector_column_name not in data.column_names:
|
||||
raise ValueError(f"Missing vector column: {vector_column_name}")
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if pa.types.is_list(data[vector_column_name].type):
|
||||
|
||||
@@ -70,7 +70,11 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
|
||||
Get a PyArrow FileSystem from a URI, handling extra environment variables.
|
||||
"""
|
||||
if get_uri_scheme(uri) == "s3":
|
||||
fs = pa_fs.S3FileSystem(endpoint_override=os.environ.get("AWS_ENDPOINT"))
|
||||
fs = pa_fs.S3FileSystem(
|
||||
endpoint_override=os.environ.get("AWS_ENDPOINT"),
|
||||
request_timeout=30,
|
||||
connect_timeout=30,
|
||||
)
|
||||
path = get_uri_location(uri)
|
||||
return fs, path
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.2.2"
|
||||
version = "0.2.6"
|
||||
dependencies = [
|
||||
"pylance==0.6.5",
|
||||
"ratelimiter",
|
||||
"retry",
|
||||
"tqdm",
|
||||
"pylance==0.8.0",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.1.0",
|
||||
"aiohttp",
|
||||
"pydantic",
|
||||
"attr",
|
||||
"semver>=3.0"
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
"semver>=3.0",
|
||||
"cachetools"
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
@@ -43,9 +44,11 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio"]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
@@ -53,3 +56,10 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio"
|
||||
]
|
||||
|
||||
@@ -17,7 +17,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
def test_basic(tmp_path):
|
||||
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
|
||||
|
||||
def test_ingest_iterator(tmp_path):
|
||||
class PydanticSchema(LanceModel):
|
||||
vector: vector(2)
|
||||
vector: Vector(2)
|
||||
item: str
|
||||
price: float
|
||||
|
||||
@@ -136,15 +136,14 @@ def test_ingest_iterator(tmp_path):
|
||||
def run_tests(schema):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||
|
||||
tbl.to_pandas()
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
|
||||
tbl_len = len(tbl)
|
||||
tbl.add(make_batches())
|
||||
assert tbl_len == 50
|
||||
assert len(tbl) == tbl_len * 2
|
||||
assert len(tbl.list_versions()) == 2
|
||||
assert len(tbl.list_versions()) == 3
|
||||
db.drop_database()
|
||||
|
||||
run_tests(arrow_schema)
|
||||
|
||||
@@ -12,10 +12,16 @@
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.embeddings import with_embeddings
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
@@ -40,3 +46,40 @@ def test_with_embeddings():
|
||||
assert data.column_names == ["text", "price", "vector"]
|
||||
assert data.column("text").to_pylist() == ["foo", "bar"]
|
||||
assert data.column("price").to_pylist() == [10.0, 20.0]
|
||||
|
||||
|
||||
def test_embedding_function(tmp_path):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
# let's create a table
|
||||
table = pa.table(
|
||||
{
|
||||
"text": pa.array(["hello world", "goodbye world"]),
|
||||
"vector": [np.random.randn(10), np.random.randn(10)],
|
||||
}
|
||||
)
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
|
||||
# Write it to disk
|
||||
lance.write_dataset(table, tmp_path / "test.lance")
|
||||
|
||||
# Load this back
|
||||
ds = lance.dataset(tmp_path / "test.lance")
|
||||
|
||||
# can we get the serialized version back out?
|
||||
configs = registry.parse_functions(ds.schema.metadata)
|
||||
|
||||
conf = configs["vector"]
|
||||
func = conf.function
|
||||
actual = func.compute_query_embeddings("hello world")
|
||||
|
||||
# And we make sure we can call it
|
||||
expected = func.compute_query_embeddings("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
125
python/tests/test_embeddings_slow.py
Normal file
125
python/tests/test_embeddings_slow.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import lancedb
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
# or connection to external api
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_sentence_transformer(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get(alias).create()
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"text": [
|
||||
"hello world",
|
||||
"goodbye world",
|
||||
"fizz",
|
||||
"buzz",
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
|
||||
vec = func.compute_query_embeddings(query)[0]
|
||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||
assert actual.text == expected.text
|
||||
assert actual.text == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("open-clip").create()
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
||||
assert actual.label == "dog"
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == frombytes.label
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
||||
assert actual.label == "dog"
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == other.label
|
||||
|
||||
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
||||
assert np.allclose(
|
||||
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||
)
|
||||
@@ -19,8 +19,9 @@ from typing import List, Optional
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -107,7 +108,7 @@ def test_pydantic_to_arrow_py38():
|
||||
|
||||
def test_fixed_size_list_field():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: vector(16)
|
||||
vec: Vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
|
||||
@@ -154,7 +155,7 @@ def test_fixed_size_list_field():
|
||||
|
||||
def test_fixed_size_list_validation():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: vector(8)
|
||||
vec: Vector(8)
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=range(9))
|
||||
@@ -167,9 +168,12 @@ def test_fixed_size_list_validation():
|
||||
|
||||
def test_lance_model():
|
||||
class TestModel(LanceModel):
|
||||
vec: vector(16)
|
||||
li: List[int]
|
||||
vector: Vector(16) = Field(default=[0.0] * 16)
|
||||
li: List[int] = Field(default=[1, 2, 3])
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["vec", "li"]
|
||||
assert TestModel.field_names() == ["vector", "li"]
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
@@ -20,8 +20,8 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.query import LanceQueryBuilder, Query
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ class MockTable:
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
@@ -67,12 +68,12 @@ def table(tmp_path) -> MockTable:
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(2)
|
||||
vector: Vector(2)
|
||||
id: int
|
||||
str_field: str
|
||||
float_field: float
|
||||
|
||||
q = LanceQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
results = q.to_pydantic(TestModel)
|
||||
assert len(results) == 1
|
||||
r0 = results[0]
|
||||
@@ -84,13 +85,34 @@ def test_cast(table):
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
)
|
||||
assert df["id"].values[0] == 1
|
||||
assert all(df["vector"].values[0] == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_prefilter(table):
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
)
|
||||
assert len(df) == 0
|
||||
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2", prefilter=True)
|
||||
.limit(1)
|
||||
.to_df()
|
||||
)
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
@@ -98,12 +120,14 @@ def test_query_builder_with_filter(table):
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
vector_column_name = "vector"
|
||||
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
)
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
df_cosine = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
@@ -120,7 +144,7 @@ def test_query_builder_with_different_vector_column():
|
||||
query = [4, 8]
|
||||
vector_column_name = "foo_vector"
|
||||
builder = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.where("b < 10")
|
||||
.select(["b"])
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attr
|
||||
import attrs
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
@@ -21,10 +21,10 @@ from aiohttp import web
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@attr.define
|
||||
@attrs.define
|
||||
class MockLanceDBServer:
|
||||
runner: web.AppRunner = attr.field(init=False)
|
||||
site: web.TCPSite = attr.field(init=False)
|
||||
runner: web.AppRunner = attrs.field(init=False)
|
||||
site: web.TCPSite = attrs.field(init=False)
|
||||
|
||||
async def query_handler(self, request: web.Request) -> web.Response:
|
||||
table_name = request.match_info["table_name"]
|
||||
|
||||
@@ -22,8 +22,10 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -139,7 +141,7 @@ def test_add(db):
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(16)
|
||||
vector: Vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
|
||||
@@ -178,16 +180,16 @@ def test_versioning(db):
|
||||
],
|
||||
)
|
||||
|
||||
assert len(table.list_versions()) == 1
|
||||
assert table.version == 1
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 3
|
||||
|
||||
table.checkout(1)
|
||||
assert table.version == 1
|
||||
table.checkout(2)
|
||||
assert table.version == 2
|
||||
assert len(table) == 2
|
||||
|
||||
|
||||
@@ -278,21 +280,21 @@ def test_restore(db):
|
||||
data=[{"vector": [1.1, 0.9], "type": "vector"}],
|
||||
)
|
||||
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
table.restore(1)
|
||||
assert len(table.list_versions()) == 3
|
||||
table.restore(2)
|
||||
assert len(table.list_versions()) == 4
|
||||
assert len(table) == 1
|
||||
|
||||
expected = table.to_arrow()
|
||||
table.checkout(1)
|
||||
table.checkout(2)
|
||||
table.restore()
|
||||
assert len(table.list_versions()) == 4
|
||||
assert len(table.list_versions()) == 5
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
table.restore(4) # latest version should be no-op
|
||||
assert len(table.list_versions()) == 4
|
||||
table.restore(5) # latest version should be no-op
|
||||
assert len(table.list_versions()) == 5
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(5)
|
||||
table.restore(6)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(0)
|
||||
@@ -306,7 +308,7 @@ def test_merge(db, tmp_path):
|
||||
)
|
||||
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
|
||||
table.merge(other_table, left_on="id")
|
||||
assert len(table.list_versions()) == 2
|
||||
assert len(table.list_versions()) == 3
|
||||
expected = pa.table(
|
||||
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
|
||||
schema=table.schema,
|
||||
@@ -316,3 +318,126 @@ def test_merge(db, tmp_path):
|
||||
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
|
||||
table.restore(1)
|
||||
table.merge(other_dataset, left_on="id")
|
||||
|
||||
|
||||
def test_delete(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 2
|
||||
table.delete("id=0")
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 1
|
||||
assert table.to_pandas()["id"].tolist() == [1]
|
||||
|
||||
|
||||
def test_update(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 2
|
||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
assert len(table.list_versions()) == 4
|
||||
assert table.version == 4
|
||||
assert len(table) == 2
|
||||
v = table.to_arrow()["vector"].combine_chunks()
|
||||
v = v.values.to_numpy().reshape(2, 2)
|
||||
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockTextEmbeddingFunction()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text", vector_column="vector", function=func
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[conf],
|
||||
)
|
||||
table.add(df)
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
table.add(df)
|
||||
|
||||
texts = ["the quick brown fox", "jumped over the lazy dog"]
|
||||
table.add([{"text": t} for t in texts])
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = emb.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_multiple_vector_columns(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: Vector(10)
|
||||
vector2: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
|
||||
|
||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||
|
||||
|
||||
def test_empty_query(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
|
||||
val = df.id.iloc[0]
|
||||
assert val == 1
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.2.4"
|
||||
version = "0.2.6"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
@@ -18,6 +18,7 @@ once_cell = "1"
|
||||
futures = "0.3"
|
||||
half = { workspace = true }
|
||||
lance = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
vectordb = { path = "../../vectordb" }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
|
||||
|
||||
@@ -28,7 +28,9 @@ fn validate_vector_column(record_batch: &RecordBatch) -> Result<()> {
|
||||
record_batch
|
||||
.column_by_name(VECTOR_COLUMN_NAME)
|
||||
.map(|_| ())
|
||||
.context(MissingColumnSnafu { name: VECTOR_COLUMN_NAME })
|
||||
.context(MissingColumnSnafu {
|
||||
name: VECTOR_COLUMN_NAME,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::MetricType;
|
||||
use lance_linalg::distance::MetricType;
|
||||
use neon::context::FunctionContext;
|
||||
use neon::prelude::*;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
@@ -183,11 +183,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let aws_region = get_aws_region(&mut cx, 4)?;
|
||||
|
||||
let params = ReadParams {
|
||||
store_options: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
store_options: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
..ReadParams::default()
|
||||
};
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::ops::Deref;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use lance::index::vector::MetricType;
|
||||
use lance_linalg::distance::MetricType;
|
||||
use neon::context::FunctionContext;
|
||||
use neon::handle::Handle;
|
||||
use neon::prelude::*;
|
||||
|
||||
@@ -43,7 +43,8 @@ impl JsTable {
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let buffer = cx.argument::<JsBuffer>(1)?;
|
||||
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
let (batches, schema) =
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
|
||||
// Write mode
|
||||
let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() {
|
||||
@@ -65,11 +66,9 @@ impl JsTable {
|
||||
let aws_region = get_aws_region(&mut cx, 6)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
mode: mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
@@ -92,7 +91,8 @@ impl JsTable {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let buffer = cx.argument::<JsBuffer>(0)?;
|
||||
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
||||
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
let (batches, schema) =
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let mut table = js_table.table.clone();
|
||||
@@ -108,11 +108,9 @@ impl JsTable {
|
||||
let aws_region = get_aws_region(&mut cx, 5)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
mode: write_mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.2.4"
|
||||
version = "0.2.6"
|
||||
edition = "2021"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
keywords = ["lancedb", "lance", "database", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
arrow = { workspace = true }
|
||||
arrow-array = { workspace = true }
|
||||
arrow-data = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
half = { workspace = true }
|
||||
lance = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log = { workspace = true }
|
||||
num-traits = "0"
|
||||
url = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5.0"
|
||||
|
||||
3
rust/vectordb/README.md
Normal file
3
rust/vectordb/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# LanceDB Rust
|
||||
|
||||
Rust client for LanceDB, a serverless vector database. Read more at: https://lancedb.com/
|
||||
15
rust/vectordb/src/arrow.rs
Normal file
15
rust/vectordb/src/arrow.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub use lance::arrow::*;
|
||||
18
rust/vectordb/src/data.rs
Normal file
18
rust/vectordb/src/data.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Data types, schema coercion, and data cleaning and etc.
|
||||
|
||||
pub mod inspect;
|
||||
pub mod sanitize;
|
||||
180
rust/vectordb/src/data/inspect.rs
Normal file
180
rust/vectordb/src/data/inspect.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow::compute::kernels::{aggregate::bool_and, length::length};
|
||||
use arrow_array::{
|
||||
cast::AsArray,
|
||||
types::{ArrowPrimitiveType, Int32Type, Int64Type},
|
||||
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
|
||||
};
|
||||
use arrow_ord::comparison::eq_dyn_scalar;
|
||||
use arrow_schema::DataType;
|
||||
use num_traits::{ToPrimitive, Zero};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
pub(crate) fn infer_dimension<T: ArrowPrimitiveType>(
|
||||
list_arr: &GenericListArray<T::Native>,
|
||||
) -> Result<Option<T::Native>>
|
||||
where
|
||||
T::Native: OffsetSizeTrait + ToPrimitive,
|
||||
{
|
||||
let len_arr = length(list_arr)?;
|
||||
if len_arr.is_empty() {
|
||||
return Ok(Some(Zero::zero()));
|
||||
}
|
||||
|
||||
let dim = len_arr.as_primitive::<T>().value(0);
|
||||
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(dim))
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer the vector columns from a dataset.
|
||||
///
|
||||
/// Parameters
|
||||
/// ----------
|
||||
/// - reader: RecordBatchReader
|
||||
/// - strict: if set true, only fixed_size_list<float> is considered as vector column. If set to false,
|
||||
/// a list<float> column with same length is also considered as vector column.
|
||||
pub fn infer_vector_columns(
|
||||
reader: impl RecordBatchReader + Send,
|
||||
strict: bool,
|
||||
) -> Result<Vec<String>> {
|
||||
let mut columns = vec![];
|
||||
|
||||
let mut columns_to_infer: HashMap<String, Option<i64>> = HashMap::new();
|
||||
for field in reader.schema().fields() {
|
||||
match field.data_type() {
|
||||
DataType::FixedSizeList(sub_field, _) if sub_field.data_type().is_floating() => {
|
||||
columns.push(field.name().to_string());
|
||||
}
|
||||
DataType::List(sub_field) if sub_field.data_type().is_floating() && !strict => {
|
||||
columns_to_infer.insert(field.name().to_string(), None);
|
||||
}
|
||||
DataType::LargeList(sub_field) if sub_field.data_type().is_floating() && !strict => {
|
||||
columns_to_infer.insert(field.name().to_string(), None);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for batch in reader {
|
||||
let batch = batch?;
|
||||
let col_names = columns_to_infer.keys().cloned().collect::<Vec<_>>();
|
||||
for col_name in col_names {
|
||||
let col = batch.column_by_name(&col_name).ok_or(Error::Schema {
|
||||
message: format!("Column {} not found", col_name),
|
||||
})?;
|
||||
if let Some(dim) = match *col.data_type() {
|
||||
DataType::List(_) => {
|
||||
infer_dimension::<Int32Type>(col.as_list::<i32>())?.map(|d| d as i64)
|
||||
}
|
||||
DataType::LargeList(_) => infer_dimension::<Int64Type>(col.as_list::<i64>())?,
|
||||
_ => {
|
||||
return Err(Error::Schema {
|
||||
message: format!("Column {} is not a list", col_name),
|
||||
})
|
||||
}
|
||||
} {
|
||||
if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) {
|
||||
if prev_dim != &dim {
|
||||
columns_to_infer.remove(&col_name);
|
||||
}
|
||||
} else {
|
||||
columns_to_infer.insert(col_name, Some(dim));
|
||||
}
|
||||
} else {
|
||||
columns_to_infer.remove(&col_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
columns.extend(columns_to_infer.keys().cloned());
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use arrow_array::{
|
||||
types::{Float32Type, Float64Type},
|
||||
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use std::{sync::Arc, vec};
|
||||
|
||||
#[test]
|
||||
fn test_infer_vector_columns() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("f", DataType::Float32, false),
|
||||
Field::new("s", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"l1",
|
||||
DataType::List(Arc::new(Field::new("item", DataType::Float32, true))),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"l2",
|
||||
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"fl",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 32),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])),
|
||||
Arc::new(StringArray::from(vec!["a", "b", "c"])),
|
||||
Arc::new(ListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..3).map(|_| Some(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)])),
|
||||
)),
|
||||
// Var-length list
|
||||
Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
|
||||
Some(vec![Some(1.0_f64)]),
|
||||
Some(vec![Some(2.0_f64), Some(3.0_f64)]),
|
||||
Some(vec![Some(4.0_f64), Some(5.0_f64), Some(6.0_f64)]),
|
||||
])),
|
||||
Arc::new(
|
||||
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
vec![
|
||||
Some(vec![Some(1.0); 32]),
|
||||
Some(vec![Some(2.0); 32]),
|
||||
Some(vec![Some(3.0); 32]),
|
||||
],
|
||||
32,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let reader =
|
||||
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
|
||||
|
||||
let cols = infer_vector_columns(reader, false).unwrap();
|
||||
assert_eq!(cols, vec!["fl", "l1"]);
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
|
||||
let cols = infer_vector_columns(reader, true).unwrap();
|
||||
assert_eq!(cols, vec!["fl"]);
|
||||
}
|
||||
}
|
||||
284
rust/vectordb/src/data/sanitize.rs
Normal file
284
rust/vectordb/src/data/sanitize.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{iter::repeat_with, sync::Arc};
|
||||
|
||||
use arrow_array::{
|
||||
cast::AsArray,
|
||||
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
|
||||
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
};
|
||||
use arrow_cast::{can_cast_types, cast};
|
||||
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
||||
use half::f16;
|
||||
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
|
||||
use log::warn;
|
||||
use num_traits::cast::AsPrimitive;
|
||||
|
||||
use super::inspect::infer_dimension;
|
||||
use crate::error::Result;
|
||||
|
||||
fn cast_array<I: ArrowNumericType, O: ArrowNumericType>(
|
||||
arr: &PrimitiveArray<I>,
|
||||
) -> Arc<PrimitiveArray<O>>
|
||||
where
|
||||
I::Native: AsPrimitive<O::Native>,
|
||||
{
|
||||
Arc::new(PrimitiveArray::<O>::from_iter_values(
|
||||
arr.values().iter().map(|v| (*v).as_()),
|
||||
))
|
||||
}
|
||||
|
||||
fn cast_float_array<I: ArrowNumericType>(
|
||||
arr: &PrimitiveArray<I>,
|
||||
dt: &DataType,
|
||||
) -> std::result::Result<Arc<dyn Array>, ArrowError>
|
||||
where
|
||||
I::Native: AsPrimitive<f64> + AsPrimitive<f32> + AsPrimitive<f16>,
|
||||
{
|
||||
match dt {
|
||||
DataType::Float16 => Ok(cast_array::<I, Float16Type>(arr)),
|
||||
DataType::Float32 => Ok(cast_array::<I, Float32Type>(arr)),
|
||||
DataType::Float64 => Ok(cast_array::<I, Float64Type>(arr)),
|
||||
_ => Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible change field: unable to coerce {:?} to {:?}",
|
||||
arr.data_type(),
|
||||
dt
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn coerce_array(
|
||||
array: &Arc<dyn Array>,
|
||||
field: &Field,
|
||||
) -> std::result::Result<Arc<dyn Array>, ArrowError> {
|
||||
if array.data_type() == field.data_type() {
|
||||
return Ok(array.clone());
|
||||
}
|
||||
match (array.data_type(), field.data_type()) {
|
||||
// Normal cast-able types.
|
||||
(adt, dt) if can_cast_types(adt, dt) => cast(&array, dt),
|
||||
// Casting between f16/f32/f64 can be lossy.
|
||||
(adt, dt) if (adt.is_floating() || dt.is_floating()) => {
|
||||
if adt.byte_width() > dt.byte_width() {
|
||||
warn!(
|
||||
"Coercing field {} {:?} to {:?} might lose precision",
|
||||
field.name(),
|
||||
adt,
|
||||
dt
|
||||
);
|
||||
}
|
||||
match adt {
|
||||
DataType::Float16 => cast_float_array(array.as_primitive::<Float16Type>(), dt),
|
||||
DataType::Float32 => cast_float_array(array.as_primitive::<Float32Type>(), dt),
|
||||
DataType::Float64 => cast_float_array(array.as_primitive::<Float64Type>(), dt),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
(adt, DataType::FixedSizeList(exp_field, exp_dim)) => match adt {
|
||||
// Cast a float fixed size array with same dimension to the expected type.
|
||||
DataType::FixedSizeList(_, dim) if dim == exp_dim => {
|
||||
let actual_sub = array.as_fixed_size_list();
|
||||
let values = coerce_array(actual_sub.values(), exp_field)?;
|
||||
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
|
||||
values.clone(),
|
||||
*dim,
|
||||
)?) as Arc<dyn Array>)
|
||||
}
|
||||
DataType::List(_) | DataType::LargeList(_) => {
|
||||
let Some(dim) = (match adt {
|
||||
DataType::List(_) => infer_dimension::<Int32Type>(array.as_list::<i32>())
|
||||
.map_err(|e| {
|
||||
ArrowError::SchemaError(format!(
|
||||
"failed to infer dimension from list: {}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
.map(|d| d as i64),
|
||||
DataType::LargeList(_) => infer_dimension::<Int64Type>(array.as_list::<i64>())
|
||||
.map_err(|e| {
|
||||
ArrowError::SchemaError(format!(
|
||||
"failed to infer dimension from large list: {}",
|
||||
e
|
||||
))
|
||||
})?,
|
||||
_ => unreachable!(),
|
||||
}) else {
|
||||
return Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
|
||||
field,
|
||||
array.data_type()
|
||||
)));
|
||||
};
|
||||
|
||||
if dim != *exp_dim as i64 {
|
||||
return Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible coerce fixed size list: expected dimension {} but got {}",
|
||||
exp_dim, dim
|
||||
)));
|
||||
}
|
||||
|
||||
let values = coerce_array(array, exp_field)?;
|
||||
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
|
||||
values.clone(),
|
||||
*exp_dim,
|
||||
)?) as Arc<dyn Array>)
|
||||
}
|
||||
_ => Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
|
||||
field,
|
||||
array.data_type()
|
||||
)))?,
|
||||
},
|
||||
_ => Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible change field {}: unable to coerce {:?} to {:?}",
|
||||
field.name(),
|
||||
array.data_type(),
|
||||
field.data_type()
|
||||
)))?,
|
||||
}
|
||||
}
|
||||
|
||||
fn coerce_schema_batch(
|
||||
batch: RecordBatch,
|
||||
schema: Arc<Schema>,
|
||||
) -> std::result::Result<RecordBatch, ArrowError> {
|
||||
if batch.schema() == schema {
|
||||
return Ok(batch);
|
||||
}
|
||||
let columns = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|field| {
|
||||
batch
|
||||
.column_by_name(field.name())
|
||||
.ok_or_else(|| {
|
||||
ArrowError::SchemaError(format!("Column {} not found", field.name()))
|
||||
})
|
||||
.and_then(|c| coerce_array(c, field))
|
||||
})
|
||||
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
|
||||
RecordBatch::try_new(schema, columns)
|
||||
}
|
||||
|
||||
/// Coerce the reader (input data) to match the given [Schema].
|
||||
///
|
||||
pub fn coerce_schema(
|
||||
reader: impl RecordBatchReader + Send + 'static,
|
||||
schema: Arc<Schema>,
|
||||
) -> Result<Box<dyn RecordBatchReader + Send>> {
|
||||
if reader.schema() == schema {
|
||||
return Ok(Box::new(RecordBatchIterator::new(reader, schema)));
|
||||
}
|
||||
let s = schema.clone();
|
||||
let batches = reader
|
||||
.zip(repeat_with(move || s.clone()))
|
||||
.map(|(batch, s)| coerce_schema_batch(batch?, s));
|
||||
Ok(Box::new(RecordBatchIterator::new(batches, schema)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array,
|
||||
RecordBatch, RecordBatchIterator, StringArray,
|
||||
};
|
||||
use arrow_schema::Field;
|
||||
use half::f16;
|
||||
use lance::arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[test]
|
||||
fn test_coerce_list_to_fixed_size_list() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new(
|
||||
"fl",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 64),
|
||||
true,
|
||||
),
|
||||
Field::new("s", DataType::Utf8, true),
|
||||
Field::new("f", DataType::Float16, true),
|
||||
Field::new("i", DataType::Int32, true),
|
||||
]));
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(
|
||||
FixedSizeListArray::try_new_from_values(
|
||||
Float32Array::from_iter_values((0..256).map(|v| v as f32)),
|
||||
64,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(StringArray::from(vec![
|
||||
Some("hello"),
|
||||
Some("world"),
|
||||
Some("from"),
|
||||
Some("lance"),
|
||||
])),
|
||||
Arc::new(Float16Array::from_iter_values(
|
||||
(0..4).map(|v| f16::from_f32(v as f32)),
|
||||
)),
|
||||
Arc::new(Int32Array::from_iter_values(0..4)),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let reader =
|
||||
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
|
||||
|
||||
let expected_schema = Arc::new(Schema::new(vec![
|
||||
Field::new(
|
||||
"fl",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 64),
|
||||
true,
|
||||
),
|
||||
Field::new("s", DataType::Utf8, true),
|
||||
Field::new("f", DataType::Float64, true),
|
||||
Field::new("i", DataType::Int8, true),
|
||||
]));
|
||||
let stream = coerce_schema(reader, expected_schema.clone()).unwrap();
|
||||
let batches = stream.collect::<Vec<_>>();
|
||||
assert_eq!(batches.len(), 1);
|
||||
let batch = batches[0].as_ref().unwrap();
|
||||
assert_eq!(batch.schema(), expected_schema);
|
||||
|
||||
let expected = RecordBatch::try_new(
|
||||
expected_schema,
|
||||
vec![
|
||||
Arc::new(
|
||||
FixedSizeListArray::try_new_from_values(
|
||||
Float16Array::from_iter_values((0..256).map(|v| f16::from_f32(v as f32))),
|
||||
64,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(StringArray::from(vec![
|
||||
Some("hello"),
|
||||
Some("world"),
|
||||
Some("from"),
|
||||
Some("lance"),
|
||||
])),
|
||||
Arc::new(Float64Array::from_iter_values((0..4).map(|v| v as f64))),
|
||||
Arc::new(Int8Array::from_iter_values(0..4)),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(batch, &expected);
|
||||
}
|
||||
}
|
||||
@@ -27,12 +27,14 @@ pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
pub struct Database {
|
||||
object_store: ObjectStore,
|
||||
query_string: Option<String>,
|
||||
|
||||
pub(crate) uri: String,
|
||||
pub(crate) base_path: object_store::path::Path,
|
||||
}
|
||||
|
||||
const LANCE_EXTENSION: &str = "lance";
|
||||
const ENGINE: &str = "engine";
|
||||
|
||||
/// A connection to LanceDB
|
||||
impl Database {
|
||||
@@ -46,12 +48,73 @@ impl Database {
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub async fn connect(uri: &str) -> Result<Database> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(uri).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
match parse_res {
|
||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await,
|
||||
Ok(mut url) => {
|
||||
// iter thru the query params and extract the commit store param
|
||||
let mut engine = None;
|
||||
let mut filtered_querys = vec![];
|
||||
|
||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||
// THE API WILL CHANGE
|
||||
for (key, value) in url.query_pairs() {
|
||||
if key == ENGINE {
|
||||
engine = Some(value.to_string());
|
||||
} else {
|
||||
// to owned so we can modify the url
|
||||
filtered_querys.push((key.to_string(), value.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out the commit store query param -- it's a lancedb param
|
||||
url.query_pairs_mut().clear();
|
||||
url.query_pairs_mut().extend_pairs(filtered_querys);
|
||||
// Take a copy of the query string so we can propagate it to lance
|
||||
let query_string = url.query().map(|s| s.to_string());
|
||||
// clear the query string so we can use the url as the base uri
|
||||
// use .set_query(None) instead of .set_query("") because the latter
|
||||
// will add a trailing '?' to the url
|
||||
url.set_query(None);
|
||||
|
||||
let table_base_uri = if let Some(store) = engine {
|
||||
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
|
||||
WARN_ONCE.call_once(|| {
|
||||
log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
|
||||
});
|
||||
let old_scheme = url.scheme().to_string();
|
||||
let new_scheme = format!("{}+{}", old_scheme, store);
|
||||
url.to_string().replacen(&old_scheme, &new_scheme, 1)
|
||||
} else {
|
||||
url.to_string()
|
||||
};
|
||||
|
||||
let plain_uri = url.to_string();
|
||||
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||
}
|
||||
|
||||
Ok(Database {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
base_path,
|
||||
object_store,
|
||||
})
|
||||
}
|
||||
Err(_) => Self::open_path(uri).await,
|
||||
}
|
||||
Ok(Database {
|
||||
uri: uri.to_string(),
|
||||
}
|
||||
|
||||
async fn open_path(path: &str) -> Result<Database> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
|
||||
}
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
query_string: None,
|
||||
base_path,
|
||||
object_store,
|
||||
})
|
||||
@@ -149,17 +212,26 @@ impl Database {
|
||||
let path = Path::new(&self.uri);
|
||||
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
|
||||
let uri = table_uri
|
||||
let mut uri = table_uri
|
||||
.as_path()
|
||||
.to_str()
|
||||
.context(InvalidTableNameSnafu { name })?;
|
||||
Ok(uri.to_string())
|
||||
.context(InvalidTableNameSnafu { name })?
|
||||
.to_string();
|
||||
|
||||
// If there are query string set on the connection, propagate to lance
|
||||
if let Some(query) = self.query_string.as_ref() {
|
||||
uri.push('?');
|
||||
uri.push_str(query.as_str());
|
||||
}
|
||||
|
||||
Ok(uri)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs::create_dir_all;
|
||||
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::database::Database;
|
||||
@@ -173,6 +245,28 @@ mod tests {
|
||||
assert_eq!(db.uri, uri);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn test_connect_relative() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap();
|
||||
|
||||
let mut relative_anacestors = vec![];
|
||||
let current_dir = std::env::current_dir().unwrap();
|
||||
let mut ancestors = current_dir.ancestors();
|
||||
while let Some(_) = ancestors.next() {
|
||||
relative_anacestors.push("..");
|
||||
}
|
||||
let relative_root = std::path::PathBuf::from(relative_anacestors.join("/"));
|
||||
let relative_uri = relative_root.join(&uri);
|
||||
|
||||
let db = Database::connect(relative_uri.to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow_schema::ArrowError;
|
||||
use snafu::Snafu;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
@@ -32,10 +33,20 @@ pub enum Error {
|
||||
Store { message: String },
|
||||
#[snafu(display("LanceDBError: {message}"))]
|
||||
Lance { message: String },
|
||||
#[snafu(display("LanceDB Schema Error: {message}"))]
|
||||
Schema { message: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl From<ArrowError> for Error {
|
||||
fn from(e: ArrowError) -> Self {
|
||||
Self::Lance {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lance::Error> for Error {
|
||||
fn from(e: lance::Error) -> Self {
|
||||
Self::Lance {
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::{MetricType, VectorIndexParams};
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
use lance_linalg::distance::MetricType;
|
||||
|
||||
pub trait VectorIndexBuilder {
|
||||
fn get_column(&self) -> Option<String>;
|
||||
@@ -107,9 +108,11 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::{MetricType, StageParams};
|
||||
use lance::index::vector::StageParams;
|
||||
|
||||
use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod data;
|
||||
pub mod database;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::sync::Arc;
|
||||
use arrow_array::Float32Array;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::Dataset;
|
||||
use lance::index::vector::MetricType;
|
||||
use lance_linalg::distance::MetricType;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
@@ -164,10 +164,10 @@ impl Query {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use lance::dataset::Dataset;
|
||||
use lance::index::vector::MetricType;
|
||||
|
||||
use crate::query::Query;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user