mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 05:49:57 +00:00
Compare commits
23 Commits
reproducib
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74004161ff | ||
|
|
34ddb1de6d | ||
|
|
1029fc9cb0 | ||
|
|
31c5df6d99 | ||
|
|
dbf37a0434 | ||
|
|
f20f19b804 | ||
|
|
55207ce844 | ||
|
|
c21f9cdda0 | ||
|
|
bc38abb781 | ||
|
|
731f86e44c | ||
|
|
31dad71c94 | ||
|
|
9585f550b3 | ||
|
|
8dc2315479 | ||
|
|
f6bfb5da11 | ||
|
|
661fcecf38 | ||
|
|
07fe284810 | ||
|
|
800bb691c3 | ||
|
|
ec24e09add | ||
|
|
0554db03b3 | ||
|
|
b315ea3978 | ||
|
|
aa7806cf0d | ||
|
|
6799613109 | ||
|
|
0f26915d22 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.2.4
|
||||
current_version = 0.2.6
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
54
.github/workflows/node.yml
vendored
54
.github/workflows/node.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
- node/**
|
||||
- rust/ffi/node/**
|
||||
- .github/workflows/node.yml
|
||||
- docker-compose.yml
|
||||
|
||||
env:
|
||||
# Disable full debug symbol generation to speed up CI build and keep memory down
|
||||
@@ -107,3 +108,56 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
npm run test
|
||||
aws-integtest:
|
||||
timeout-minutes: 45
|
||||
runs-on: "ubuntu-22.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: node
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ACCESSKEY
|
||||
AWS_SECRET_ACCESS_KEY: SECRETKEY
|
||||
AWS_DEFAULT_REGION: us-west-2
|
||||
# this one is for s3
|
||||
AWS_ENDPOINT: http://localhost:4566
|
||||
# this one is for dynamodb
|
||||
DYNAMODB_ENDPOINT: http://localhost:4566
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 18
|
||||
cache: 'npm'
|
||||
cache-dependency-path: node/package-lock.json
|
||||
- name: start local stack
|
||||
run: docker compose -f ../docker-compose.yml up -d --wait
|
||||
- name: create s3
|
||||
run: aws s3 mb s3://lancedb-integtest --endpoint $AWS_ENDPOINT
|
||||
- name: create ddb
|
||||
run: |
|
||||
aws dynamodb create-table \
|
||||
--table-name lancedb-integtest \
|
||||
--attribute-definitions '[{"AttributeName": "base_uri", "AttributeType": "S"}, {"AttributeName": "version", "AttributeType": "N"}]' \
|
||||
--key-schema '[{"AttributeName": "base_uri", "KeyType": "HASH"}, {"AttributeName": "version", "KeyType": "RANGE"}]' \
|
||||
--provisioned-throughput '{"ReadCapacityUnits": 10, "WriteCapacityUnits": 10}' \
|
||||
--endpoint-url $DYNAMODB_ENDPOINT
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Build
|
||||
run: |
|
||||
npm ci
|
||||
npm run tsc
|
||||
npm run build
|
||||
npm run pack-build
|
||||
npm install --no-save ./dist/lancedb-vectordb-*.tgz
|
||||
# Remove index.node to test with dependency installed
|
||||
rm index.node
|
||||
- name: Test
|
||||
run: npm run integration-test
|
||||
|
||||
34
.github/workflows/python.yml
vendored
34
.github/workflows/python.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: isort
|
||||
run: isort --check --diff --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -x -v --durations=30 tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
run: pytest --doctest-modules lancedb
|
||||
mac:
|
||||
@@ -65,4 +65,34 @@ jobs:
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -x -v --durations=30 tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
pydantic1x:
|
||||
timeout-minutes: 30
|
||||
runs-on: "ubuntu-22.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: python
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install lancedb
|
||||
run: |
|
||||
pip install "pydantic<2"
|
||||
pip install -e .[tests]
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock black isort
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: isort
|
||||
run: isort --check --diff --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
run: pytest --doctest-modules lancedb
|
||||
23
Cargo.toml
23
Cargo.toml
@@ -1,16 +1,25 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"rust/vectordb",
|
||||
"rust/ffi/node"
|
||||
]
|
||||
members = ["rust/ffi/node", "rust/vectordb"]
|
||||
# Python package needs to be built by maturin.
|
||||
exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = "=0.6.5"
|
||||
lance = { "version" = "=0.7.5", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.7.5" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "43.0.0", optional = false }
|
||||
arrow-array = "43.0"
|
||||
arrow-data = "43.0"
|
||||
arrow-schema = "43.0"
|
||||
arrow-ipc = "43.0"
|
||||
half = { "version" = "=2.2.1", default-features = false }
|
||||
arrow-ord = "43.0"
|
||||
arrow-schema = "43.0"
|
||||
arrow-arith = "43.0"
|
||||
arrow-cast = "43.0"
|
||||
half = { "version" = "=2.2.1", default-features = false, features = [
|
||||
"num-traits"
|
||||
] }
|
||||
log = "0.4"
|
||||
object_store = "0.6.1"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
|
||||
18
docker-compose.yml
Normal file
18
docker-compose.yml
Normal file
@@ -0,0 +1,18 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
localstack:
|
||||
image: localstack/localstack:0.14
|
||||
ports:
|
||||
- 4566:4566
|
||||
environment:
|
||||
- SERVICES=s3,dynamodb
|
||||
- DEBUG=1
|
||||
- LS_LOG=trace
|
||||
- DOCKER_HOST=unix:///var/run/docker.sock
|
||||
- AWS_ACCESS_KEY_ID=ACCESSKEY
|
||||
- AWS_SECRET_ACCESS_KEY=SECRETKEY
|
||||
healthcheck:
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]
|
||||
interval: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
@@ -49,11 +49,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
|
||||
db.create_table("table2", data)
|
||||
|
||||
db["table2"].head()
|
||||
db["table2"].head()
|
||||
```
|
||||
!!! info "Note"
|
||||
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
|
||||
|
||||
|
||||
```python
|
||||
custom_schema = pa.schema([
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
@@ -66,7 +66,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
|
||||
### From PyArrow Tables
|
||||
You can also create LanceDB tables directly from pyarrow tables
|
||||
|
||||
|
||||
```python
|
||||
table = pa.Table.from_arrays(
|
||||
[
|
||||
@@ -84,18 +84,28 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
```
|
||||
|
||||
### From Pydantic Models
|
||||
LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method.
|
||||
When you create an empty table without data, you must specify the table schema.
|
||||
LanceDB supports creating tables by specifying a pyarrow schema or a specialized
|
||||
pydantic model called `LanceModel`.
|
||||
|
||||
For example, the following Content model specifies a table with 5 columns:
|
||||
movie_id, vector, genres, title, and imdb_id. When you create a table, you can
|
||||
pass the class as the value of the `schema` parameter to `create_table`.
|
||||
The `vector` column is a `Vector` type, which is a specialized pydantic type that
|
||||
can be configured with the vector dimensions. It is also important to note that
|
||||
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
|
||||
(which itself derives from `pydantic.BaseModel`).
|
||||
|
||||
```python
|
||||
from lancedb.pydantic import vector, LanceModel
|
||||
from lancedb.pydantic import Vector, LanceModel
|
||||
|
||||
class Content(LanceModel):
|
||||
movie_id: int
|
||||
vector: vector(128)
|
||||
vector: Vector(128)
|
||||
genres: str
|
||||
title: str
|
||||
imdb_id: int
|
||||
|
||||
|
||||
@property
|
||||
def imdb_url(self) -> str:
|
||||
return f"https://www.imdb.com/title/tt{self.imdb_id}"
|
||||
@@ -103,7 +113,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
import pyarrow as pa
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
table_name = "movielens_small"
|
||||
table = db.create_table(table_name, schema=Content.to_arrow_schema())
|
||||
table = db.create_table(table_name, schema=Content)
|
||||
```
|
||||
|
||||
### Using Iterators / Writing Large Datasets
|
||||
@@ -113,7 +123,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types.
|
||||
|
||||
Here's an example using using `RecordBatch` iterator for creating tables.
|
||||
|
||||
|
||||
```python
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -142,11 +152,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
|
||||
## Creating Empty Table
|
||||
You can also create empty tables in python. Initialize it with schema and later ingest data into it.
|
||||
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
@@ -168,8 +178,8 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
|
||||
class Model(LanceModel):
|
||||
vector: vector(2)
|
||||
|
||||
vector: Vector(2)
|
||||
|
||||
tbl = db.create_table("table5", schema=Model.to_arrow_schema())
|
||||
```
|
||||
|
||||
@@ -249,7 +259,7 @@ After a table has been created, you can always add more data to it using
|
||||
You can also add a large dataset batch in one go using Iterator of any supported data types.
|
||||
|
||||
### Adding to table using Iterator
|
||||
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
@@ -261,10 +271,10 @@ After a table has been created, you can always add more data to it using
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
})
|
||||
|
||||
|
||||
tbl.add(make_batches())
|
||||
```
|
||||
|
||||
|
||||
The other arguments accepted:
|
||||
|
||||
| Name | Type | Description | Default |
|
||||
@@ -274,7 +284,7 @@ After a table has been created, you can always add more data to it using
|
||||
| on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop |
|
||||
| fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 |
|
||||
|
||||
|
||||
|
||||
=== "Javascript/Typescript"
|
||||
|
||||
```javascript
|
||||
@@ -312,7 +322,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
|
||||
# x vector
|
||||
# 0 1 [1.0, 2.0]
|
||||
# 1 3 [5.0, 6.0]
|
||||
```
|
||||
```
|
||||
|
||||
### Delete from a list of values
|
||||
|
||||
@@ -325,7 +335,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
|
||||
# x vector
|
||||
# 0 3 [5.0, 6.0]
|
||||
```
|
||||
|
||||
|
||||
=== "Javascript/Typescript"
|
||||
|
||||
```javascript
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,62 +0,0 @@
|
||||
id,quote,author
|
||||
1,"Nobody exists on purpose. Nobody belongs anywhere.",Morty
|
||||
2,"We're all going to die. Come watch TV.",Morty
|
||||
3,"Losers look stuff up while the rest of us are carpin' all them diems.",Summer
|
||||
4,"He's not a hot girl. He can't just bail on his life and set up shop in someone else's.",Beth
|
||||
5,"When you are an a—hole, it doesn't matter how right you are. Nobody wants to give you the satisfaction.",Morty
|
||||
6,"God's turning people into insect monsters, Beth. I'm the one beating them to death. Thank me.",Jerry
|
||||
7,"Camping is just being homeless without the change.",Summer
|
||||
8,"This seems like a good time for a drink and a cold, calculated speech with sinister overtones. A speech about politics, about order, brotherhood, power ... but speeches are for campaigning. Now is the time for action.",Morty
|
||||
9,"Having a family doesn't mean that you stop being an individual. You know the best thing you can do for the people that depend on you? Be honest with them, even if it means setting them free.",Mr. Meeseeks
|
||||
10,"If I've learned one thing, it's that before you get anywhere in life, you gotta stop listening to yourself.",Jerry
|
||||
11,"I just want to go back to Hell, where everyone thinks I'm smart and funny.",Mr. Needful
|
||||
12,"Hi Mr. Jellybean, I'm Morty. I’m on an adventure with my grandpa.",Morty
|
||||
13,"You're not the cause of your parents' misery. You're just a symptom of it.",Summer
|
||||
14,"Don't deify the people who leave you.",Beth
|
||||
15,"Well, then get your s—t together, get it all together, and put it in a backpack, all your s—t, so it's together. And if you gotta take it somewhere, take it somewhere, you know, take it to the s—t store and sell it, or put it in the s—t museum. I don't care what you do, you just gotta get it together. Get your s—t together.",Morty
|
||||
16,"At least the devil has a job!",Summer
|
||||
17,"Life is effort and I'll stop when I die!",Jerry
|
||||
18,"I just killed my family! I don't care what they were!",Morty
|
||||
19,"It's funny to say they are small. It's funny to say they are big.",Shrimply Pibbles
|
||||
20,"You're holding me verbally hostage.",Summer
|
||||
21,"Honey, stop raising your father's cholesterol so you can take a hot funeral selfie.",Beth
|
||||
22,"Rick, when you say you made an exact replica of the house, did you mean, like, an exact replica?",Morty
|
||||
23,"Give a gun to the lady who got pregnant with me too early and constantly makes it our problem.",Summer
|
||||
24,"Say goodbye to your precious dry land! For soon it will be wet!",Mr. Nimbus
|
||||
25,"Nobody's smarter than Rick, but nobody else is my dad. You're a genius at that.",Morty
|
||||
26,"B—h, my generation gets traumatized for breakfast.",Summer
|
||||
27,"Inception made sense!",Morty
|
||||
28,"I realize now I'm attracted to you for the same reason I can't be with you: You can't change. And I have no problem with that, but it clearly means I have a problem with myself.",Unity
|
||||
29,"Mr. President, if I've learned one thing today, it's that sometimes you have to not give a f—k!",Morty
|
||||
30,"I didn't know freedom meant people doing stuff that sucks.",Summer
|
||||
31,"How many of these are just horrible mistakes I made? I mean, maybe I'd stop making so many if I let myself learn from them.",Morty
|
||||
32,"I'm a scientist because I invent, transform, create, and destroy for a living. And when I don't like something about the world, I change it.",Rick
|
||||
33,"Wubba lubba dub dub!",Rick
|
||||
34,"I turned myself into a pickle, Morty! I'm Pickle Rick!",Rick
|
||||
35,"I know about the Yosemite T-shirt, Morty.",Rick
|
||||
36,"The universe is basically an animal. It grazes on the ordinary. It creates infinite idiots just to eat them.",Rick
|
||||
37,"If I die in a cage, I lose a bet.",Rick
|
||||
38,"Sometimes science is more art than science.",Rick
|
||||
39,"To live is to risk it all—otherwise, you're just an inert chunk of randomly assembled molecules drifting wherever the universe blows you.",Rick
|
||||
40,"Welcome to the club, pal.",Rick
|
||||
41,"So I have an emo streak. It's part of what makes me so rad.",Rick
|
||||
42,"Listen, I'm not the nicest guy in the universe, because I'm the smartest, and being nice is something stupid people do to hedge their bets.",Rick
|
||||
43,"Wait a minute! Is that Mountain Dew in my quantum-transport-solution?",Rick
|
||||
44,"Listen, Morty, I hate to break it to you, but what people call 'love' is just a chemical reaction that compels animals to breed.",Rick
|
||||
45,"Break the cycle, Morty. Rise above. Focus on science.",Rick
|
||||
46,"Don't get drawn into the culture, Morty. Stealing stuff is about the stuff, not the stealing.",Rick
|
||||
47,"I'm sorry, but your opinion means very little to me.",Rick
|
||||
48,"You don't get to tell anyone what's sad. You’re like a one-man Mount Sadmore. So I guess like a Lincoln Sadmorial.",Rick
|
||||
49,"This pickle doesn't care about your children. I'm not gonna take their dreams. I'm gonna take their parents.",Rick
|
||||
50,"I programmed you to believe that.",Rick
|
||||
51,"Have fun with empowerment. It seems to make everyone that gets it really happy.",Rick
|
||||
52,"Thanks, Mr. Poopybutthole. I always could count on you.",Rick
|
||||
53,"Weddings are basically funerals with a cake.",Rick
|
||||
54,"I mean, if you spend all day shuffling words around, you can make anything sound bad, Morty.",Rick
|
||||
55,"It's your choice to take this personally.",Rick
|
||||
56,"Excuse me, coming through. What are you here for? Just kidding, I don't care.",Rick
|
||||
57,"If I let you make me nervous, then we can't get schwifty.",Rick
|
||||
58,"Oh, boy, so you actually learned something today? What is this, Full House?",Rick
|
||||
59,"I can't abide bureaucracy. I don't like being told where to go and what to do. I consider it a violation. Did you get those seeds all the way up your butt?",Rick
|
||||
60,"I think you have to think ahead and live in the moment.",Rick
|
||||
61,"I know that new situations can be intimidating. You're lookin' around and it's all scary and different, but you know, meeting them head-on, charging into 'em like a bull—that's how we grow as people.",Rick
|
||||
|
@@ -249,11 +249,11 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from lancedb.pydantic import vector, LanceModel\n",
|
||||
"from lancedb.pydantic import Vector, LanceModel\n",
|
||||
"\n",
|
||||
"class Content(LanceModel):\n",
|
||||
" movie_id: int\n",
|
||||
" vector: vector(128)\n",
|
||||
" vector: Vector(128)\n",
|
||||
" genres: str\n",
|
||||
" title: str\n",
|
||||
" imdb_id: int\n",
|
||||
@@ -359,7 +359,7 @@
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"class PydanticSchema(LanceModel):\n",
|
||||
" vector: vector(2)\n",
|
||||
" vector: Vector(2)\n",
|
||||
" item: str\n",
|
||||
" price: float\n",
|
||||
"\n",
|
||||
@@ -394,10 +394,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import lancedb\n",
|
||||
"from lancedb.pydantic import LanceModel, vector\n",
|
||||
"from lancedb.pydantic import LanceModel, Vector\n",
|
||||
"\n",
|
||||
"class Model(LanceModel):\n",
|
||||
" vector: vector(2)\n",
|
||||
" vector: Vector(2)\n",
|
||||
"\n",
|
||||
"tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())"
|
||||
]
|
||||
|
||||
@@ -13,10 +13,10 @@ via [pydantic_to_schema()](python.md##lancedb.pydantic.pydantic_to_schema) metho
|
||||
|
||||
## Vector Field
|
||||
|
||||
LanceDB provides a [`vector(dim)`](python.md#lancedb.pydantic.vector) method to define a
|
||||
LanceDB provides a [`Vector(dim)`](python.md#lancedb.pydantic.Vector) method to define a
|
||||
vector Field in a Pydantic Model.
|
||||
|
||||
::: lancedb.pydantic.vector
|
||||
::: lancedb.pydantic.Vector
|
||||
|
||||
## Type Conversion
|
||||
|
||||
@@ -33,4 +33,4 @@ Current supported type conversions:
|
||||
| `str` | `pyarrow.utf8()` |
|
||||
| `list` | `pyarrow.List` |
|
||||
| `BaseModel` | `pyarrow.Struct` |
|
||||
| `vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
|
||||
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
|
||||
|
||||
@@ -26,15 +26,19 @@ pip install lancedb
|
||||
|
||||
## Embeddings
|
||||
|
||||
::: lancedb.embeddings.with_embeddings
|
||||
|
||||
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
||||
|
||||
::: lancedb.embeddings.functions.EmbeddingFunctionModel
|
||||
::: lancedb.embeddings.functions.EmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.TextEmbeddingFunctionModel
|
||||
|
||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddingFunction
|
||||
::: lancedb.embeddings.functions.TextEmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenAIEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenClipEmbeddings
|
||||
|
||||
::: lancedb.embeddings.with_embeddings
|
||||
|
||||
## Context
|
||||
|
||||
|
||||
105
node/package-lock.json
generated
105
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.2.4",
|
||||
"version": "0.2.6",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.2.4",
|
||||
"version": "0.2.6",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -31,6 +31,7 @@
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/sinon": "^10.0.15",
|
||||
"@types/temp": "^0.9.1",
|
||||
"@types/uuid": "^9.0.3",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
"cargo-cp-artifact": "^0.1",
|
||||
"chai": "^4.3.7",
|
||||
@@ -48,14 +49,15 @@
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typedoc": "^0.24.7",
|
||||
"typedoc-plugin-markdown": "^3.15.3",
|
||||
"typescript": "*"
|
||||
"typescript": "*",
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.4",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.4",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.4"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.6",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.6",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
|
||||
}
|
||||
},
|
||||
"node_modules/@apache-arrow/ts": {
|
||||
@@ -315,9 +317,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz",
|
||||
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
|
||||
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -327,9 +329,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz",
|
||||
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
|
||||
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -339,9 +341,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -351,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -363,9 +365,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz",
|
||||
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
|
||||
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -596,6 +598,12 @@
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/uuid": {
|
||||
"version": "9.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
|
||||
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@typescript-eslint/eslint-plugin": {
|
||||
"version": "5.59.1",
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
|
||||
@@ -4451,6 +4459,15 @@
|
||||
"punycode": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/uuid": {
|
||||
"version": "9.0.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
|
||||
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
|
||||
"dev": true,
|
||||
"bin": {
|
||||
"uuid": "dist/bin/uuid"
|
||||
}
|
||||
},
|
||||
"node_modules/v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
@@ -4852,33 +4869,33 @@
|
||||
}
|
||||
},
|
||||
"@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz",
|
||||
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
|
||||
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz",
|
||||
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
|
||||
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz",
|
||||
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
|
||||
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.2.4",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz",
|
||||
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==",
|
||||
"version": "0.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
|
||||
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@neon-rs/cli": {
|
||||
@@ -5093,6 +5110,12 @@
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"@types/uuid": {
|
||||
"version": "9.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
|
||||
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
|
||||
"dev": true
|
||||
},
|
||||
"@typescript-eslint/eslint-plugin": {
|
||||
"version": "5.59.1",
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
|
||||
@@ -7844,6 +7867,12 @@
|
||||
"punycode": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"uuid": {
|
||||
"version": "9.0.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
|
||||
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
|
||||
"dev": true
|
||||
},
|
||||
"v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.2.4",
|
||||
"version": "0.2.6",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -9,6 +9,7 @@
|
||||
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
|
||||
"build-release": "npm run build -- --release",
|
||||
"test": "npm run tsc && mocha -recursive dist/test",
|
||||
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
||||
"lint": "eslint native.js src --ext .js,.ts",
|
||||
"clean": "rm -rf node_modules *.node dist/",
|
||||
"pack-build": "neon pack-build",
|
||||
@@ -34,6 +35,7 @@
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/sinon": "^10.0.15",
|
||||
"@types/temp": "^0.9.1",
|
||||
"@types/uuid": "^9.0.3",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
"cargo-cp-artifact": "^0.1",
|
||||
"chai": "^4.3.7",
|
||||
@@ -51,7 +53,8 @@
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typedoc": "^0.24.7",
|
||||
"typedoc-plugin-markdown": "^3.15.3",
|
||||
"typescript": "*"
|
||||
"typescript": "*",
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@apache-arrow/ts": "^12.0.0",
|
||||
@@ -78,10 +81,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.4",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.4",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.4",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.4"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.2.6",
|
||||
"@lancedb/vectordb-darwin-x64": "0.2.6",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
|
||||
}
|
||||
}
|
||||
|
||||
43
node/src/integration_test/test.ts
Normal file
43
node/src/integration_test/test.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import * as chai from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
|
||||
const assert = chai.assert
|
||||
chai.use(chaiAsPromised)
|
||||
|
||||
describe('LanceDB AWS Integration test', function () {
|
||||
it('s3+ddb schema is processed correctly', async function () {
|
||||
this.timeout(15000)
|
||||
|
||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||
// THE API WILL CHANGE
|
||||
const conn = await lancedb.connect('s3://lancedb-integtest?engine=ddb&ddbTableName=lancedb-integtest')
|
||||
const data = [{ vector: Array(128).fill(1.0) }]
|
||||
|
||||
const tableName = uuidv4()
|
||||
let table = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
|
||||
|
||||
const futs = [table.add(data), table.add(data), table.add(data), table.add(data), table.add(data)]
|
||||
await Promise.allSettled(futs)
|
||||
|
||||
table = await conn.openTable(tableName)
|
||||
assert.equal(await table.countRows(), 6)
|
||||
})
|
||||
})
|
||||
@@ -19,7 +19,7 @@ import * as chaiAsPromised from 'chai-as-promised'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
|
||||
import { Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray } from 'apache-arrow'
|
||||
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
|
||||
|
||||
const expect = chai.expect
|
||||
const assert = chai.assert
|
||||
@@ -258,6 +258,36 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
})
|
||||
|
||||
describe('when searching an empty dataset', function () {
|
||||
it('should not fail', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||
assert.isEmpty(result)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when searching an empty-after-delete dataset', function () {
|
||||
it('should not fail', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
await table.add([{ vector: Array(128).fill(0.1) }])
|
||||
await table.delete('vector IS NOT NULL')
|
||||
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||
assert.isEmpty(result)
|
||||
})
|
||||
})
|
||||
|
||||
describe('when creating a vector index', function () {
|
||||
it('overwrite all records in a table', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.2.2
|
||||
current_version = 0.2.5
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -11,12 +11,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.metadata
|
||||
from typing import Optional
|
||||
|
||||
from .db import URI, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
|
||||
def connect(
|
||||
uri: URI,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
|
||||
import pyarrow as pa
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
||||
from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
|
||||
|
||||
# import lancedb so we don't have to in every example
|
||||
|
||||
@@ -22,17 +22,19 @@ def doctest_setup(monkeypatch, tmpdir):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
|
||||
@registry.register()
|
||||
class MockEmbeddingFunction(EmbeddingFunctionModel):
|
||||
def __call__(self, data):
|
||||
if isinstance(data, str):
|
||||
data = [data]
|
||||
elif isinstance(data, pa.ChunkedArray):
|
||||
data = data.combine_chunks().to_pylist()
|
||||
elif isinstance(data, pa.Array):
|
||||
data = data.to_pylist()
|
||||
@registry.register("test")
|
||||
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
Return the hash of the first 10 characters
|
||||
"""
|
||||
|
||||
return [self.embed(row) for row in data]
|
||||
def generate_embeddings(self, texts):
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
|
||||
def embed(self, row):
|
||||
return [float(hash(c)) for c in row[:10]]
|
||||
def _compute_one_embedding(self, row):
|
||||
emb = np.array([float(hash(c)) for c in row[:10]])
|
||||
emb /= np.linalg.norm(emb)
|
||||
return emb
|
||||
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
@@ -22,7 +22,7 @@ import pyarrow as pa
|
||||
from pyarrow import fs
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionModel
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||
@@ -290,7 +290,7 @@ class LanceDBConnection(DBConnection):
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionModel]] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
|
||||
@@ -13,10 +13,12 @@
|
||||
|
||||
|
||||
from .functions import (
|
||||
REGISTRY,
|
||||
EmbeddingFunctionModel,
|
||||
EmbeddingFunction,
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
TextEmbeddingFunctionModel,
|
||||
OpenAIEmbeddings,
|
||||
OpenClipEmbeddings,
|
||||
SentenceTransformerEmbeddings,
|
||||
TextEmbeddingFunction,
|
||||
)
|
||||
from .utils import with_embeddings
|
||||
|
||||
@@ -10,43 +10,78 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import urllib.error
|
||||
import urllib.parse as urlparse
|
||||
import urllib.request
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return REGISTRY
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self):
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunctionModel.
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunctionModel):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
self._functions[cls.__name__] = cls
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -57,13 +92,22 @@ class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def load(self, name: str):
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(self, metadata: Optional[dict]) -> dict:
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
@@ -71,9 +115,9 @@ class EmbeddingFunctionRegistry:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[dict]
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes.
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -86,68 +130,94 @@ class EmbeddingFunctionRegistry:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
functions = {}
|
||||
for obj in raw_list:
|
||||
model = self.load(obj["schema"]["title"])
|
||||
functions[obj["model"]["vector_column"]] = model(**obj["model"])
|
||||
return functions
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, func):
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
schema = func.model_json_schema()
|
||||
json_data = func.model_dump()
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"schema": schema,
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector column configs
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
REGISTRY = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
class EmbeddingFunctionModel(BaseModel, ABC):
|
||||
"""
|
||||
A callable ABC for embedding functions
|
||||
"""
|
||||
|
||||
source_column: Optional[str]
|
||||
vector_column: str
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> List[np.array]:
|
||||
pass
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
|
||||
|
||||
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
An ABC for embedding functions.
|
||||
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function. This is called
|
||||
before generate_embeddings() and is useful for stripping
|
||||
whitespace, lowercasing, etc.
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
@@ -157,6 +227,78 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
@@ -167,15 +309,25 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
||||
pass
|
||||
|
||||
|
||||
@REGISTRY.register()
|
||||
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
@@ -186,6 +338,11 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
@@ -220,9 +377,201 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
sentence_transformers = cls.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
|
||||
return SentenceTransformer(name, device=device)
|
||||
except ImportError:
|
||||
raise ValueError("Please install sentence_transformers")
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
openai = self.safe_import("openai")
|
||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||
return [v["embedding"] for v in rs]
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in futures]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor)
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
|
||||
@@ -26,6 +26,8 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
@@ -46,7 +48,19 @@ class FixedSizeListMixin(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vector(
|
||||
def vector(dim: int, value_type: pa.DataType = pa.float32()):
|
||||
# TODO: remove in future release
|
||||
from warnings import warn
|
||||
|
||||
warn(
|
||||
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
|
||||
"This function will be removed in future release",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return Vector(dim, value_type)
|
||||
|
||||
|
||||
def Vector(
|
||||
dim: int, value_type: pa.DataType = pa.float32()
|
||||
) -> Type[FixedSizeListMixin]:
|
||||
"""Pydantic Vector Type.
|
||||
@@ -65,12 +79,12 @@ def vector(
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import vector
|
||||
>>> from lancedb.pydantic import Vector
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... url: str
|
||||
... embeddings: vector(768)
|
||||
... embeddings: Vector(768)
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
@@ -114,7 +128,7 @@ def vector(
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
|
||||
raise TypeError("A list of numbers or numpy.ndarray is needed")
|
||||
return v
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@@ -224,27 +238,18 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
||||
>>> from typing import List, Optional
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import pydantic_to_schema
|
||||
...
|
||||
>>> class InnerModel(pydantic.BaseModel):
|
||||
... a: str
|
||||
... b: Optional[float]
|
||||
>>>
|
||||
>>> class FooModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... s: Optional[str] = None
|
||||
... s: str
|
||||
... vec: List[float]
|
||||
... li: List[int]
|
||||
... inner: InnerModel
|
||||
...
|
||||
>>> schema = pydantic_to_schema(FooModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("s", pa.utf8(), True),
|
||||
... pa.field("s", pa.utf8(), False),
|
||||
... pa.field("vec", pa.list_(pa.float64()), False),
|
||||
... pa.field("li", pa.list_(pa.int64()), False),
|
||||
... pa.field("inner", pa.struct([
|
||||
... pa.field("a", pa.utf8(), False),
|
||||
... pa.field("b", pa.float64(), True),
|
||||
... ]), False),
|
||||
... ])
|
||||
"""
|
||||
fields = _pydantic_model_to_fields(model)
|
||||
@@ -258,11 +263,11 @@ class LanceModel(pydantic.BaseModel):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> from lancedb.pydantic import LanceModel, vector
|
||||
>>> from lancedb.pydantic import LanceModel, Vector
|
||||
>>>
|
||||
>>> class TestModel(LanceModel):
|
||||
... name: str
|
||||
... vector: vector(2)
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("/tmp")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
@@ -278,13 +283,58 @@ class LanceModel(pydantic.BaseModel):
|
||||
"""
|
||||
Get the Arrow Schema for this model.
|
||||
"""
|
||||
return pydantic_to_schema(cls)
|
||||
schema = pydantic_to_schema(cls)
|
||||
functions = cls.parse_embedding_functions()
|
||||
if len(functions) > 0:
|
||||
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
|
||||
functions
|
||||
)
|
||||
schema = schema.with_metadata(metadata)
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def field_names(cls) -> List[str]:
|
||||
"""
|
||||
Get the field names of this model.
|
||||
"""
|
||||
return list(cls.safe_get_fields().keys())
|
||||
|
||||
@classmethod
|
||||
def safe_get_fields(cls):
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return list(cls.__fields__.keys())
|
||||
return list(cls.model_fields.keys())
|
||||
return cls.__fields__
|
||||
return cls.model_fields
|
||||
|
||||
@classmethod
|
||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the embedding functions from this model.
|
||||
"""
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
vec_and_function = []
|
||||
for name, field_info in cls.safe_get_fields().items():
|
||||
func = get_extras(field_info, "vector_column_for")
|
||||
if func is not None:
|
||||
vec_and_function.append([name, func])
|
||||
|
||||
configs = []
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func == func:
|
||||
configs.append(
|
||||
EmbeddingFunctionConfig(
|
||||
source_column=source, vector_column=vec, function=func
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
if PYDANTIC_VERSION.major >= 2:
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
@@ -60,13 +60,15 @@ class LanceQueryBuilder(ABC):
|
||||
def create(
|
||||
cls,
|
||||
table: "lancedb.table.Table",
|
||||
query: Optional[Union[np.ndarray, str]],
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# convert "auto" query_type to "vector" or "fts"
|
||||
# and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
@@ -90,30 +92,27 @@ class LanceQueryBuilder(ABC):
|
||||
# otherwise raise TypeError
|
||||
if query_type == "fts":
|
||||
if not isinstance(query, str):
|
||||
raise TypeError(
|
||||
f"Query type is 'fts' but query is not a string: {type(query)}"
|
||||
)
|
||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
# If query_type is vector, then query must be a list or np.ndarray.
|
||||
# otherwise raise TypeError
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
|
||||
)
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
return query, query_type
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
elif isinstance(query, str):
|
||||
func = table.embedding_functions.get(vector_column_name, None)
|
||||
if func is not None:
|
||||
query = func(query)[0]
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
else:
|
||||
raise TypeError("Query must be a list, np.ndarray, or str")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
@@ -238,7 +237,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.Table",
|
||||
query: Union[np.ndarray, list],
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
super().__init__(table)
|
||||
|
||||
@@ -18,10 +18,9 @@ from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.common import DATA
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.table import Table, _sanitize_data
|
||||
|
||||
from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..table import Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ from lance.dataset import ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
from .embeddings.functions import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
@@ -81,15 +82,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
vector column to the table.
|
||||
"""
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
for vector_col, func in functions.items():
|
||||
if vector_col not in data.column_names:
|
||||
col_data = func(data[func.source_column])
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
if vector_column not in data.column_names:
|
||||
col_data = func.compute_source_embeddings(data[conf.source_column])
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_col).type
|
||||
dtype = schema.field(vector_column).type
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
data = data.append_column(
|
||||
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -102,7 +104,8 @@ def _to_record_batch_generator(
|
||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for batch in table.to_batches():
|
||||
yield batch
|
||||
yield batch
|
||||
else:
|
||||
yield batch
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
@@ -229,7 +232,7 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -238,7 +241,7 @@ class Table(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, default None
|
||||
query: str, list, np.ndarray, PIL.Image.Image, default None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
@@ -248,6 +251,8 @@ class Table(ABC):
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
@@ -523,6 +528,9 @@ class LanceTable(Table):
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""Add data to the table.
|
||||
If vector columns are missing and the table
|
||||
has embedding functions, then the vector columns
|
||||
are automatically computed and added.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -616,12 +624,6 @@ class LanceTable(Table):
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
def _get_embedding_function_for_source_col(self, column_name: str):
|
||||
for k, v in self.embedding_functions.items():
|
||||
if v.source_column == column_name:
|
||||
return v
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
"""
|
||||
@@ -639,7 +641,7 @@ class LanceTable(Table):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -648,7 +650,7 @@ class LanceTable(Table):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, or None
|
||||
query: str, list, np.ndarray, a PIL Image or None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
@@ -657,9 +659,11 @@ class LanceTable(Table):
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If the query is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If the query is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
table has embedding functions, else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -683,7 +687,7 @@ class LanceTable(Table):
|
||||
mode="create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: List[EmbeddingFunctionModel] = None,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -726,10 +730,16 @@ class LanceTable(Table):
|
||||
"""
|
||||
tbl = LanceTable(db, name)
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
|
||||
@@ -70,7 +70,11 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
|
||||
Get a PyArrow FileSystem from a URI, handling extra environment variables.
|
||||
"""
|
||||
if get_uri_scheme(uri) == "s3":
|
||||
fs = pa_fs.S3FileSystem(endpoint_override=os.environ.get("AWS_ENDPOINT"))
|
||||
fs = pa_fs.S3FileSystem(
|
||||
endpoint_override=os.environ.get("AWS_ENDPOINT"),
|
||||
request_timeout=30,
|
||||
connect_timeout=30,
|
||||
)
|
||||
path = get_uri_location(uri)
|
||||
return fs, path
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.2.2"
|
||||
version = "0.2.5"
|
||||
dependencies = [
|
||||
"pylance==0.6.5",
|
||||
"pylance==0.7.4",
|
||||
"ratelimiter",
|
||||
"retry",
|
||||
"tqdm",
|
||||
@@ -44,9 +44,11 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio"]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
@@ -54,3 +56,10 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio"
|
||||
]
|
||||
@@ -17,7 +17,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
def test_basic(tmp_path):
|
||||
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
|
||||
|
||||
def test_ingest_iterator(tmp_path):
|
||||
class PydanticSchema(LanceModel):
|
||||
vector: vector(2)
|
||||
vector: Vector(2)
|
||||
item: str
|
||||
price: float
|
||||
|
||||
@@ -136,13 +136,12 @@ def test_ingest_iterator(tmp_path):
|
||||
def run_tests(schema):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||
|
||||
tbl.to_pandas()
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
|
||||
tbl_len = len(tbl)
|
||||
tbl.add(make_batches())
|
||||
assert tbl_len == 50
|
||||
assert len(tbl) == tbl_len * 2
|
||||
assert len(tbl.list_versions()) == 3
|
||||
db.drop_database()
|
||||
|
||||
@@ -16,8 +16,12 @@ import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
@@ -54,8 +58,12 @@ def test_embedding_function(tmp_path):
|
||||
"vector": [np.random.randn(10), np.random.randn(10)],
|
||||
}
|
||||
)
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
metadata = registry.get_table_metadata([func])
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
|
||||
# Write it to disk
|
||||
@@ -65,14 +73,13 @@ def test_embedding_function(tmp_path):
|
||||
ds = lance.dataset(tmp_path / "test.lance")
|
||||
|
||||
# can we get the serialized version back out?
|
||||
functions = registry.parse_functions(ds.schema.metadata)
|
||||
configs = registry.parse_functions(ds.schema.metadata)
|
||||
|
||||
func = functions["vector"]
|
||||
actual = func("hello world")
|
||||
conf = configs["vector"]
|
||||
func = conf.function
|
||||
actual = func.compute_query_embeddings("hello world")
|
||||
|
||||
# We create an instance
|
||||
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
# And we make sure we can call it
|
||||
expected = expected_func("hello world")
|
||||
expected = func.compute_query_embeddings("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
125
python/tests/test_embeddings_slow.py
Normal file
125
python/tests/test_embeddings_slow.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import lancedb
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
# or connection to external api
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_sentence_transformer(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get(alias).create()
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"text": [
|
||||
"hello world",
|
||||
"goodbye world",
|
||||
"fizz",
|
||||
"buzz",
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
|
||||
vec = func.compute_query_embeddings(query)[0]
|
||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||
assert actual.text == expected.text
|
||||
assert actual.text == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("open-clip").create()
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
||||
assert actual.label == "dog"
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == frombytes.label
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
||||
assert actual.label == "dog"
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == other.label
|
||||
|
||||
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
||||
assert np.allclose(
|
||||
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||
)
|
||||
@@ -19,8 +19,9 @@ from typing import List, Optional
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -107,7 +108,7 @@ def test_pydantic_to_arrow_py38():
|
||||
|
||||
def test_fixed_size_list_field():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: vector(16)
|
||||
vec: Vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
|
||||
@@ -154,7 +155,7 @@ def test_fixed_size_list_field():
|
||||
|
||||
def test_fixed_size_list_validation():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: vector(8)
|
||||
vec: Vector(8)
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=range(9))
|
||||
@@ -167,9 +168,12 @@ def test_fixed_size_list_validation():
|
||||
|
||||
def test_lance_model():
|
||||
class TestModel(LanceModel):
|
||||
vec: vector(16)
|
||||
li: List[int]
|
||||
vector: Vector(16) = Field(default=[0.0] * 16)
|
||||
li: List[int] = Field(default=[1, 2, 3])
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["vec", "li"]
|
||||
assert TestModel.field_names() == ["vector", "li"]
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
@@ -20,7 +20,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -67,7 +67,7 @@ def table(tmp_path) -> MockTable:
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(2)
|
||||
vector: Vector(2)
|
||||
id: int
|
||||
str_field: str
|
||||
float_field: float
|
||||
|
||||
@@ -22,9 +22,10 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -140,7 +141,7 @@ def test_add(db):
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(16)
|
||||
vector: Vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
|
||||
@@ -354,22 +355,25 @@ def test_update(db):
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
func = MockTextEmbeddingFunction()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func(texts)})
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text", vector_column="vector", function=func
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
embedding_functions=[conf],
|
||||
)
|
||||
table.add(df)
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
query_vector = func.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
@@ -377,17 +381,13 @@ def test_create_with_embedding_function(db):
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
)
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
@@ -397,7 +397,7 @@ def test_add_with_embedding_function(db):
|
||||
table.add([{"text": t} for t in texts])
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
query_vector = emb.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
@@ -407,8 +407,8 @@ def test_add_with_embedding_function(db):
|
||||
def test_multiple_vector_columns(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: vector(10)
|
||||
vector2: vector(10)
|
||||
vector1: Vector(10)
|
||||
vector2: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.2.4"
|
||||
version = "0.2.6"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
@@ -18,6 +18,7 @@ once_cell = "1"
|
||||
futures = "0.3"
|
||||
half = { workspace = true }
|
||||
lance = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
vectordb = { path = "../../vectordb" }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
|
||||
|
||||
@@ -28,7 +28,9 @@ fn validate_vector_column(record_batch: &RecordBatch) -> Result<()> {
|
||||
record_batch
|
||||
.column_by_name(VECTOR_COLUMN_NAME)
|
||||
.map(|_| ())
|
||||
.context(MissingColumnSnafu { name: VECTOR_COLUMN_NAME })
|
||||
.context(MissingColumnSnafu {
|
||||
name: VECTOR_COLUMN_NAME,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::MetricType;
|
||||
use lance_linalg::distance::MetricType;
|
||||
use neon::context::FunctionContext;
|
||||
use neon::prelude::*;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
@@ -183,11 +183,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let aws_region = get_aws_region(&mut cx, 4)?;
|
||||
|
||||
let params = ReadParams {
|
||||
store_options: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
store_options: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
..ReadParams::default()
|
||||
};
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::ops::Deref;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use lance::index::vector::MetricType;
|
||||
use lance_linalg::distance::MetricType;
|
||||
use neon::context::FunctionContext;
|
||||
use neon::handle::Handle;
|
||||
use neon::prelude::*;
|
||||
|
||||
@@ -43,7 +43,8 @@ impl JsTable {
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let buffer = cx.argument::<JsBuffer>(1)?;
|
||||
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
let (batches, schema) =
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
|
||||
// Write mode
|
||||
let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() {
|
||||
@@ -65,11 +66,9 @@ impl JsTable {
|
||||
let aws_region = get_aws_region(&mut cx, 6)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
mode: mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
@@ -92,7 +91,8 @@ impl JsTable {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let buffer = cx.argument::<JsBuffer>(0)?;
|
||||
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
||||
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
let (batches, schema) =
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let mut table = js_table.table.clone();
|
||||
@@ -108,11 +108,9 @@ impl JsTable {
|
||||
let aws_region = get_aws_region(&mut cx, 5)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
mode: write_mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.2.4"
|
||||
version = "0.2.6"
|
||||
edition = "2021"
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
@@ -10,14 +10,21 @@ categories = ["database-implementations"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
arrow = { workspace = true }
|
||||
arrow-array = { workspace = true }
|
||||
arrow-data = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
half = { workspace = true }
|
||||
lance = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log = { workspace = true }
|
||||
num-traits = "0"
|
||||
url = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5.0"
|
||||
|
||||
15
rust/vectordb/src/arrow.rs
Normal file
15
rust/vectordb/src/arrow.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub use lance::arrow::*;
|
||||
18
rust/vectordb/src/data.rs
Normal file
18
rust/vectordb/src/data.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Data types, schema coercion, and data cleaning and etc.
|
||||
|
||||
pub mod inspect;
|
||||
pub mod sanitize;
|
||||
180
rust/vectordb/src/data/inspect.rs
Normal file
180
rust/vectordb/src/data/inspect.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow::compute::kernels::{aggregate::bool_and, length::length};
|
||||
use arrow_array::{
|
||||
cast::AsArray,
|
||||
types::{ArrowPrimitiveType, Int32Type, Int64Type},
|
||||
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
|
||||
};
|
||||
use arrow_ord::comparison::eq_dyn_scalar;
|
||||
use arrow_schema::DataType;
|
||||
use num_traits::{ToPrimitive, Zero};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
pub(crate) fn infer_dimension<T: ArrowPrimitiveType>(
|
||||
list_arr: &GenericListArray<T::Native>,
|
||||
) -> Result<Option<T::Native>>
|
||||
where
|
||||
T::Native: OffsetSizeTrait + ToPrimitive,
|
||||
{
|
||||
let len_arr = length(list_arr)?;
|
||||
if len_arr.is_empty() {
|
||||
return Ok(Some(Zero::zero()));
|
||||
}
|
||||
|
||||
let dim = len_arr.as_primitive::<T>().value(0);
|
||||
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(dim))
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer the vector columns from a dataset.
|
||||
///
|
||||
/// Parameters
|
||||
/// ----------
|
||||
/// - reader: RecordBatchReader
|
||||
/// - strict: if set true, only fixed_size_list<float> is considered as vector column. If set to false,
|
||||
/// a list<float> column with same length is also considered as vector column.
|
||||
pub fn infer_vector_columns(
|
||||
reader: impl RecordBatchReader + Send,
|
||||
strict: bool,
|
||||
) -> Result<Vec<String>> {
|
||||
let mut columns = vec![];
|
||||
|
||||
let mut columns_to_infer: HashMap<String, Option<i64>> = HashMap::new();
|
||||
for field in reader.schema().fields() {
|
||||
match field.data_type() {
|
||||
DataType::FixedSizeList(sub_field, _) if sub_field.data_type().is_floating() => {
|
||||
columns.push(field.name().to_string());
|
||||
}
|
||||
DataType::List(sub_field) if sub_field.data_type().is_floating() && !strict => {
|
||||
columns_to_infer.insert(field.name().to_string(), None);
|
||||
}
|
||||
DataType::LargeList(sub_field) if sub_field.data_type().is_floating() && !strict => {
|
||||
columns_to_infer.insert(field.name().to_string(), None);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for batch in reader {
|
||||
let batch = batch?;
|
||||
let col_names = columns_to_infer.keys().cloned().collect::<Vec<_>>();
|
||||
for col_name in col_names {
|
||||
let col = batch.column_by_name(&col_name).ok_or(Error::Schema {
|
||||
message: format!("Column {} not found", col_name),
|
||||
})?;
|
||||
if let Some(dim) = match *col.data_type() {
|
||||
DataType::List(_) => {
|
||||
infer_dimension::<Int32Type>(col.as_list::<i32>())?.map(|d| d as i64)
|
||||
}
|
||||
DataType::LargeList(_) => infer_dimension::<Int64Type>(col.as_list::<i64>())?,
|
||||
_ => {
|
||||
return Err(Error::Schema {
|
||||
message: format!("Column {} is not a list", col_name),
|
||||
})
|
||||
}
|
||||
} {
|
||||
if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) {
|
||||
if prev_dim != &dim {
|
||||
columns_to_infer.remove(&col_name);
|
||||
}
|
||||
} else {
|
||||
columns_to_infer.insert(col_name, Some(dim));
|
||||
}
|
||||
} else {
|
||||
columns_to_infer.remove(&col_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
columns.extend(columns_to_infer.keys().cloned());
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use arrow_array::{
|
||||
types::{Float32Type, Float64Type},
|
||||
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use std::{sync::Arc, vec};
|
||||
|
||||
#[test]
|
||||
fn test_infer_vector_columns() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("f", DataType::Float32, false),
|
||||
Field::new("s", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"l1",
|
||||
DataType::List(Arc::new(Field::new("item", DataType::Float32, true))),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"l2",
|
||||
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"fl",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 32),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])),
|
||||
Arc::new(StringArray::from(vec!["a", "b", "c"])),
|
||||
Arc::new(ListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..3).map(|_| Some(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)])),
|
||||
)),
|
||||
// Var-length list
|
||||
Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
|
||||
Some(vec![Some(1.0_f64)]),
|
||||
Some(vec![Some(2.0_f64), Some(3.0_f64)]),
|
||||
Some(vec![Some(4.0_f64), Some(5.0_f64), Some(6.0_f64)]),
|
||||
])),
|
||||
Arc::new(
|
||||
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
vec![
|
||||
Some(vec![Some(1.0); 32]),
|
||||
Some(vec![Some(2.0); 32]),
|
||||
Some(vec![Some(3.0); 32]),
|
||||
],
|
||||
32,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let reader =
|
||||
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
|
||||
|
||||
let cols = infer_vector_columns(reader, false).unwrap();
|
||||
assert_eq!(cols, vec!["fl", "l1"]);
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
|
||||
let cols = infer_vector_columns(reader, true).unwrap();
|
||||
assert_eq!(cols, vec!["fl"]);
|
||||
}
|
||||
}
|
||||
284
rust/vectordb/src/data/sanitize.rs
Normal file
284
rust/vectordb/src/data/sanitize.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{iter::repeat_with, sync::Arc};
|
||||
|
||||
use arrow_array::{
|
||||
cast::AsArray,
|
||||
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
|
||||
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
};
|
||||
use arrow_cast::{can_cast_types, cast};
|
||||
use arrow_schema::{ArrowError, DataType, Field, Schema};
|
||||
use half::f16;
|
||||
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
|
||||
use log::warn;
|
||||
use num_traits::cast::AsPrimitive;
|
||||
|
||||
use super::inspect::infer_dimension;
|
||||
use crate::error::Result;
|
||||
|
||||
fn cast_array<I: ArrowNumericType, O: ArrowNumericType>(
|
||||
arr: &PrimitiveArray<I>,
|
||||
) -> Arc<PrimitiveArray<O>>
|
||||
where
|
||||
I::Native: AsPrimitive<O::Native>,
|
||||
{
|
||||
Arc::new(PrimitiveArray::<O>::from_iter_values(
|
||||
arr.values().iter().map(|v| (*v).as_()),
|
||||
))
|
||||
}
|
||||
|
||||
fn cast_float_array<I: ArrowNumericType>(
|
||||
arr: &PrimitiveArray<I>,
|
||||
dt: &DataType,
|
||||
) -> std::result::Result<Arc<dyn Array>, ArrowError>
|
||||
where
|
||||
I::Native: AsPrimitive<f64> + AsPrimitive<f32> + AsPrimitive<f16>,
|
||||
{
|
||||
match dt {
|
||||
DataType::Float16 => Ok(cast_array::<I, Float16Type>(arr)),
|
||||
DataType::Float32 => Ok(cast_array::<I, Float32Type>(arr)),
|
||||
DataType::Float64 => Ok(cast_array::<I, Float64Type>(arr)),
|
||||
_ => Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible change field: unable to coerce {:?} to {:?}",
|
||||
arr.data_type(),
|
||||
dt
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn coerce_array(
|
||||
array: &Arc<dyn Array>,
|
||||
field: &Field,
|
||||
) -> std::result::Result<Arc<dyn Array>, ArrowError> {
|
||||
if array.data_type() == field.data_type() {
|
||||
return Ok(array.clone());
|
||||
}
|
||||
match (array.data_type(), field.data_type()) {
|
||||
// Normal cast-able types.
|
||||
(adt, dt) if can_cast_types(adt, dt) => cast(&array, dt),
|
||||
// Casting between f16/f32/f64 can be lossy.
|
||||
(adt, dt) if (adt.is_floating() || dt.is_floating()) => {
|
||||
if adt.byte_width() > dt.byte_width() {
|
||||
warn!(
|
||||
"Coercing field {} {:?} to {:?} might lose precision",
|
||||
field.name(),
|
||||
adt,
|
||||
dt
|
||||
);
|
||||
}
|
||||
match adt {
|
||||
DataType::Float16 => cast_float_array(array.as_primitive::<Float16Type>(), dt),
|
||||
DataType::Float32 => cast_float_array(array.as_primitive::<Float32Type>(), dt),
|
||||
DataType::Float64 => cast_float_array(array.as_primitive::<Float64Type>(), dt),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
(adt, DataType::FixedSizeList(exp_field, exp_dim)) => match adt {
|
||||
// Cast a float fixed size array with same dimension to the expected type.
|
||||
DataType::FixedSizeList(_, dim) if dim == exp_dim => {
|
||||
let actual_sub = array.as_fixed_size_list();
|
||||
let values = coerce_array(actual_sub.values(), exp_field)?;
|
||||
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
|
||||
values.clone(),
|
||||
*dim,
|
||||
)?) as Arc<dyn Array>)
|
||||
}
|
||||
DataType::List(_) | DataType::LargeList(_) => {
|
||||
let Some(dim) = (match adt {
|
||||
DataType::List(_) => infer_dimension::<Int32Type>(array.as_list::<i32>())
|
||||
.map_err(|e| {
|
||||
ArrowError::SchemaError(format!(
|
||||
"failed to infer dimension from list: {}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
.map(|d| d as i64),
|
||||
DataType::LargeList(_) => infer_dimension::<Int64Type>(array.as_list::<i64>())
|
||||
.map_err(|e| {
|
||||
ArrowError::SchemaError(format!(
|
||||
"failed to infer dimension from large list: {}",
|
||||
e
|
||||
))
|
||||
})?,
|
||||
_ => unreachable!(),
|
||||
}) else {
|
||||
return Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
|
||||
field,
|
||||
array.data_type()
|
||||
)));
|
||||
};
|
||||
|
||||
if dim != *exp_dim as i64 {
|
||||
return Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible coerce fixed size list: expected dimension {} but got {}",
|
||||
exp_dim, dim
|
||||
)));
|
||||
}
|
||||
|
||||
let values = coerce_array(array, exp_field)?;
|
||||
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
|
||||
values.clone(),
|
||||
*exp_dim,
|
||||
)?) as Arc<dyn Array>)
|
||||
}
|
||||
_ => Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
|
||||
field,
|
||||
array.data_type()
|
||||
)))?,
|
||||
},
|
||||
_ => Err(ArrowError::SchemaError(format!(
|
||||
"Incompatible change field {}: unable to coerce {:?} to {:?}",
|
||||
field.name(),
|
||||
array.data_type(),
|
||||
field.data_type()
|
||||
)))?,
|
||||
}
|
||||
}
|
||||
|
||||
fn coerce_schema_batch(
|
||||
batch: RecordBatch,
|
||||
schema: Arc<Schema>,
|
||||
) -> std::result::Result<RecordBatch, ArrowError> {
|
||||
if batch.schema() == schema {
|
||||
return Ok(batch);
|
||||
}
|
||||
let columns = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|field| {
|
||||
batch
|
||||
.column_by_name(field.name())
|
||||
.ok_or_else(|| {
|
||||
ArrowError::SchemaError(format!("Column {} not found", field.name()))
|
||||
})
|
||||
.and_then(|c| coerce_array(c, field))
|
||||
})
|
||||
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
|
||||
RecordBatch::try_new(schema, columns)
|
||||
}
|
||||
|
||||
/// Coerce the reader (input data) to match the given [Schema].
|
||||
///
|
||||
pub fn coerce_schema(
|
||||
reader: impl RecordBatchReader + Send + 'static,
|
||||
schema: Arc<Schema>,
|
||||
) -> Result<Box<dyn RecordBatchReader + Send>> {
|
||||
if reader.schema() == schema {
|
||||
return Ok(Box::new(RecordBatchIterator::new(reader, schema)));
|
||||
}
|
||||
let s = schema.clone();
|
||||
let batches = reader
|
||||
.zip(repeat_with(move || s.clone()))
|
||||
.map(|(batch, s)| coerce_schema_batch(batch?, s));
|
||||
Ok(Box::new(RecordBatchIterator::new(batches, schema)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array,
|
||||
RecordBatch, RecordBatchIterator, StringArray,
|
||||
};
|
||||
use arrow_schema::Field;
|
||||
use half::f16;
|
||||
use lance::arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[test]
|
||||
fn test_coerce_list_to_fixed_size_list() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new(
|
||||
"fl",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 64),
|
||||
true,
|
||||
),
|
||||
Field::new("s", DataType::Utf8, true),
|
||||
Field::new("f", DataType::Float16, true),
|
||||
Field::new("i", DataType::Int32, true),
|
||||
]));
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(
|
||||
FixedSizeListArray::try_new_from_values(
|
||||
Float32Array::from_iter_values((0..256).map(|v| v as f32)),
|
||||
64,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(StringArray::from(vec![
|
||||
Some("hello"),
|
||||
Some("world"),
|
||||
Some("from"),
|
||||
Some("lance"),
|
||||
])),
|
||||
Arc::new(Float16Array::from_iter_values(
|
||||
(0..4).map(|v| f16::from_f32(v as f32)),
|
||||
)),
|
||||
Arc::new(Int32Array::from_iter_values(0..4)),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let reader =
|
||||
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
|
||||
|
||||
let expected_schema = Arc::new(Schema::new(vec![
|
||||
Field::new(
|
||||
"fl",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 64),
|
||||
true,
|
||||
),
|
||||
Field::new("s", DataType::Utf8, true),
|
||||
Field::new("f", DataType::Float64, true),
|
||||
Field::new("i", DataType::Int8, true),
|
||||
]));
|
||||
let stream = coerce_schema(reader, expected_schema.clone()).unwrap();
|
||||
let batches = stream.collect::<Vec<_>>();
|
||||
assert_eq!(batches.len(), 1);
|
||||
let batch = batches[0].as_ref().unwrap();
|
||||
assert_eq!(batch.schema(), expected_schema);
|
||||
|
||||
let expected = RecordBatch::try_new(
|
||||
expected_schema,
|
||||
vec![
|
||||
Arc::new(
|
||||
FixedSizeListArray::try_new_from_values(
|
||||
Float16Array::from_iter_values((0..256).map(|v| f16::from_f32(v as f32))),
|
||||
64,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(StringArray::from(vec![
|
||||
Some("hello"),
|
||||
Some("world"),
|
||||
Some("from"),
|
||||
Some("lance"),
|
||||
])),
|
||||
Arc::new(Float64Array::from_iter_values((0..4).map(|v| v as f64))),
|
||||
Arc::new(Int8Array::from_iter_values(0..4)),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(batch, &expected);
|
||||
}
|
||||
}
|
||||
@@ -27,12 +27,14 @@ pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
pub struct Database {
|
||||
object_store: ObjectStore,
|
||||
query_string: Option<String>,
|
||||
|
||||
pub(crate) uri: String,
|
||||
pub(crate) base_path: object_store::path::Path,
|
||||
}
|
||||
|
||||
const LANCE_EXTENSION: &str = "lance";
|
||||
const ENGINE: &str = "engine";
|
||||
|
||||
/// A connection to LanceDB
|
||||
impl Database {
|
||||
@@ -46,12 +48,73 @@ impl Database {
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub async fn connect(uri: &str) -> Result<Database> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(uri).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
match parse_res {
|
||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await,
|
||||
Ok(mut url) => {
|
||||
// iter thru the query params and extract the commit store param
|
||||
let mut engine = None;
|
||||
let mut filtered_querys = vec![];
|
||||
|
||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||
// THE API WILL CHANGE
|
||||
for (key, value) in url.query_pairs() {
|
||||
if key == ENGINE {
|
||||
engine = Some(value.to_string());
|
||||
} else {
|
||||
// to owned so we can modify the url
|
||||
filtered_querys.push((key.to_string(), value.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out the commit store query param -- it's a lancedb param
|
||||
url.query_pairs_mut().clear();
|
||||
url.query_pairs_mut().extend_pairs(filtered_querys);
|
||||
// Take a copy of the query string so we can propagate it to lance
|
||||
let query_string = url.query().map(|s| s.to_string());
|
||||
// clear the query string so we can use the url as the base uri
|
||||
// use .set_query(None) instead of .set_query("") because the latter
|
||||
// will add a trailing '?' to the url
|
||||
url.set_query(None);
|
||||
|
||||
let table_base_uri = if let Some(store) = engine {
|
||||
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
|
||||
WARN_ONCE.call_once(|| {
|
||||
log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
|
||||
});
|
||||
let old_scheme = url.scheme().to_string();
|
||||
let new_scheme = format!("{}+{}", old_scheme, store);
|
||||
url.to_string().replacen(&old_scheme, &new_scheme, 1)
|
||||
} else {
|
||||
url.to_string()
|
||||
};
|
||||
|
||||
let plain_uri = url.to_string();
|
||||
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||
}
|
||||
|
||||
Ok(Database {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
base_path,
|
||||
object_store,
|
||||
})
|
||||
}
|
||||
Err(_) => Self::open_path(uri).await,
|
||||
}
|
||||
Ok(Database {
|
||||
uri: uri.to_string(),
|
||||
}
|
||||
|
||||
async fn open_path(path: &str) -> Result<Database> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
|
||||
}
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
query_string: None,
|
||||
base_path,
|
||||
object_store,
|
||||
})
|
||||
@@ -149,17 +212,26 @@ impl Database {
|
||||
let path = Path::new(&self.uri);
|
||||
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
|
||||
let uri = table_uri
|
||||
let mut uri = table_uri
|
||||
.as_path()
|
||||
.to_str()
|
||||
.context(InvalidTableNameSnafu { name })?;
|
||||
Ok(uri.to_string())
|
||||
.context(InvalidTableNameSnafu { name })?
|
||||
.to_string();
|
||||
|
||||
// If there are query string set on the connection, propagate to lance
|
||||
if let Some(query) = self.query_string.as_ref() {
|
||||
uri.push('?');
|
||||
uri.push_str(query.as_str());
|
||||
}
|
||||
|
||||
Ok(uri)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs::create_dir_all;
|
||||
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::database::Database;
|
||||
@@ -173,6 +245,28 @@ mod tests {
|
||||
assert_eq!(db.uri, uri);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn test_connect_relative() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap();
|
||||
|
||||
let mut relative_anacestors = vec![];
|
||||
let current_dir = std::env::current_dir().unwrap();
|
||||
let mut ancestors = current_dir.ancestors();
|
||||
while let Some(_) = ancestors.next() {
|
||||
relative_anacestors.push("..");
|
||||
}
|
||||
let relative_root = std::path::PathBuf::from(relative_anacestors.join("/"));
|
||||
let relative_uri = relative_root.join(&uri);
|
||||
|
||||
let db = Database::connect(relative_uri.to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow_schema::ArrowError;
|
||||
use snafu::Snafu;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
@@ -32,10 +33,20 @@ pub enum Error {
|
||||
Store { message: String },
|
||||
#[snafu(display("LanceDBError: {message}"))]
|
||||
Lance { message: String },
|
||||
#[snafu(display("LanceDB Schema Error: {message}"))]
|
||||
Schema { message: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl From<ArrowError> for Error {
|
||||
fn from(e: ArrowError) -> Self {
|
||||
Self::Lance {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lance::Error> for Error {
|
||||
fn from(e: lance::Error) -> Self {
|
||||
Self::Lance {
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::{MetricType, VectorIndexParams};
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
use lance_linalg::distance::MetricType;
|
||||
|
||||
pub trait VectorIndexBuilder {
|
||||
fn get_column(&self) -> Option<String>;
|
||||
@@ -107,9 +108,11 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::{MetricType, StageParams};
|
||||
use lance::index::vector::StageParams;
|
||||
|
||||
use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod data;
|
||||
pub mod database;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::sync::Arc;
|
||||
use arrow_array::Float32Array;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::Dataset;
|
||||
use lance::index::vector::MetricType;
|
||||
use lance_linalg::distance::MetricType;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
@@ -164,10 +164,10 @@ impl Query {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use lance::dataset::Dataset;
|
||||
use lance::index::vector::MetricType;
|
||||
|
||||
use crate::query::Query;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user