Compare commits

..

1 Commits

Author SHA1 Message Date
Chang She
9441fde2bb reproducibility 2023-09-06 02:24:19 -07:00
49 changed files with 1533 additions and 1708 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.6
current_version = 0.2.4
commit = True
message = Bump version: {current_version} → {new_version}
tag = True

View File

@@ -9,7 +9,6 @@ on:
- node/**
- rust/ffi/node/**
- .github/workflows/node.yml
- docker-compose.yml
env:
# Disable full debug symbol generation to speed up CI build and keep memory down
@@ -108,56 +107,3 @@ jobs:
- name: Test
run: |
npm run test
aws-integtest:
timeout-minutes: 45
runs-on: "ubuntu-22.04"
defaults:
run:
shell: bash
working-directory: node
env:
AWS_ACCESS_KEY_ID: ACCESSKEY
AWS_SECRET_ACCESS_KEY: SECRETKEY
AWS_DEFAULT_REGION: us-west-2
# this one is for s3
AWS_ENDPOINT: http://localhost:4566
# this one is for dynamodb
DYNAMODB_ENDPOINT: http://localhost:4566
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
with:
node-version: 18
cache: 'npm'
cache-dependency-path: node/package-lock.json
- name: start local stack
run: docker compose -f ../docker-compose.yml up -d --wait
- name: create s3
run: aws s3 mb s3://lancedb-integtest --endpoint $AWS_ENDPOINT
- name: create ddb
run: |
aws dynamodb create-table \
--table-name lancedb-integtest \
--attribute-definitions '[{"AttributeName": "base_uri", "AttributeType": "S"}, {"AttributeName": "version", "AttributeType": "N"}]' \
--key-schema '[{"AttributeName": "base_uri", "KeyType": "HASH"}, {"AttributeName": "version", "KeyType": "RANGE"}]' \
--provisioned-throughput '{"ReadCapacityUnits": 10, "WriteCapacityUnits": 10}' \
--endpoint-url $DYNAMODB_ENDPOINT
- uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: |
npm ci
npm run tsc
npm run build
npm run pack-build
npm install --no-save ./dist/lancedb-vectordb-*.tgz
# Remove index.node to test with dependency installed
rm index.node
- name: Test
run: npm run integration-test

View File

@@ -38,7 +38,7 @@ jobs:
- name: isort
run: isort --check --diff --quiet .
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
run: pytest -x -v --durations=30 tests
- name: doctest
run: pytest --doctest-modules lancedb
mac:
@@ -65,34 +65,4 @@ jobs:
- name: Black
run: black --check --diff --no-color --quiet .
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
pydantic1x:
timeout-minutes: 30
runs-on: "ubuntu-22.04"
defaults:
run:
shell: bash
working-directory: python
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Install lancedb
run: |
pip install "pydantic<2"
pip install -e .[tests]
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock black isort
- name: Black
run: black --check --diff --no-color --quiet .
- name: isort
run: isort --check --diff --quiet .
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest
run: pytest --doctest-modules lancedb
run: pytest -x -v --durations=30 tests

View File

@@ -1,25 +1,16 @@
[workspace]
members = ["rust/ffi/node", "rust/vectordb"]
# Python package needs to be built by maturin.
exclude = ["python"]
members = [
"rust/vectordb",
"rust/ffi/node"
]
resolver = "2"
[workspace.dependencies]
lance = { "version" = "=0.7.5", "features" = ["dynamodb"] }
lance-linalg = { "version" = "=0.7.5" }
# Note that this one does not include pyarrow
arrow = { version = "43.0.0", optional = false }
lance = "=0.6.5"
arrow-array = "43.0"
arrow-data = "43.0"
arrow-ipc = "43.0"
arrow-ord = "43.0"
arrow-schema = "43.0"
arrow-arith = "43.0"
arrow-cast = "43.0"
half = { "version" = "=2.2.1", default-features = false, features = [
"num-traits"
] }
log = "0.4"
arrow-ipc = "43.0"
half = { "version" = "=2.2.1", default-features = false }
object_store = "0.6.1"
snafu = "0.7.4"
url = "2"

View File

@@ -1,18 +0,0 @@
version: "3.9"
services:
localstack:
image: localstack/localstack:0.14
ports:
- 4566:4566
environment:
- SERVICES=s3,dynamodb
- DEBUG=1
- LS_LOG=trace
- DOCKER_HOST=unix:///var/run/docker.sock
- AWS_ACCESS_KEY_ID=ACCESSKEY
- AWS_SECRET_ACCESS_KEY=SECRETKEY
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]
interval: 5s
retries: 3
start_period: 10s

View File

@@ -49,11 +49,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
db.create_table("table2", data)
db["table2"].head()
db["table2"].head()
```
!!! info "Note"
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
```python
custom_schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -66,7 +66,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
### From PyArrow Tables
You can also create LanceDB tables directly from pyarrow tables
```python
table = pa.Table.from_arrays(
[
@@ -84,28 +84,18 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
```
### From Pydantic Models
When you create an empty table without data, you must specify the table schema.
LanceDB supports creating tables by specifying a pyarrow schema or a specialized
pydantic model called `LanceModel`.
For example, the following Content model specifies a table with 5 columns:
movie_id, vector, genres, title, and imdb_id. When you create a table, you can
pass the class as the value of the `schema` parameter to `create_table`.
The `vector` column is a `Vector` type, which is a specialized pydantic type that
can be configured with the vector dimensions. It is also important to note that
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
(which itself derives from `pydantic.BaseModel`).
LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method.
```python
from lancedb.pydantic import Vector, LanceModel
from lancedb.pydantic import vector, LanceModel
class Content(LanceModel):
movie_id: int
vector: Vector(128)
vector: vector(128)
genres: str
title: str
imdb_id: int
@property
def imdb_url(self) -> str:
return f"https://www.imdb.com/title/tt{self.imdb_id}"
@@ -113,7 +103,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
import pyarrow as pa
db = lancedb.connect("~/.lancedb")
table_name = "movielens_small"
table = db.create_table(table_name, schema=Content)
table = db.create_table(table_name, schema=Content.to_arrow_schema())
```
### Using Iterators / Writing Large Datasets
@@ -123,7 +113,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types.
Here's an example using using `RecordBatch` iterator for creating tables.
```python
import pyarrow as pa
@@ -152,11 +142,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
## Creating Empty Table
You can also create empty tables in python. Initialize it with schema and later ingest data into it.
```python
import lancedb
import pyarrow as pa
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -178,8 +168,8 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
from lancedb.pydantic import LanceModel, vector
class Model(LanceModel):
vector: Vector(2)
vector: vector(2)
tbl = db.create_table("table5", schema=Model.to_arrow_schema())
```
@@ -259,7 +249,7 @@ After a table has been created, you can always add more data to it using
You can also add a large dataset batch in one go using Iterator of any supported data types.
### Adding to table using Iterator
```python
import pandas as pd
@@ -271,10 +261,10 @@ After a table has been created, you can always add more data to it using
"item": ["foo", "bar"],
"price": [10.0, 20.0],
})
tbl.add(make_batches())
```
The other arguments accepted:
| Name | Type | Description | Default |
@@ -284,7 +274,7 @@ After a table has been created, you can always add more data to it using
| on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop |
| fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 |
=== "Javascript/Typescript"
```javascript
@@ -322,7 +312,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector
# 0 1 [1.0, 2.0]
# 1 3 [5.0, 6.0]
```
```
### Delete from a list of values
@@ -335,7 +325,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector
# 0 3 [5.0, 6.0]
```
=== "Javascript/Typescript"
```javascript

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
id,quote,author
1,"Nobody exists on purpose. Nobody belongs anywhere.",Morty
2,"We're all going to die. Come watch TV.",Morty
3,"Losers look stuff up while the rest of us are carpin' all them diems.",Summer
4,"He's not a hot girl. He can't just bail on his life and set up shop in someone else's.",Beth
5,"When you are an a—hole, it doesn't matter how right you are. Nobody wants to give you the satisfaction.",Morty
6,"God's turning people into insect monsters, Beth. I'm the one beating them to death. Thank me.",Jerry
7,"Camping is just being homeless without the change.",Summer
8,"This seems like a good time for a drink and a cold, calculated speech with sinister overtones. A speech about politics, about order, brotherhood, power ... but speeches are for campaigning. Now is the time for action.",Morty
9,"Having a family doesn't mean that you stop being an individual. You know the best thing you can do for the people that depend on you? Be honest with them, even if it means setting them free.",Mr. Meeseeks
10,"If I've learned one thing, it's that before you get anywhere in life, you gotta stop listening to yourself.",Jerry
11,"I just want to go back to Hell, where everyone thinks I'm smart and funny.",Mr. Needful
12,"Hi Mr. Jellybean, I'm Morty. Im on an adventure with my grandpa.",Morty
13,"You're not the cause of your parents' misery. You're just a symptom of it.",Summer
14,"Don't deify the people who leave you.",Beth
15,"Well, then get your s—t together, get it all together, and put it in a backpack, all your s—t, so it's together. And if you gotta take it somewhere, take it somewhere, you know, take it to the s—t store and sell it, or put it in the s—t museum. I don't care what you do, you just gotta get it together. Get your s—t together.",Morty
16,"At least the devil has a job!",Summer
17,"Life is effort and I'll stop when I die!",Jerry
18,"I just killed my family! I don't care what they were!",Morty
19,"It's funny to say they are small. It's funny to say they are big.",Shrimply Pibbles
20,"You're holding me verbally hostage.",Summer
21,"Honey, stop raising your father's cholesterol so you can take a hot funeral selfie.",Beth
22,"Rick, when you say you made an exact replica of the house, did you mean, like, an exact replica?",Morty
23,"Give a gun to the lady who got pregnant with me too early and constantly makes it our problem.",Summer
24,"Say goodbye to your precious dry land! For soon it will be wet!",Mr. Nimbus
25,"Nobody's smarter than Rick, but nobody else is my dad. You're a genius at that.",Morty
26,"B—h, my generation gets traumatized for breakfast.",Summer
27,"Inception made sense!",Morty
28,"I realize now I'm attracted to you for the same reason I can't be with you: You can't change. And I have no problem with that, but it clearly means I have a problem with myself.",Unity
29,"Mr. President, if I've learned one thing today, it's that sometimes you have to not give a f—k!",Morty
30,"I didn't know freedom meant people doing stuff that sucks.",Summer
31,"How many of these are just horrible mistakes I made? I mean, maybe I'd stop making so many if I let myself learn from them.",Morty
32,"I'm a scientist because I invent, transform, create, and destroy for a living. And when I don't like something about the world, I change it.",Rick
33,"Wubba lubba dub dub!",Rick
34,"I turned myself into a pickle, Morty! I'm Pickle Rick!",Rick
35,"I know about the Yosemite T-shirt, Morty.",Rick
36,"The universe is basically an animal. It grazes on the ordinary. It creates infinite idiots just to eat them.",Rick
37,"If I die in a cage, I lose a bet.",Rick
38,"Sometimes science is more art than science.",Rick
39,"To live is to risk it all—otherwise, you're just an inert chunk of randomly assembled molecules drifting wherever the universe blows you.",Rick
40,"Welcome to the club, pal.",Rick
41,"So I have an emo streak. It's part of what makes me so rad.",Rick
42,"Listen, I'm not the nicest guy in the universe, because I'm the smartest, and being nice is something stupid people do to hedge their bets.",Rick
43,"Wait a minute! Is that Mountain Dew in my quantum-transport-solution?",Rick
44,"Listen, Morty, I hate to break it to you, but what people call 'love' is just a chemical reaction that compels animals to breed.",Rick
45,"Break the cycle, Morty. Rise above. Focus on science.",Rick
46,"Don't get drawn into the culture, Morty. Stealing stuff is about the stuff, not the stealing.",Rick
47,"I'm sorry, but your opinion means very little to me.",Rick
48,"You don't get to tell anyone what's sad. Youre like a one-man Mount Sadmore. So I guess like a Lincoln Sadmorial.",Rick
49,"This pickle doesn't care about your children. I'm not gonna take their dreams. I'm gonna take their parents.",Rick
50,"I programmed you to believe that.",Rick
51,"Have fun with empowerment. It seems to make everyone that gets it really happy.",Rick
52,"Thanks, Mr. Poopybutthole. I always could count on you.",Rick
53,"Weddings are basically funerals with a cake.",Rick
54,"I mean, if you spend all day shuffling words around, you can make anything sound bad, Morty.",Rick
55,"It's your choice to take this personally.",Rick
56,"Excuse me, coming through. What are you here for? Just kidding, I don't care.",Rick
57,"If I let you make me nervous, then we can't get schwifty.",Rick
58,"Oh, boy, so you actually learned something today? What is this, Full House?",Rick
59,"I can't abide bureaucracy. I don't like being told where to go and what to do. I consider it a violation. Did you get those seeds all the way up your butt?",Rick
60,"I think you have to think ahead and live in the moment.",Rick
61,"I know that new situations can be intimidating. You're lookin' around and it's all scary and different, but you know, meeting them head-on, charging into 'em like a bull—that's how we grow as people.",Rick
1 id quote author
2 1 Nobody exists on purpose. Nobody belongs anywhere. Morty
3 2 We're all going to die. Come watch TV. Morty
4 3 Losers look stuff up while the rest of us are carpin' all them diems. Summer
5 4 He's not a hot girl. He can't just bail on his life and set up shop in someone else's. Beth
6 5 When you are an a—hole, it doesn't matter how right you are. Nobody wants to give you the satisfaction. Morty
7 6 God's turning people into insect monsters, Beth. I'm the one beating them to death. Thank me. Jerry
8 7 Camping is just being homeless without the change. Summer
9 8 This seems like a good time for a drink and a cold, calculated speech with sinister overtones. A speech about politics, about order, brotherhood, power ... but speeches are for campaigning. Now is the time for action. Morty
10 9 Having a family doesn't mean that you stop being an individual. You know the best thing you can do for the people that depend on you? Be honest with them, even if it means setting them free. Mr. Meeseeks
11 10 If I've learned one thing, it's that before you get anywhere in life, you gotta stop listening to yourself. Jerry
12 11 I just want to go back to Hell, where everyone thinks I'm smart and funny. Mr. Needful
13 12 Hi Mr. Jellybean, I'm Morty. I’m on an adventure with my grandpa. Morty
14 13 You're not the cause of your parents' misery. You're just a symptom of it. Summer
15 14 Don't deify the people who leave you. Beth
16 15 Well, then get your s—t together, get it all together, and put it in a backpack, all your s—t, so it's together. And if you gotta take it somewhere, take it somewhere, you know, take it to the s—t store and sell it, or put it in the s—t museum. I don't care what you do, you just gotta get it together. Get your s—t together. Morty
17 16 At least the devil has a job! Summer
18 17 Life is effort and I'll stop when I die! Jerry
19 18 I just killed my family! I don't care what they were! Morty
20 19 It's funny to say they are small. It's funny to say they are big. Shrimply Pibbles
21 20 You're holding me verbally hostage. Summer
22 21 Honey, stop raising your father's cholesterol so you can take a hot funeral selfie. Beth
23 22 Rick, when you say you made an exact replica of the house, did you mean, like, an exact replica? Morty
24 23 Give a gun to the lady who got pregnant with me too early and constantly makes it our problem. Summer
25 24 Say goodbye to your precious dry land! For soon it will be wet! Mr. Nimbus
26 25 Nobody's smarter than Rick, but nobody else is my dad. You're a genius at that. Morty
27 26 B—h, my generation gets traumatized for breakfast. Summer
28 27 Inception made sense! Morty
29 28 I realize now I'm attracted to you for the same reason I can't be with you: You can't change. And I have no problem with that, but it clearly means I have a problem with myself. Unity
30 29 Mr. President, if I've learned one thing today, it's that sometimes you have to not give a f—k! Morty
31 30 I didn't know freedom meant people doing stuff that sucks. Summer
32 31 How many of these are just horrible mistakes I made? I mean, maybe I'd stop making so many if I let myself learn from them. Morty
33 32 I'm a scientist because I invent, transform, create, and destroy for a living. And when I don't like something about the world, I change it. Rick
34 33 Wubba lubba dub dub! Rick
35 34 I turned myself into a pickle, Morty! I'm Pickle Rick! Rick
36 35 I know about the Yosemite T-shirt, Morty. Rick
37 36 The universe is basically an animal. It grazes on the ordinary. It creates infinite idiots just to eat them. Rick
38 37 If I die in a cage, I lose a bet. Rick
39 38 Sometimes science is more art than science. Rick
40 39 To live is to risk it all—otherwise, you're just an inert chunk of randomly assembled molecules drifting wherever the universe blows you. Rick
41 40 Welcome to the club, pal. Rick
42 41 So I have an emo streak. It's part of what makes me so rad. Rick
43 42 Listen, I'm not the nicest guy in the universe, because I'm the smartest, and being nice is something stupid people do to hedge their bets. Rick
44 43 Wait a minute! Is that Mountain Dew in my quantum-transport-solution? Rick
45 44 Listen, Morty, I hate to break it to you, but what people call 'love' is just a chemical reaction that compels animals to breed. Rick
46 45 Break the cycle, Morty. Rise above. Focus on science. Rick
47 46 Don't get drawn into the culture, Morty. Stealing stuff is about the stuff, not the stealing. Rick
48 47 I'm sorry, but your opinion means very little to me. Rick
49 48 You don't get to tell anyone what's sad. You’re like a one-man Mount Sadmore. So I guess like a Lincoln Sadmorial. Rick
50 49 This pickle doesn't care about your children. I'm not gonna take their dreams. I'm gonna take their parents. Rick
51 50 I programmed you to believe that. Rick
52 51 Have fun with empowerment. It seems to make everyone that gets it really happy. Rick
53 52 Thanks, Mr. Poopybutthole. I always could count on you. Rick
54 53 Weddings are basically funerals with a cake. Rick
55 54 I mean, if you spend all day shuffling words around, you can make anything sound bad, Morty. Rick
56 55 It's your choice to take this personally. Rick
57 56 Excuse me, coming through. What are you here for? Just kidding, I don't care. Rick
58 57 If I let you make me nervous, then we can't get schwifty. Rick
59 58 Oh, boy, so you actually learned something today? What is this, Full House? Rick
60 59 I can't abide bureaucracy. I don't like being told where to go and what to do. I consider it a violation. Did you get those seeds all the way up your butt? Rick
61 60 I think you have to think ahead and live in the moment. Rick
62 61 I know that new situations can be intimidating. You're lookin' around and it's all scary and different, but you know, meeting them head-on, charging into 'em like a bull—that's how we grow as people. Rick

View File

@@ -249,11 +249,11 @@
}
],
"source": [
"from lancedb.pydantic import Vector, LanceModel\n",
"from lancedb.pydantic import vector, LanceModel\n",
"\n",
"class Content(LanceModel):\n",
" movie_id: int\n",
" vector: Vector(128)\n",
" vector: vector(128)\n",
" genres: str\n",
" title: str\n",
" imdb_id: int\n",
@@ -359,7 +359,7 @@
"import pandas as pd\n",
"\n",
"class PydanticSchema(LanceModel):\n",
" vector: Vector(2)\n",
" vector: vector(2)\n",
" item: str\n",
" price: float\n",
"\n",
@@ -394,10 +394,10 @@
"outputs": [],
"source": [
"import lancedb\n",
"from lancedb.pydantic import LanceModel, Vector\n",
"from lancedb.pydantic import LanceModel, vector\n",
"\n",
"class Model(LanceModel):\n",
" vector: Vector(2)\n",
" vector: vector(2)\n",
"\n",
"tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())"
]

View File

@@ -13,10 +13,10 @@ via [pydantic_to_schema()](python.md##lancedb.pydantic.pydantic_to_schema) metho
## Vector Field
LanceDB provides a [`Vector(dim)`](python.md#lancedb.pydantic.Vector) method to define a
LanceDB provides a [`vector(dim)`](python.md#lancedb.pydantic.vector) method to define a
vector Field in a Pydantic Model.
::: lancedb.pydantic.Vector
::: lancedb.pydantic.vector
## Type Conversion
@@ -33,4 +33,4 @@ Current supported type conversions:
| `str` | `pyarrow.utf8()` |
| `list` | `pyarrow.List` |
| `BaseModel` | `pyarrow.Struct` |
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
| `vector(n)` | `pyarrow.FixedSizeList(float32, n)` |

View File

@@ -26,19 +26,15 @@ pip install lancedb
## Embeddings
::: lancedb.embeddings.with_embeddings
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
::: lancedb.embeddings.functions.EmbeddingFunction
::: lancedb.embeddings.functions.EmbeddingFunctionModel
::: lancedb.embeddings.functions.TextEmbeddingFunction
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
::: lancedb.embeddings.functions.OpenAIEmbeddings
::: lancedb.embeddings.functions.OpenClipEmbeddings
::: lancedb.embeddings.with_embeddings
::: lancedb.embeddings.functions.TextEmbeddingFunctionModel
::: lancedb.embeddings.functions.SentenceTransformerEmbeddingFunction
## Context

105
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.2.6",
"version": "0.2.4",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.2.6",
"version": "0.2.4",
"cpu": [
"x64",
"arm64"
@@ -31,7 +31,6 @@
"@types/node": "^18.16.2",
"@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1",
"chai": "^4.3.7",
@@ -49,15 +48,14 @@
"ts-node-dev": "^2.0.0",
"typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3",
"typescript": "*",
"uuid": "^9.0.0"
"typescript": "*"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.2.6",
"@lancedb/vectordb-darwin-x64": "0.2.6",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
"@lancedb/vectordb-darwin-arm64": "0.2.4",
"@lancedb/vectordb-darwin-x64": "0.2.4",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4",
"@lancedb/vectordb-linux-x64-gnu": "0.2.4",
"@lancedb/vectordb-win32-x64-msvc": "0.2.4"
}
},
"node_modules/@apache-arrow/ts": {
@@ -317,9 +315,9 @@
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz",
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==",
"cpu": [
"arm64"
],
@@ -329,9 +327,9 @@
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz",
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==",
"cpu": [
"x64"
],
@@ -341,9 +339,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz",
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==",
"cpu": [
"arm64"
],
@@ -353,9 +351,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz",
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==",
"cpu": [
"x64"
],
@@ -365,9 +363,9 @@
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz",
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==",
"cpu": [
"x64"
],
@@ -598,12 +596,6 @@
"@types/node": "*"
}
},
"node_modules/@types/uuid": {
"version": "9.0.3",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
"dev": true
},
"node_modules/@typescript-eslint/eslint-plugin": {
"version": "5.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
@@ -4459,15 +4451,6 @@
"punycode": "^2.1.0"
}
},
"node_modules/uuid": {
"version": "9.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
"dev": true,
"bin": {
"uuid": "dist/bin/uuid"
}
},
"node_modules/v8-compile-cache-lib": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
@@ -4869,33 +4852,33 @@
}
},
"@lancedb/vectordb-darwin-arm64": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.4.tgz",
"integrity": "sha512-MqiZXamHYEOfguPsHWLBQ56IabIN6Az8u2Hx8LCyXcxW9gcyJZMSAfJc+CcA4KYHKotv0KsVBhgxZ3kaZQQyiw==",
"optional": true
},
"@lancedb/vectordb-darwin-x64": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.4.tgz",
"integrity": "sha512-DzL+mw5WhKDwXdEFlPh8M9zSDhGnfks7NvEh6ZqKbU6znH206YB7g3OA4WfFyV579IIEQ8jd4v/XDthNzQKuSA==",
"optional": true
},
"@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.4.tgz",
"integrity": "sha512-LP1nNfIpFxCgcCMlIQdseDX9dZU27TNhCL41xar8euqcetY5uKvi0YqhiVlpNO85Ss1FRQBgQ/GtnOM6Bo7oBQ==",
"optional": true
},
"@lancedb/vectordb-linux-x64-gnu": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.4.tgz",
"integrity": "sha512-m4RhOI5JJWPU9Ip2LlRIzXu4mwIv9M//OyAuTLiLKRm8726jQHhYi5VFUEtNzqY0o0p6pS0b3XbifYQ+cyJn3Q==",
"optional": true
},
"@lancedb/vectordb-win32-x64-msvc": {
"version": "0.2.6",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
"version": "0.2.4",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.4.tgz",
"integrity": "sha512-lMF/2e3YkKWnTYv0R7cUCfjMkAqepNaHSc/dvJzCNsFVEhfDsFdScQFLToARs5GGxnq4fOf+MKpaHg/W6QTxiA==",
"optional": true
},
"@neon-rs/cli": {
@@ -5110,12 +5093,6 @@
"@types/node": "*"
}
},
"@types/uuid": {
"version": "9.0.3",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.3.tgz",
"integrity": "sha512-taHQQH/3ZyI3zP8M/puluDEIEvtQHVYcC6y3N8ijFtAd28+Ey/G4sg1u2gB01S8MwybLOKAp9/yCMu/uR5l3Ug==",
"dev": true
},
"@typescript-eslint/eslint-plugin": {
"version": "5.59.1",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz",
@@ -7867,12 +7844,6 @@
"punycode": "^2.1.0"
}
},
"uuid": {
"version": "9.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.0.tgz",
"integrity": "sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==",
"dev": true
},
"v8-compile-cache-lib": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.2.6",
"version": "0.2.4",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",
@@ -9,7 +9,6 @@
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
"build-release": "npm run build -- --release",
"test": "npm run tsc && mocha -recursive dist/test",
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
"lint": "eslint native.js src --ext .js,.ts",
"clean": "rm -rf node_modules *.node dist/",
"pack-build": "neon pack-build",
@@ -35,7 +34,6 @@
"@types/node": "^18.16.2",
"@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1",
"chai": "^4.3.7",
@@ -53,8 +51,7 @@
"ts-node-dev": "^2.0.0",
"typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3",
"typescript": "*",
"uuid": "^9.0.0"
"typescript": "*"
},
"dependencies": {
"@apache-arrow/ts": "^12.0.0",
@@ -81,10 +78,10 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.2.6",
"@lancedb/vectordb-darwin-x64": "0.2.6",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
"@lancedb/vectordb-darwin-arm64": "0.2.4",
"@lancedb/vectordb-darwin-x64": "0.2.4",
"@lancedb/vectordb-linux-arm64-gnu": "0.2.4",
"@lancedb/vectordb-linux-x64-gnu": "0.2.4",
"@lancedb/vectordb-win32-x64-msvc": "0.2.4"
}
}

View File

@@ -1,43 +0,0 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { describe } from 'mocha'
import * as chai from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import { v4 as uuidv4 } from 'uuid'
import * as lancedb from '../index'
const assert = chai.assert
chai.use(chaiAsPromised)
describe('LanceDB AWS Integration test', function () {
it('s3+ddb schema is processed correctly', async function () {
this.timeout(15000)
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
// THE API WILL CHANGE
const conn = await lancedb.connect('s3://lancedb-integtest?engine=ddb&ddbTableName=lancedb-integtest')
const data = [{ vector: Array(128).fill(1.0) }]
const tableName = uuidv4()
let table = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
const futs = [table.add(data), table.add(data), table.add(data), table.add(data), table.add(data)]
await Promise.allSettled(futs)
table = await conn.openTable(tableName)
assert.equal(await table.countRows(), 6)
})
})

View File

@@ -19,7 +19,7 @@ import * as chaiAsPromised from 'chai-as-promised'
import * as lancedb from '../index'
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
import { Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray } from 'apache-arrow'
const expect = chai.expect
const assert = chai.assert
@@ -258,36 +258,6 @@ describe('LanceDB client', function () {
})
})
describe('when searching an empty dataset', function () {
it('should not fail', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
const result = await table.search(Array(128).fill(0.1)).execute()
assert.isEmpty(result)
})
})
describe('when searching an empty-after-delete dataset', function () {
it('should not fail', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
)
const table = await con.createTable({ name: 'vectors', schema })
await table.add([{ vector: Array(128).fill(0.1) }])
await table.delete('vector IS NOT NULL')
const result = await table.search(Array(128).fill(0.1)).execute()
assert.isEmpty(result)
})
})
describe('when creating a vector index', function () {
it('overwrite all records in a table', async function () {
const uri = await createTestDB(32, 300)

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.5
current_version = 0.2.2
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -11,15 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
from typing import Optional
from .db import URI, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .schema import vector
__version__ = importlib.metadata.version("lancedb")
def connect(
uri: URI,

View File

@@ -1,9 +1,9 @@
import os
import numpy as np
import pyarrow as pa
import pytest
from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
# import lancedb so we don't have to in every example
@@ -22,19 +22,17 @@ def doctest_setup(monkeypatch, tmpdir):
registry = EmbeddingFunctionRegistry.get_instance()
@registry.register("test")
class MockTextEmbeddingFunction(TextEmbeddingFunction):
"""
Return the hash of the first 10 characters
"""
@registry.register()
class MockEmbeddingFunction(EmbeddingFunctionModel):
def __call__(self, data):
if isinstance(data, str):
data = [data]
elif isinstance(data, pa.ChunkedArray):
data = data.combine_chunks().to_pylist()
elif isinstance(data, pa.Array):
data = data.to_pylist()
def generate_embeddings(self, texts):
return [self._compute_one_embedding(row) for row in texts]
return [self.embed(row) for row in data]
def _compute_one_embedding(self, row):
emb = np.array([float(hash(c)) for c in row[:10]])
emb /= np.linalg.norm(emb)
return emb
def ndims(self):
return 10
def embed(self, row):
return [float(hash(c)) for c in row[:10]]

View File

@@ -22,7 +22,7 @@ import pyarrow as pa
from pyarrow import fs
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .embeddings import EmbeddingFunctionModel
from .pydantic import LanceModel
from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme
@@ -290,7 +290,7 @@ class LanceDBConnection(DBConnection):
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
embedding_functions: Optional[List[EmbeddingFunctionModel]] = None,
) -> LanceTable:
"""Create a table in the database.

View File

@@ -13,12 +13,10 @@
from .functions import (
EmbeddingFunction,
EmbeddingFunctionConfig,
REGISTRY,
EmbeddingFunctionModel,
EmbeddingFunctionRegistry,
OpenAIEmbeddings,
OpenClipEmbeddings,
SentenceTransformerEmbeddings,
TextEmbeddingFunction,
SentenceTransformerEmbeddingFunction,
TextEmbeddingFunctionModel,
)
from .utils import with_embeddings

View File

@@ -10,78 +10,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import importlib
import io
import json
import os
import socket
import urllib.error
import urllib.parse as urlparse
import urllib.request
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union
import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel, Field, PrivateAttr
from pydantic import BaseModel
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing.
You can implement your own embedding function by subclassing EmbeddingFunction
or TextEmbeddingFunction and registering it with the registry.
Examples
--------
>>> registry = EmbeddingFunctionRegistry.get_instance()
>>> @registry.register("my-embedding-function")
... class MyEmbeddingFunction(EmbeddingFunction):
... def ndims(self) -> int:
... return 128
...
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
... return self.compute_source_embeddings(query, *args, **kwargs)
...
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
...
>>> registry.get("my-embedding-function")
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
and fetch them by name. It also handles serializing and deserializing
"""
@classmethod
def get_instance(cls):
return __REGISTRY__
return REGISTRY
def __init__(self):
self._functions = {}
def register(self, alias: str = None):
def register(self):
"""
This creates a decorator that can be used to register
an EmbeddingFunction.
Parameters
----------
alias : Optional[str]
a human friendly name for the embedding function. If not
provided, the class name will be used.
an EmbeddingFunctionModel.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunction):
raise TypeError("Must be a subclass of EmbeddingFunction")
if not issubclass(cls, EmbeddingFunctionModel):
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
key = alias or cls.__name__
self._functions[key] = cls
cls.__embedding_function_registry_alias__ = alias
self._functions[cls.__name__] = cls
return cls
return decorator
@@ -92,22 +57,13 @@ class EmbeddingFunctionRegistry:
"""
self._functions = {}
def get(self, name: str):
def load(self, name: str):
"""
Fetch an embedding function class by name
Parameters
----------
name : str
The name of the embedding function to fetch
Either the alias or the class name if no alias was provided
during registration
"""
return self._functions[name]
def parse_functions(
self, metadata: Optional[Dict[bytes, bytes]]
) -> Dict[str, "EmbeddingFunctionConfig"]:
def parse_functions(self, metadata: Optional[dict]) -> dict:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
@@ -115,9 +71,9 @@ class EmbeddingFunctionRegistry:
Parameters
----------
metadata : Optional[Dict[bytes, bytes]]
metadata : Optional[dict]
The metadata from an arrow table. Note that
the keys and values are bytes (pyarrow api)
the keys and values are bytes.
Returns
-------
@@ -130,94 +86,68 @@ class EmbeddingFunctionRegistry:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
return {
obj["vector_column"]: EmbeddingFunctionConfig(
vector_column=obj["vector_column"],
source_column=obj["source_column"],
function=self.get(obj["name"])(**obj["model"]),
)
for obj in raw_list
}
functions = {}
for obj in raw_list:
model = self.load(obj["schema"]["title"])
functions[obj["model"]["vector_column"]] = model(**obj["model"])
return functions
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
def function_to_metadata(self, func):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
func = conf.function
name = getattr(
func, "__embedding_function_registry_alias__", func.__class__.__name__
)
json_data = func.safe_model_dump()
schema = func.model_json_schema()
json_data = func.model_dump()
return {
"name": name,
"schema": schema,
"model": json_data,
"source_column": conf.source_column,
"vector_column": conf.vector_column,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector configs
Convert a list of embedding functions and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
if func_list is None or len(func_list) == 0:
return None
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes
# so we need to json dump then utf8 encode
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
# Global instance
__REGISTRY__ = EmbeddingFunctionRegistry()
REGISTRY = EmbeddingFunctionRegistry()
class EmbeddingFunctionModel(BaseModel, ABC):
"""
A callable ABC for embedding functions
"""
source_column: Optional[str]
vector_column: str
@abstractmethod
def __call__(self, *args, **kwargs) -> List[np.array]:
pass
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
class EmbeddingFunction(BaseModel, ABC):
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
"""
An ABC for embedding functions.
All concrete embedding functions must implement the following:
1. compute_query_embeddings() which takes a query and returns a list of embeddings
2. get_source_embeddings() which returns a list of embeddings for the source column
For text data, the two will be the same. For multi-modal data, the source column
might be images and the vector column might be text.
3. ndims method which returns the number of dimensions of the vector column
A callable ABC for embedding functions that take text as input
"""
_ndims: int = PrivateAttr()
@classmethod
def create(cls, **kwargs):
"""
Create an instance of the embedding function
"""
return cls(**kwargs)
@abstractmethod
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for a given user query
"""
pass
@abstractmethod
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
"""
Compute the embeddings for the source column in the database
"""
pass
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
"""
Sanitize the input to the embedding function.
Sanitize the input to the embedding function. This is called
before generate_embeddings() and is useful for stripping
whitespace, lowercasing, etc.
"""
if isinstance(texts, str):
texts = [texts]
@@ -227,78 +157,6 @@ class EmbeddingFunction(BaseModel, ABC):
texts = texts.combine_chunks().to_pylist()
return texts
@classmethod
def safe_import(cls, module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
return importlib.import_module(module)
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION
if PYDANTIC_VERSION.major < 2:
return dict(self)
return self.model_dump()
@abstractmethod
def ndims(self):
"""
Return the dimensions of the vector column
"""
pass
def SourceField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the source column for this embedding function
"""
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
def VectorField(self, **kwargs):
"""
Creates a pydantic Field that can automatically annotate
the target vector column for this embedding function
"""
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
class EmbeddingFunctionConfig(BaseModel):
"""
This model encapsulates the configuration for a embedding function
in a lancedb table. It holds the embedding function, the source column,
and the vector column
"""
vector_column: str
source_column: str
function: EmbeddingFunction
class TextEmbeddingFunction(EmbeddingFunction):
"""
A callable ABC for embedding functions that take text as input
"""
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, *args, **kwargs)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
@abstractmethod
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
@@ -309,25 +167,15 @@ class TextEmbeddingFunction(EmbeddingFunction):
pass
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
@REGISTRY.register()
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
"""
An embedding function that uses the sentence-transformers library
https://huggingface.co/sentence-transformers
"""
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
normalize: bool = False
@property
def embedding_model(self):
@@ -338,11 +186,6 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
"""
return self.__class__.get_embedding_model(self.name, self.device)
def ndims(self):
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
@@ -377,201 +220,9 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
sentence_transformers = cls.safe_import(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(name, device=device)
try:
from sentence_transformers import SentenceTransformer
@register("openai")
class OpenAIEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the OpenAI API
https://platform.openai.com/docs/guides/embeddings
"""
name: str = "text-embedding-ada-002"
def ndims(self):
# TODO don't hardcode this
return 1536
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
openai = self.safe_import("openai")
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
return [v["embedding"] for v in rs]
@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the OpenClip API
For multi-modal text-to-image search
https://github.com/mlfoundations/open_clip
"""
name: str = "ViT-B-32"
pretrained: str = "laion2b_s34b_b79k"
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_preprocess = PrivateAttr()
_tokenizer = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
open_clip = self.safe_import("open_clip", "open-clip")
model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained
)
model.to(self.device)
self._model, self._preprocess = model, preprocess
self._tokenizer = open_clip.get_tokenizer(self.name)
self._ndims = None
def ndims(self):
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = self.safe_import("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self.safe_import("torch")
text = self.sanitize_input(text)
text = self._tokenizer(text)
text.to(self.device)
with torch.no_grad():
text_features = self._model.encode_text(text.to(self.device))
if self.normalize:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().squeeze()
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given images
"""
images = self.sanitize_input(images)
embeddings = []
for i in range(0, len(images), self.batch_size):
j = min(i + self.batch_size, len(images))
batch = images[i:j]
embeddings.extend(self._parallel_get(batch))
return embeddings
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
"""
Issue concurrent requests to retrieve the image data
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [future.result() for future in futures]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
torch = self.safe_import("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0)
with torch.no_grad():
return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]):
PIL = self.safe_import("PIL", "pillow")
if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image):
return image
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
return PIL.Image.open(parsed.path)
elif parsed.scheme == "":
return PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
"""
encode a single image tensor and optionally normalize the output
"""
image_features = self._model.encode_image(image_tensor)
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
def url_retrieve(url: str):
"""
Parameters
----------
url: str
URL to download from
"""
try:
with urllib.request.urlopen(url) as conn:
return conn.read()
except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError("could not download {} due to {}".format(url, err))
return SentenceTransformer(name, device=device)
except ImportError:
raise ValueError("Please install sentence_transformers")

View File

@@ -26,8 +26,6 @@ import pyarrow as pa
import pydantic
import semver
from .embeddings import EmbeddingFunctionRegistry
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
@@ -48,19 +46,7 @@ class FixedSizeListMixin(ABC):
raise NotImplementedError
def vector(dim: int, value_type: pa.DataType = pa.float32()):
# TODO: remove in future release
from warnings import warn
warn(
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
"This function will be removed in future release",
DeprecationWarning,
)
return Vector(dim, value_type)
def Vector(
def vector(
dim: int, value_type: pa.DataType = pa.float32()
) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type.
@@ -79,12 +65,12 @@ def Vector(
--------
>>> import pydantic
>>> from lancedb.pydantic import Vector
>>> from lancedb.pydantic import vector
...
>>> class MyModel(pydantic.BaseModel):
... id: int
... url: str
... embeddings: Vector(768)
... embeddings: vector(768)
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
@@ -128,7 +114,7 @@ def Vector(
def validate(cls, v):
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
raise TypeError("A list of numbers or numpy.ndarray is needed")
return cls(v)
return v
if PYDANTIC_VERSION < (2, 0):
@@ -238,18 +224,27 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
>>> from typing import List, Optional
>>> import pydantic
>>> from lancedb.pydantic import pydantic_to_schema
...
>>> class InnerModel(pydantic.BaseModel):
... a: str
... b: Optional[float]
>>>
>>> class FooModel(pydantic.BaseModel):
... id: int
... s: str
... s: Optional[str] = None
... vec: List[float]
... li: List[int]
...
... inner: InnerModel
>>> schema = pydantic_to_schema(FooModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("s", pa.utf8(), False),
... pa.field("s", pa.utf8(), True),
... pa.field("vec", pa.list_(pa.float64()), False),
... pa.field("li", pa.list_(pa.int64()), False),
... pa.field("inner", pa.struct([
... pa.field("a", pa.utf8(), False),
... pa.field("b", pa.float64(), True),
... ]), False),
... ])
"""
fields = _pydantic_model_to_fields(model)
@@ -263,11 +258,11 @@ class LanceModel(pydantic.BaseModel):
Examples
--------
>>> import lancedb
>>> from lancedb.pydantic import LanceModel, Vector
>>> from lancedb.pydantic import LanceModel, vector
>>>
>>> class TestModel(LanceModel):
... name: str
... vector: Vector(2)
... vector: vector(2)
...
>>> db = lancedb.connect("/tmp")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
@@ -283,58 +278,13 @@ class LanceModel(pydantic.BaseModel):
"""
Get the Arrow Schema for this model.
"""
schema = pydantic_to_schema(cls)
functions = cls.parse_embedding_functions()
if len(functions) > 0:
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
functions
)
schema = schema.with_metadata(metadata)
return schema
return pydantic_to_schema(cls)
@classmethod
def field_names(cls) -> List[str]:
"""
Get the field names of this model.
"""
return list(cls.safe_get_fields().keys())
@classmethod
def safe_get_fields(cls):
if PYDANTIC_VERSION.major < 2:
return cls.__fields__
return cls.model_fields
@classmethod
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
"""
Parse the embedding functions from this model.
"""
from .embeddings import EmbeddingFunctionConfig
vec_and_function = []
for name, field_info in cls.safe_get_fields().items():
func = get_extras(field_info, "vector_column_for")
if func is not None:
vec_and_function.append([name, func])
configs = []
for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items():
src_func = get_extras(field_info, "source_column_for")
if src_func == func:
configs.append(
EmbeddingFunctionConfig(
source_column=source, vector_column=vec, function=func
)
)
return configs
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
return list(cls.__fields__.keys())
return list(cls.model_fields.keys())

View File

@@ -60,15 +60,13 @@ class LanceQueryBuilder(ABC):
def create(
cls,
table: "lancedb.table.Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
query: Optional[Union[np.ndarray, str]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
# convert "auto" query_type to "vector" or "fts"
# and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
@@ -92,27 +90,30 @@ class LanceQueryBuilder(ABC):
# otherwise raise TypeError
if query_type == "fts":
if not isinstance(query, str):
raise TypeError(f"'fts' queries must be a string: {type(query)}")
raise TypeError(
f"Query type is 'fts' but query is not a string: {type(query)}"
)
return query, query_type
elif query_type == "vector":
# If query_type is vector, then query must be a list or np.ndarray.
# otherwise raise TypeError
if not isinstance(query, (list, np.ndarray)):
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
raise TypeError(
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
elif isinstance(query, str):
func = table.embedding_functions.get(vector_column_name, None)
if func is not None:
query = func(query)[0]
return query, "vector"
else:
return query, "fts"
else:
raise TypeError("Query must be a list, np.ndarray, or str")
else:
raise ValueError(
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
@@ -237,7 +238,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
def __init__(
self,
table: "lancedb.table.Table",
query: Union[np.ndarray, list, "PIL.Image.Image"],
query: Union[np.ndarray, list],
vector_column: str = VECTOR_COLUMN_NAME,
):
super().__init__(table)

View File

@@ -18,9 +18,10 @@ from urllib.parse import urlparse
import pyarrow as pa
from ..common import DATA
from ..db import DBConnection
from ..table import Table, _sanitize_data
from lancedb.common import DATA
from lancedb.db import DBConnection
from lancedb.table import Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient

View File

@@ -28,8 +28,7 @@ from lance.dataset import ReaderLike
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionRegistry
from .embeddings.functions import EmbeddingFunctionConfig
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
from .pydantic import LanceModel
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas
@@ -82,16 +81,15 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
vector column to the table.
"""
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_column, conf in functions.items():
func = conf.function
if vector_column not in data.column_names:
col_data = func.compute_source_embeddings(data[conf.source_column])
for vector_col, func in functions.items():
if vector_col not in data.column_names:
col_data = func(data[func.source_column])
if schema is not None:
dtype = schema.field(vector_column).type
dtype = schema.field(vector_col).type
else:
dtype = pa.list_(pa.float32(), len(col_data[0]))
data = data.append_column(
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
)
return data
@@ -104,8 +102,7 @@ def _to_record_batch_generator(
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for batch in table.to_batches():
yield batch
else:
yield batch
yield batch
class Table(ABC):
@@ -232,7 +229,7 @@ class Table(ABC):
@abstractmethod
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
query: Optional[Union[VEC, str]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -241,7 +238,7 @@ class Table(ABC):
Parameters
----------
query: str, list, np.ndarray, PIL.Image.Image, default None
query: str, list, np.ndarray, default None
The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
@@ -251,8 +248,6 @@ class Table(ABC):
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
@@ -528,9 +523,6 @@ class LanceTable(Table):
fill_value: float = 0.0,
):
"""Add data to the table.
If vector columns are missing and the table
has embedding functions, then the vector columns
are automatically computed and added.
Parameters
----------
@@ -624,6 +616,12 @@ class LanceTable(Table):
)
self._reset_dataset()
def _get_embedding_function_for_source_col(self, column_name: str):
for k, v in self.embedding_functions.items():
if v.source_column == column_name:
return v
return None
@cached_property
def embedding_functions(self) -> dict:
"""
@@ -641,7 +639,7 @@ class LanceTable(Table):
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
query: Optional[Union[VEC, str]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -650,7 +648,7 @@ class LanceTable(Table):
Parameters
----------
query: str, list, np.ndarray, a PIL Image or None
query: str, list, np.ndarray, or None
The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
@@ -659,11 +657,9 @@ class LanceTable(Table):
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If the query is a list/np.ndarray then the query type is "vector";
If the query is a string, then the query type is "vector" if the
table has embedding functions, else the query type is "fts"
table has embedding functions else the query type is "fts"
Returns
-------
@@ -687,7 +683,7 @@ class LanceTable(Table):
mode="create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionConfig] = None,
embedding_functions: List[EmbeddingFunctionModel] = None,
):
"""
Create a new table.
@@ -730,16 +726,10 @@ class LanceTable(Table):
"""
tbl = LanceTable(db, name)
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
metadata = None
if embedding_functions is not None:
# If we passed in embedding functions explicitly
# then we'll override any schema metadata that
# may was implicitly specified by the LanceModel schema
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)

View File

@@ -70,11 +70,7 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
Get a PyArrow FileSystem from a URI, handling extra environment variables.
"""
if get_uri_scheme(uri) == "s3":
fs = pa_fs.S3FileSystem(
endpoint_override=os.environ.get("AWS_ENDPOINT"),
request_timeout=30,
connect_timeout=30,
)
fs = pa_fs.S3FileSystem(endpoint_override=os.environ.get("AWS_ENDPOINT"))
path = get_uri_location(uri)
return fs, path

View File

@@ -1,8 +1,8 @@
[project]
name = "lancedb"
version = "0.2.5"
version = "0.2.2"
dependencies = [
"pylance==0.7.4",
"pylance==0.6.5",
"ratelimiter",
"retry",
"tqdm",
@@ -44,11 +44,9 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio"]
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
[build-system]
requires = ["setuptools", "wheel"]
@@ -56,10 +54,3 @@ build-backend = "setuptools.build_meta"
[tool.isort]
profile = "black"
[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio"
]

View File

@@ -17,7 +17,7 @@ import pyarrow as pa
import pytest
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, vector
def test_basic(tmp_path):
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
def test_ingest_iterator(tmp_path):
class PydanticSchema(LanceModel):
vector: Vector(2)
vector: vector(2)
item: str
price: float
@@ -136,12 +136,13 @@ def test_ingest_iterator(tmp_path):
def run_tests(schema):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
tbl_len = len(tbl)
tbl.add(make_batches())
assert tbl_len == 50
assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 3
db.drop_database()

View File

@@ -16,12 +16,8 @@ import lance
import numpy as np
import pyarrow as pa
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
with_embeddings,
)
from lancedb.conftest import MockEmbeddingFunction
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
def mock_embed_func(input_data):
@@ -58,12 +54,8 @@ def test_embedding_function(tmp_path):
"vector": [np.random.randn(10), np.random.randn(10)],
}
)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
metadata = registry.get_table_metadata([func])
table = table.replace_schema_metadata(metadata)
# Write it to disk
@@ -73,13 +65,14 @@ def test_embedding_function(tmp_path):
ds = lance.dataset(tmp_path / "test.lance")
# can we get the serialized version back out?
configs = registry.parse_functions(ds.schema.metadata)
functions = registry.parse_functions(ds.schema.metadata)
conf = configs["vector"]
func = conf.function
actual = func.compute_query_embeddings("hello world")
func = functions["vector"]
actual = func("hello world")
# We create an instance
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
# And we make sure we can call it
expected = func.compute_query_embeddings("hello world")
expected = expected_func("hello world")
assert np.allclose(actual, expected)

View File

@@ -1,125 +0,0 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import numpy as np
import pandas as pd
import pytest
import requests
import lancedb
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_sentence_transformer(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get(alias).create()
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words)
table.add(
pd.DataFrame(
{
"text": [
"hello world",
"goodbye world",
"fizz",
"buzz",
"foo",
"bar",
"baz",
]
}
)
)
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
vec = func.compute_query_embeddings(query)[0]
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
assert actual.text == expected.text
assert actual.text == "hello world"
@pytest.mark.slow
def test_openclip(tmp_path):
from PIL import Image
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get("open-clip").create()
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
)
# text search
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
assert actual.label == "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == frombytes.label
assert np.allclose(actual.vector, frombytes.vector)
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(io.BytesIO(image_bytes))
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
assert actual.label == "dog"
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == other.label
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
assert np.allclose(
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)

View File

@@ -19,9 +19,8 @@ from typing import List, Optional
import pyarrow as pa
import pydantic
import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
@pytest.mark.skipif(
@@ -108,7 +107,7 @@ def test_pydantic_to_arrow_py38():
def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel):
vec: Vector(16)
vec: vector(16)
li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
@@ -155,7 +154,7 @@ def test_fixed_size_list_field():
def test_fixed_size_list_validation():
class TestModel(pydantic.BaseModel):
vec: Vector(8)
vec: vector(8)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(9))
@@ -168,12 +167,9 @@ def test_fixed_size_list_validation():
def test_lance_model():
class TestModel(LanceModel):
vector: Vector(16) = Field(default=[0.0] * 16)
li: List[int] = Field(default=[1, 2, 3])
vec: vector(16)
li: List[int]
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["vector", "li"]
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
assert TestModel.field_names() == ["vec", "li"]

View File

@@ -20,7 +20,7 @@ import pyarrow as pa
import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, vector
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
@@ -67,7 +67,7 @@ def table(tmp_path) -> MockTable:
def test_cast(table):
class TestModel(LanceModel):
vector: Vector(2)
vector: vector(2)
id: int
str_field: str
float_field: float

View File

@@ -22,10 +22,9 @@ import pandas as pd
import pyarrow as pa
import pytest
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.conftest import MockEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, vector
from lancedb.table import LanceTable
@@ -141,7 +140,7 @@ def test_add(db):
def test_add_pydantic_model(db):
class TestModel(LanceModel):
vector: Vector(16)
vector: vector(16)
li: List[int]
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
@@ -355,25 +354,22 @@ def test_update(db):
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: Vector(10)
vector: vector(10)
func = MockTextEmbeddingFunction()
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
df = pd.DataFrame({"text": texts, "vector": func(texts)})
conf = EmbeddingFunctionConfig(
source_column="text", vector_column="vector", function=func
)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[conf],
embedding_functions=[func],
)
table.add(df)
query_str = "hi how are you?"
query_vector = func.compute_query_embeddings(query_str)[0]
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
@@ -381,13 +377,17 @@ def test_create_with_embedding_function(db):
def test_add_with_embedding_function(db):
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
text: str
vector: vector(10)
table = LanceTable.create(db, "my_table", schema=MyTable)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[func],
)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
@@ -397,7 +397,7 @@ def test_add_with_embedding_function(db):
table.add([{"text": t} for t in texts])
query_str = "hi how are you?"
query_vector = emb.compute_query_embeddings(query_str)[0]
query_vector = func(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
@@ -407,8 +407,8 @@ def test_add_with_embedding_function(db):
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: Vector(10)
vector2: Vector(10)
vector1: vector(10)
vector2: vector(10)
table = LanceTable.create(
db,

View File

@@ -1,6 +1,6 @@
[package]
name = "vectordb-node"
version = "0.2.6"
version = "0.2.4"
description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
edition = "2018"
@@ -18,7 +18,6 @@ once_cell = "1"
futures = "0.3"
half = { workspace = true }
lance = { workspace = true }
lance-linalg = { workspace = true }
vectordb = { path = "../../vectordb" }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }

View File

@@ -28,9 +28,7 @@ fn validate_vector_column(record_batch: &RecordBatch) -> Result<()> {
record_batch
.column_by_name(VECTOR_COLUMN_NAME)
.map(|_| ())
.context(MissingColumnSnafu {
name: VECTOR_COLUMN_NAME,
})
.context(MissingColumnSnafu { name: VECTOR_COLUMN_NAME })
}
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {

View File

@@ -14,7 +14,7 @@
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use lance_linalg::distance::MetricType;
use lance::index::vector::MetricType;
use neon::context::FunctionContext;
use neon::prelude::*;
use std::convert::TryFrom;

View File

@@ -183,9 +183,11 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let aws_region = get_aws_region(&mut cx, 4)?;
let params = ReadParams {
store_options: Some(ObjectStoreParams::with_aws_credentials(
aws_creds, aws_region,
)),
store_options: Some(ObjectStoreParams {
aws_credentials: aws_creds,
aws_region,
..ObjectStoreParams::default()
}),
..ReadParams::default()
};

View File

@@ -3,7 +3,7 @@ use std::ops::Deref;
use arrow_array::Float32Array;
use futures::{TryFutureExt, TryStreamExt};
use lance_linalg::distance::MetricType;
use lance::index::vector::MetricType;
use neon::context::FunctionContext;
use neon::handle::Handle;
use neon::prelude::*;

View File

@@ -43,8 +43,7 @@ impl JsTable {
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let buffer = cx.argument::<JsBuffer>(1)?;
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
// Write mode
let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() {
@@ -66,9 +65,11 @@ impl JsTable {
let aws_region = get_aws_region(&mut cx, 6)?;
let params = WriteParams {
store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_creds, aws_region,
)),
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
aws_region,
..ObjectStoreParams::default()
}),
mode: mode,
..WriteParams::default()
};
@@ -91,8 +92,7 @@ impl JsTable {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
let (batches, schema) = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let mut table = js_table.table.clone();
@@ -108,9 +108,11 @@ impl JsTable {
let aws_region = get_aws_region(&mut cx, 5)?;
let params = WriteParams {
store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_creds, aws_region,
)),
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
aws_region,
..ObjectStoreParams::default()
}),
mode: write_mode,
..WriteParams::default()
};

View File

@@ -1,6 +1,6 @@
[package]
name = "vectordb"
version = "0.2.6"
version = "0.2.4"
edition = "2021"
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
@@ -10,21 +10,14 @@ categories = ["database-implementations"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lance = { workspace = true }
lance-linalg = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
log = { workspace = true }
num-traits = "0"
url = { workspace = true }
[dev-dependencies]
tempfile = "3.5.0"

View File

@@ -1,15 +0,0 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use lance::arrow::*;

View File

@@ -1,18 +0,0 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Data types, schema coercion, and data cleaning and etc.
pub mod inspect;
pub mod sanitize;

View File

@@ -1,180 +0,0 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use arrow::compute::kernels::{aggregate::bool_and, length::length};
use arrow_array::{
cast::AsArray,
types::{ArrowPrimitiveType, Int32Type, Int64Type},
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
};
use arrow_ord::comparison::eq_dyn_scalar;
use arrow_schema::DataType;
use num_traits::{ToPrimitive, Zero};
use crate::error::{Error, Result};
pub(crate) fn infer_dimension<T: ArrowPrimitiveType>(
list_arr: &GenericListArray<T::Native>,
) -> Result<Option<T::Native>>
where
T::Native: OffsetSizeTrait + ToPrimitive,
{
let len_arr = length(list_arr)?;
if len_arr.is_empty() {
return Ok(Some(Zero::zero()));
}
let dim = len_arr.as_primitive::<T>().value(0);
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
Ok(None)
} else {
Ok(Some(dim))
}
}
/// Infer the vector columns from a dataset.
///
/// Parameters
/// ----------
/// - reader: RecordBatchReader
/// - strict: if set true, only fixed_size_list<float> is considered as vector column. If set to false,
/// a list<float> column with same length is also considered as vector column.
pub fn infer_vector_columns(
reader: impl RecordBatchReader + Send,
strict: bool,
) -> Result<Vec<String>> {
let mut columns = vec![];
let mut columns_to_infer: HashMap<String, Option<i64>> = HashMap::new();
for field in reader.schema().fields() {
match field.data_type() {
DataType::FixedSizeList(sub_field, _) if sub_field.data_type().is_floating() => {
columns.push(field.name().to_string());
}
DataType::List(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().to_string(), None);
}
DataType::LargeList(sub_field) if sub_field.data_type().is_floating() && !strict => {
columns_to_infer.insert(field.name().to_string(), None);
}
_ => {}
}
}
for batch in reader {
let batch = batch?;
let col_names = columns_to_infer.keys().cloned().collect::<Vec<_>>();
for col_name in col_names {
let col = batch.column_by_name(&col_name).ok_or(Error::Schema {
message: format!("Column {} not found", col_name),
})?;
if let Some(dim) = match *col.data_type() {
DataType::List(_) => {
infer_dimension::<Int32Type>(col.as_list::<i32>())?.map(|d| d as i64)
}
DataType::LargeList(_) => infer_dimension::<Int64Type>(col.as_list::<i64>())?,
_ => {
return Err(Error::Schema {
message: format!("Column {} is not a list", col_name),
})
}
} {
if let Some(Some(prev_dim)) = columns_to_infer.get(&col_name) {
if prev_dim != &dim {
columns_to_infer.remove(&col_name);
}
} else {
columns_to_infer.insert(col_name, Some(dim));
}
} else {
columns_to_infer.remove(&col_name);
}
}
}
columns.extend(columns_to_infer.keys().cloned());
Ok(columns)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{
types::{Float32Type, Float64Type},
FixedSizeListArray, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use std::{sync::Arc, vec};
#[test]
fn test_infer_vector_columns() {
let schema = Arc::new(Schema::new(vec![
Field::new("f", DataType::Float32, false),
Field::new("s", DataType::Utf8, false),
Field::new(
"l1",
DataType::List(Arc::new(Field::new("item", DataType::Float32, true))),
false,
),
Field::new(
"l2",
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
false,
),
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 32),
true,
),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
Arc::new(ListArray::from_iter_primitive::<Float32Type, _, _>(
(0..3).map(|_| Some(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)])),
)),
// Var-length list
Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
Some(vec![Some(1.0_f64)]),
Some(vec![Some(2.0_f64), Some(3.0_f64)]),
Some(vec![Some(4.0_f64), Some(5.0_f64), Some(6.0_f64)]),
])),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
vec![
Some(vec![Some(1.0); 32]),
Some(vec![Some(2.0); 32]),
Some(vec![Some(3.0); 32]),
],
32,
),
),
],
)
.unwrap();
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
let cols = infer_vector_columns(reader, false).unwrap();
assert_eq!(cols, vec!["fl", "l1"]);
let reader = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
let cols = infer_vector_columns(reader, true).unwrap();
assert_eq!(cols, vec!["fl"]);
}
}

View File

@@ -1,284 +0,0 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{iter::repeat_with, sync::Arc};
use arrow_array::{
cast::AsArray,
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
RecordBatchReader,
};
use arrow_cast::{can_cast_types, cast};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use half::f16;
use lance::arrow::{DataTypeExt, FixedSizeListArrayExt};
use log::warn;
use num_traits::cast::AsPrimitive;
use super::inspect::infer_dimension;
use crate::error::Result;
fn cast_array<I: ArrowNumericType, O: ArrowNumericType>(
arr: &PrimitiveArray<I>,
) -> Arc<PrimitiveArray<O>>
where
I::Native: AsPrimitive<O::Native>,
{
Arc::new(PrimitiveArray::<O>::from_iter_values(
arr.values().iter().map(|v| (*v).as_()),
))
}
fn cast_float_array<I: ArrowNumericType>(
arr: &PrimitiveArray<I>,
dt: &DataType,
) -> std::result::Result<Arc<dyn Array>, ArrowError>
where
I::Native: AsPrimitive<f64> + AsPrimitive<f32> + AsPrimitive<f16>,
{
match dt {
DataType::Float16 => Ok(cast_array::<I, Float16Type>(arr)),
DataType::Float32 => Ok(cast_array::<I, Float32Type>(arr)),
DataType::Float64 => Ok(cast_array::<I, Float64Type>(arr)),
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field: unable to coerce {:?} to {:?}",
arr.data_type(),
dt
))),
}
}
fn coerce_array(
array: &Arc<dyn Array>,
field: &Field,
) -> std::result::Result<Arc<dyn Array>, ArrowError> {
if array.data_type() == field.data_type() {
return Ok(array.clone());
}
match (array.data_type(), field.data_type()) {
// Normal cast-able types.
(adt, dt) if can_cast_types(adt, dt) => cast(&array, dt),
// Casting between f16/f32/f64 can be lossy.
(adt, dt) if (adt.is_floating() || dt.is_floating()) => {
if adt.byte_width() > dt.byte_width() {
warn!(
"Coercing field {} {:?} to {:?} might lose precision",
field.name(),
adt,
dt
);
}
match adt {
DataType::Float16 => cast_float_array(array.as_primitive::<Float16Type>(), dt),
DataType::Float32 => cast_float_array(array.as_primitive::<Float32Type>(), dt),
DataType::Float64 => cast_float_array(array.as_primitive::<Float64Type>(), dt),
_ => unreachable!(),
}
}
(adt, DataType::FixedSizeList(exp_field, exp_dim)) => match adt {
// Cast a float fixed size array with same dimension to the expected type.
DataType::FixedSizeList(_, dim) if dim == exp_dim => {
let actual_sub = array.as_fixed_size_list();
let values = coerce_array(actual_sub.values(), exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*dim,
)?) as Arc<dyn Array>)
}
DataType::List(_) | DataType::LargeList(_) => {
let Some(dim) = (match adt {
DataType::List(_) => infer_dimension::<Int32Type>(array.as_list::<i32>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from list: {}",
e
))
})?
.map(|d| d as i64),
DataType::LargeList(_) => infer_dimension::<Int64Type>(array.as_list::<i64>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from large list: {}",
e
))
})?,
_ => unreachable!(),
}) else {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)));
};
if dim != *exp_dim as i64 {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: expected dimension {} but got {}",
exp_dim, dim
)));
}
let values = coerce_array(array, exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*exp_dim,
)?) as Arc<dyn Array>)
}
_ => Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)))?,
},
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field {}: unable to coerce {:?} to {:?}",
field.name(),
array.data_type(),
field.data_type()
)))?,
}
}
fn coerce_schema_batch(
batch: RecordBatch,
schema: Arc<Schema>,
) -> std::result::Result<RecordBatch, ArrowError> {
if batch.schema() == schema {
return Ok(batch);
}
let columns = schema
.fields()
.iter()
.map(|field| {
batch
.column_by_name(field.name())
.ok_or_else(|| {
ArrowError::SchemaError(format!("Column {} not found", field.name()))
})
.and_then(|c| coerce_array(c, field))
})
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
RecordBatch::try_new(schema, columns)
}
/// Coerce the reader (input data) to match the given [Schema].
///
pub fn coerce_schema(
reader: impl RecordBatchReader + Send + 'static,
schema: Arc<Schema>,
) -> Result<Box<dyn RecordBatchReader + Send>> {
if reader.schema() == schema {
return Ok(Box::new(RecordBatchIterator::new(reader, schema)));
}
let s = schema.clone();
let batches = reader
.zip(repeat_with(move || s.clone()))
.map(|(batch, s)| coerce_schema_batch(batch?, s));
Ok(Box::new(RecordBatchIterator::new(batches, schema)))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow_array::{
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int32Array, Int8Array,
RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::Field;
use half::f16;
use lance::arrow::FixedSizeListArrayExt;
#[test]
fn test_coerce_list_to_fixed_size_list() {
let schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float16, true),
Field::new("i", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float32Array::from_iter_values((0..256).map(|v| v as f32)),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float16Array::from_iter_values(
(0..4).map(|v| f16::from_f32(v as f32)),
)),
Arc::new(Int32Array::from_iter_values(0..4)),
],
)
.unwrap();
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
let expected_schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float64, true),
Field::new("i", DataType::Int8, true),
]));
let stream = coerce_schema(reader, expected_schema.clone()).unwrap();
let batches = stream.collect::<Vec<_>>();
assert_eq!(batches.len(), 1);
let batch = batches[0].as_ref().unwrap();
assert_eq!(batch.schema(), expected_schema);
let expected = RecordBatch::try_new(
expected_schema,
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float16Array::from_iter_values((0..256).map(|v| f16::from_f32(v as f32))),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float64Array::from_iter_values((0..4).map(|v| v as f64))),
Arc::new(Int8Array::from_iter_values(0..4)),
],
)
.unwrap();
assert_eq!(batch, &expected);
}
}

View File

@@ -27,14 +27,12 @@ pub const LANCE_FILE_EXTENSION: &str = "lance";
pub struct Database {
object_store: ObjectStore,
query_string: Option<String>,
pub(crate) uri: String,
pub(crate) base_path: object_store::path::Path,
}
const LANCE_EXTENSION: &str = "lance";
const ENGINE: &str = "engine";
/// A connection to LanceDB
impl Database {
@@ -48,73 +46,12 @@ impl Database {
///
/// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> {
let parse_res = url::Url::parse(uri);
match parse_res {
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await,
Ok(mut url) => {
// iter thru the query params and extract the commit store param
let mut engine = None;
let mut filtered_querys = vec![];
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
// THE API WILL CHANGE
for (key, value) in url.query_pairs() {
if key == ENGINE {
engine = Some(value.to_string());
} else {
// to owned so we can modify the url
filtered_querys.push((key.to_string(), value.to_string()));
}
}
// Filter out the commit store query param -- it's a lancedb param
url.query_pairs_mut().clear();
url.query_pairs_mut().extend_pairs(filtered_querys);
// Take a copy of the query string so we can propagate it to lance
let query_string = url.query().map(|s| s.to_string());
// clear the query string so we can use the url as the base uri
// use .set_query(None) instead of .set_query("") because the latter
// will add a trailing '?' to the url
url.set_query(None);
let table_base_uri = if let Some(store) = engine {
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
WARN_ONCE.call_once(|| {
log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
});
let old_scheme = url.scheme().to_string();
let new_scheme = format!("{}+{}", old_scheme, store);
url.to_string().replacen(&old_scheme, &new_scheme, 1)
} else {
url.to_string()
};
let plain_uri = url.to_string();
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
if object_store.is_local() {
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
}
Ok(Database {
uri: table_base_uri,
query_string,
base_path,
object_store,
})
}
Err(_) => Self::open_path(uri).await,
}
}
async fn open_path(path: &str) -> Result<Database> {
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
let (object_store, base_path) = ObjectStore::from_uri(uri).await?;
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
}
Ok(Self {
uri: path.to_string(),
query_string: None,
Ok(Database {
uri: uri.to_string(),
base_path,
object_store,
})
@@ -212,26 +149,17 @@ impl Database {
let path = Path::new(&self.uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let mut uri = table_uri
let uri = table_uri
.as_path()
.to_str()
.context(InvalidTableNameSnafu { name })?
.to_string();
// If there are query string set on the connection, propagate to lance
if let Some(query) = self.query_string.as_ref() {
uri.push('?');
uri.push_str(query.as_str());
}
Ok(uri)
.context(InvalidTableNameSnafu { name })?;
Ok(uri.to_string())
}
}
#[cfg(test)]
mod tests {
use std::fs::create_dir_all;
use tempfile::tempdir;
use crate::database::Database;
@@ -245,28 +173,6 @@ mod tests {
assert_eq!(db.uri, uri);
}
#[cfg(not(windows))]
#[tokio::test]
async fn test_connect_relative() {
let tmp_dir = tempdir().unwrap();
let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap();
let mut relative_anacestors = vec![];
let current_dir = std::env::current_dir().unwrap();
let mut ancestors = current_dir.ancestors();
while let Some(_) = ancestors.next() {
relative_anacestors.push("..");
}
let relative_root = std::path::PathBuf::from(relative_anacestors.join("/"));
let relative_uri = relative_root.join(&uri);
let db = Database::connect(relative_uri.to_str().unwrap())
.await
.unwrap();
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
}
#[tokio::test]
async fn test_table_names() {
let tmp_dir = tempdir().unwrap();

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow_schema::ArrowError;
use snafu::Snafu;
#[derive(Debug, Snafu)]
@@ -33,20 +32,10 @@ pub enum Error {
Store { message: String },
#[snafu(display("LanceDBError: {message}"))]
Lance { message: String },
#[snafu(display("LanceDB Schema Error: {message}"))]
Schema { message: String },
}
pub type Result<T> = std::result::Result<T, Error>;
impl From<ArrowError> for Error {
fn from(e: ArrowError) -> Self {
Self::Lance {
message: e.to_string(),
}
}
}
impl From<lance::Error> for Error {
fn from(e: lance::Error) -> Self {
Self::Lance {

View File

@@ -14,8 +14,7 @@
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::VectorIndexParams;
use lance_linalg::distance::MetricType;
use lance::index::vector::{MetricType, VectorIndexParams};
pub trait VectorIndexBuilder {
fn get_column(&self) -> Option<String>;
@@ -108,11 +107,9 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
#[cfg(test)]
mod tests {
use super::*;
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::StageParams;
use lance::index::vector::{MetricType, StageParams};
use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod data;
pub mod database;
pub mod error;
pub mod index;

View File

@@ -17,7 +17,7 @@ use std::sync::Arc;
use arrow_array::Float32Array;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::Dataset;
use lance_linalg::distance::MetricType;
use lance::index::vector::MetricType;
use crate::error::Result;
@@ -164,10 +164,10 @@ impl Query {
mod tests {
use std::sync::Arc;
use super::*;
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use lance::dataset::Dataset;
use lance::index::vector::MetricType;
use crate::query::Query;