mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 22:09:58 +00:00
Compare commits
1 Commits
python-v0.
...
changhiskh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96a7c1ab42 |
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.9.18", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.18" }
|
||||
lance-linalg = { "version" = "=0.9.18" }
|
||||
lance-testing = { "version" = "=0.9.18" }
|
||||
lance = { "version" = "=0.9.16", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.16" }
|
||||
lance-linalg = { "version" = "=0.9.16" }
|
||||
lance-testing = { "version" = "=0.9.16" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
|
||||
@@ -99,9 +99,10 @@ nav:
|
||||
- Configuring Storage: guides/storage.md
|
||||
- 🧬 Managing embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- Explicit management: embeddings/embedding_explicit.md
|
||||
- Implicit management: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Custom Embedding Functions: embeddings/api.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔌 Integrations:
|
||||
@@ -163,9 +164,10 @@ nav:
|
||||
- Configuring Storage: guides/storage.md
|
||||
- Managing Embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- Explicit management: embeddings/embedding_explicit.md
|
||||
- Implicit management: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Custom Embedding Functions: embeddings/api.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Integrations:
|
||||
@@ -206,7 +208,6 @@ extra_css:
|
||||
|
||||
extra_javascript:
|
||||
- "extra_js/init_ask_ai_widget.js"
|
||||
- "extra_js/meta_tag.js"
|
||||
|
||||
extra:
|
||||
analytics:
|
||||
|
||||
141
docs/src/embeddings/embedding_explicit.md
Normal file
141
docs/src/embeddings/embedding_explicit.md
Normal file
@@ -0,0 +1,141 @@
|
||||
In this workflow, you define your own embedding function and pass it as a callable to LanceDB, invoking it in your code to generate the embeddings. Let's look at some examples.
|
||||
|
||||
### Hugging Face
|
||||
|
||||
!!! note
|
||||
Currently, the Hugging Face method is only supported in the Python SDK.
|
||||
|
||||
=== "Python"
|
||||
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
|
||||
library, which can be installed via pip.
|
||||
|
||||
```bash
|
||||
pip install sentence-transformers
|
||||
```
|
||||
|
||||
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
|
||||
for a given document.
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
name="paraphrase-albert-small-v2"
|
||||
model = SentenceTransformer(name)
|
||||
|
||||
# used for both training and querying
|
||||
def embed_func(batch):
|
||||
return [model.encode(sentence) for sentence in batch]
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
|
||||
|
||||
=== "Python"
|
||||
```python
|
||||
import openai
|
||||
import os
|
||||
|
||||
# Configuring the environment variable OPENAI_API_KEY
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
|
||||
# verify that the API key is working
|
||||
assert len(openai.Model.list()["data"]) > 0
|
||||
|
||||
def embed_func(c):
|
||||
rs = openai.Embedding.create(input=c, engine="text-embedding-ada-002")
|
||||
return [record["embedding"] for record in rs["data"]]
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
```javascript
|
||||
const lancedb = require("vectordb");
|
||||
|
||||
// You need to provide an OpenAI API key
|
||||
const apiKey = "sk-..."
|
||||
// The embedding function will create embeddings for the 'text' column
|
||||
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
|
||||
```
|
||||
|
||||
## Applying an embedding function to data
|
||||
|
||||
=== "Python"
|
||||
Using an embedding function, you can apply it to raw data
|
||||
to generate embeddings for each record.
|
||||
|
||||
Say you have a pandas DataFrame with a `text` column that you want embedded,
|
||||
you can use the `with_embeddings` function to generate embeddings and add them to
|
||||
an existing table.
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
from lancedb.embeddings import with_embeddings
|
||||
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{"text": "pepperoni"},
|
||||
{"text": "pineapple"}
|
||||
]
|
||||
)
|
||||
data = with_embeddings(embed_func, df)
|
||||
|
||||
# The output is used to create / append to a table
|
||||
# db.create_table("my_table", data=data)
|
||||
```
|
||||
|
||||
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
|
||||
|
||||
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
|
||||
using the `batch_size` parameter to `with_embeddings`.
|
||||
|
||||
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
|
||||
API call is reliable.
|
||||
|
||||
=== "JavaScript"
|
||||
Using an embedding function, you can apply it to raw data
|
||||
to generate embeddings for each record.
|
||||
|
||||
Simply pass the embedding function created above and LanceDB will use it to generate
|
||||
embeddings for your data.
|
||||
|
||||
```javascript
|
||||
const db = await lancedb.connect("data/sample-lancedb");
|
||||
const data = [
|
||||
{ text: "pepperoni"},
|
||||
{ text: "pineapple"}
|
||||
]
|
||||
|
||||
const table = await db.createTable("vectors", data, embedding)
|
||||
```
|
||||
|
||||
## Querying using an embedding function
|
||||
|
||||
!!! warning
|
||||
At query time, you **must** use the same embedding function you used to vectorize your data.
|
||||
If you use a different embedding function, the embeddings will not reside in the same vector
|
||||
space and the results will be nonsensical.
|
||||
|
||||
=== "Python"
|
||||
```python
|
||||
query = "What's the best pizza topping?"
|
||||
query_vector = embed_func([query])[0]
|
||||
results = (
|
||||
tbl.search(query_vector)
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||
|
||||
=== "JavaScript"
|
||||
```javascript
|
||||
const results = await table
|
||||
.search("What's the best pizza topping?")
|
||||
.limit(10)
|
||||
.execute()
|
||||
```
|
||||
|
||||
The above snippet returns an array of records with the top 10 nearest neighbors to the query.
|
||||
@@ -3,126 +3,61 @@ Representing multi-modal data as vector embeddings is becoming a standard practi
|
||||
For this purpose, LanceDB introduces an **embedding functions API**, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions *disappear in the background* so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline.
|
||||
|
||||
!!! warning
|
||||
Using the embedding function registry means that you don't have to explicitly generate the embeddings yourself.
|
||||
However, if your embedding function changes, you'll have to re-configure your table with the new embedding function
|
||||
and regenerate the embeddings. In the future, we plan to support the ability to change the embedding function via
|
||||
table metadata and have LanceDB automatically take care of regenerating the embeddings.
|
||||
|
||||
Using the implicit embeddings management approach means that you can forget about the manually passing around embedding
|
||||
functions in your code, as long as you don't intend to change it at a later time. If your embedding function changes,
|
||||
you'll have to re-configure your table with the new embedding function and regenerate the embeddings.
|
||||
|
||||
## 1. Define the embedding function
|
||||
We have some pre-defined embedding functions in the global registry, with more coming soon. Here's let's an implementation of CLIP as example.
|
||||
```
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
clip = registry.get("open-clip").create()
|
||||
|
||||
=== "Python"
|
||||
In the LanceDB python SDK, we define a global embedding function registry with
|
||||
many different embedding models and even more coming soon.
|
||||
Here's let's an implementation of CLIP as example.
|
||||
|
||||
```python
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
clip = registry.get("open-clip").create()
|
||||
```
|
||||
|
||||
You can also define your own embedding function by implementing the `EmbeddingFunction`
|
||||
abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
|
||||
|
||||
=== "JavaScript""
|
||||
In the TypeScript SDK, the choices are more limited. For now, only the OpenAI
|
||||
embedding function is available.
|
||||
|
||||
```javascript
|
||||
const lancedb = require("vectordb");
|
||||
|
||||
// You need to provide an OpenAI API key
|
||||
const apiKey = "sk-..."
|
||||
// The embedding function will create embeddings for the 'text' column
|
||||
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
|
||||
```
|
||||
```
|
||||
You can also define your own embedding function by implementing the `EmbeddingFunction` abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
|
||||
|
||||
## 2. Define the data model or schema
|
||||
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
|
||||
|
||||
=== "Python"
|
||||
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
```
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
```
|
||||
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
|
||||
|
||||
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
|
||||
## 3. Create LanceDB table
|
||||
Now that we have chosen/defined our embedding function and the schema, we can create the table:
|
||||
|
||||
=== "JavaScript"
|
||||
```python
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table("pets", schema=Pets)
|
||||
|
||||
For the TypeScript SDK, a schema can be inferred from input data, or an explicit
|
||||
Arrow schema can be provided.
|
||||
```
|
||||
|
||||
## 3. Create table and add data
|
||||
That's it! We've provided all the information needed to embed the source and query inputs. We can now forget about the model and dimension details and start to build our VectorDB pipeline.
|
||||
|
||||
Now that we have chosen/defined our embedding function and the schema,
|
||||
we can create the table and ingest data without needing to explicitly generate
|
||||
the embeddings at all:
|
||||
## 4. Ingest lots of data and query your table
|
||||
Any new or incoming data can just be added and it'll be vectorized automatically.
|
||||
|
||||
=== "Python"
|
||||
```python
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table("pets", schema=Pets)
|
||||
```python
|
||||
table.add([{"image_uri": u} for u in uris])
|
||||
```
|
||||
|
||||
table.add([{"image_uri": u} for u in uris])
|
||||
```
|
||||
Our OpenCLIP query embedding function supports querying via both text and images:
|
||||
|
||||
=== "JavaScript"
|
||||
```python
|
||||
result = table.search("dog")
|
||||
```
|
||||
|
||||
```javascript
|
||||
const db = await lancedb.connect("data/sample-lancedb");
|
||||
const data = [
|
||||
{ text: "pepperoni"},
|
||||
{ text: "pineapple"}
|
||||
]
|
||||
Let's query an image:
|
||||
|
||||
const table = await db.createTable("vectors", data, embedding)
|
||||
```
|
||||
|
||||
## 4. Querying your table
|
||||
Not only can you forget about the embeddings during ingestion, you also don't
|
||||
need to worry about it when you query the table:
|
||||
|
||||
=== "Python"
|
||||
|
||||
Our OpenCLIP query embedding function supports querying via both text and images:
|
||||
|
||||
```python
|
||||
results = (
|
||||
table.search("dog")
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
Or we can search using an image:
|
||||
|
||||
```python
|
||||
p = Path("path/to/images/samoyed_100.jpg")
|
||||
query_image = Image.open(p)
|
||||
results = (
|
||||
table.search(query_image)
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
Both of the above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
```javascript
|
||||
const results = await table
|
||||
.search("What's the best pizza topping?")
|
||||
.limit(10)
|
||||
.execute()
|
||||
```
|
||||
|
||||
The above snippet returns an array of records with the top 10 nearest neighbors to the query.
|
||||
```python
|
||||
p = Path("path/to/images/samoyed_100.jpg")
|
||||
query_image = Image.open(p)
|
||||
table.search(query_image)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -165,5 +100,4 @@ rs[2].image
|
||||
|
||||

|
||||
|
||||
Now that you have the basic idea about LanceDB embedding functions and the embedding function registry,
|
||||
let's dive deeper into defining your own [custom functions](./custom_embedding_function.md).
|
||||
Now that you have the basic idea about implicit management via embedding functions, let's dive deeper into a [custom API](./api.md) that you can use to implement your own embedding functions.
|
||||
@@ -1,14 +1,8 @@
|
||||
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio.
|
||||
This makes them a very powerful tool for machine learning practitioners.
|
||||
However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs
|
||||
(both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
|
||||
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio. This makes them a very powerful tool for machine learning practitioners. However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs (both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
|
||||
|
||||
LanceDB supports 3 methods of working with embeddings.
|
||||
LanceDB supports 2 methods of vectorizing your raw data into embeddings.
|
||||
|
||||
1. You can manually generate embeddings for the data and queries. This is done outside of LanceDB.
|
||||
2. You can use the built-in [embedding functions](./embedding_functions.md) to embed the data and queries in the background.
|
||||
3. For python users, you can define your own [custom embedding function](./custom_embedding_function.md)
|
||||
that extends the default embedding functions.
|
||||
1. **Explicit**: By manually calling LanceDB's `with_embedding` function to vectorize your data via an `embed_func` of your choice
|
||||
2. **Implicit**: Allow LanceDB to embed the data and queries in the background as they come in, by using the table's `EmbeddingRegistry` information
|
||||
|
||||
For python users, there is also a legacy [with_embeddings API](./legacy.md).
|
||||
It is retained for compatibility and will be removed in a future version.
|
||||
See the [explicit](embedding_explicit.md) and [implicit](embedding_functions.md) embedding sections for more details.
|
||||
@@ -1,99 +0,0 @@
|
||||
The legacy `with_embeddings` API is for Python only and is deprecated.
|
||||
|
||||
### Hugging Face
|
||||
|
||||
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
|
||||
library, which can be installed via pip.
|
||||
|
||||
```bash
|
||||
pip install sentence-transformers
|
||||
```
|
||||
|
||||
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
|
||||
for a given document.
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
name="paraphrase-albert-small-v2"
|
||||
model = SentenceTransformer(name)
|
||||
|
||||
# used for both training and querying
|
||||
def embed_func(batch):
|
||||
return [model.encode(sentence) for sentence in batch]
|
||||
```
|
||||
|
||||
|
||||
### OpenAI
|
||||
|
||||
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
|
||||
|
||||
```python
|
||||
import openai
|
||||
import os
|
||||
|
||||
# Configuring the environment variable OPENAI_API_KEY
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
|
||||
client = openai.OpenAI()
|
||||
|
||||
def embed_func(c):
|
||||
rs = client.embeddings.create(input=c, model="text-embedding-ada-002")
|
||||
return [record.embedding for record in rs["data"]]
|
||||
```
|
||||
|
||||
|
||||
## Applying an embedding function to data
|
||||
|
||||
Using an embedding function, you can apply it to raw data
|
||||
to generate embeddings for each record.
|
||||
|
||||
Say you have a pandas DataFrame with a `text` column that you want embedded,
|
||||
you can use the `with_embeddings` function to generate embeddings and add them to
|
||||
an existing table.
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
from lancedb.embeddings import with_embeddings
|
||||
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{"text": "pepperoni"},
|
||||
{"text": "pineapple"}
|
||||
]
|
||||
)
|
||||
data = with_embeddings(embed_func, df)
|
||||
|
||||
# The output is used to create / append to a table
|
||||
tbl = db.create_table("my_table", data=data)
|
||||
```
|
||||
|
||||
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
|
||||
|
||||
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
|
||||
using the `batch_size` parameter to `with_embeddings`.
|
||||
|
||||
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
|
||||
API call is reliable.
|
||||
|
||||
## Querying using an embedding function
|
||||
|
||||
!!! warning
|
||||
At query time, you **must** use the same embedding function you used to vectorize your data.
|
||||
If you use a different embedding function, the embeddings will not reside in the same vector
|
||||
space and the results will be nonsensical.
|
||||
|
||||
=== "Python"
|
||||
```python
|
||||
query = "What's the best pizza topping?"
|
||||
query_vector = embed_func([query])[0]
|
||||
results = (
|
||||
tbl.search(query_vector)
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||
@@ -1,5 +1,6 @@
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
window.addEventListener('load', function() {
|
||||
var meta = document.createElement('meta');
|
||||
meta.setAttribute('property', 'og:image');
|
||||
meta.setAttribute('content', '/assets/lancedb_and_lance.png');
|
||||
document.head.appendChild(meta);
|
||||
});
|
||||
@@ -290,7 +290,7 @@
|
||||
"from lancedb.pydantic import LanceModel, Vector\n",
|
||||
"\n",
|
||||
"class Pets(LanceModel):\n",
|
||||
" vector: Vector(clip.ndims()) = clip.VectorField()\n",
|
||||
" vector: Vector(clip.ndims) = clip.VectorField()\n",
|
||||
" image_uri: str = clip.SourceField()\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
@@ -360,7 +360,7 @@
|
||||
" table = db.create_table(\"pets\", schema=Pets)\n",
|
||||
" # use a sampling of 1000 images\n",
|
||||
" p = Path(\"~/Downloads/images\").expanduser()\n",
|
||||
" uris = [str(f) for f in p.glob(\"*.jpg\")]\n",
|
||||
" uris = [str(f) for f in p.iterdir()]\n",
|
||||
" uris = sample(uris, 1000)\n",
|
||||
" table.add(pd.DataFrame({\"image_uri\": uris}))"
|
||||
]
|
||||
@@ -543,7 +543,7 @@
|
||||
],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"p = Path(\"~/Downloads/images/samoyed_100.jpg\").expanduser()\n",
|
||||
"p = Path(\"/Users/changshe/Downloads/images/samoyed_100.jpg\")\n",
|
||||
"query_image = Image.open(p)\n",
|
||||
"query_image"
|
||||
]
|
||||
|
||||
@@ -23,8 +23,10 @@ from multiprocessing import Pool
|
||||
import lance
|
||||
import pyarrow as pa
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
|
||||
|
||||
import lancedb
|
||||
|
||||
MODEL_ID = "openai/clip-vit-base-patch32"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.5.6
|
||||
current_version = 0.5.5
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -26,7 +26,7 @@ import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
from retry import retry
|
||||
|
||||
from ..util import deprecated, safe_import_pandas
|
||||
from ..util import safe_import_pandas
|
||||
from ..utils.general import LOGGER
|
||||
|
||||
pd = safe_import_pandas()
|
||||
@@ -38,7 +38,6 @@ IMAGES = Union[
|
||||
]
|
||||
|
||||
|
||||
@deprecated
|
||||
def with_embeddings(
|
||||
func: Callable,
|
||||
data: DATA,
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing import (
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
@@ -37,6 +38,11 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
from lancedb.util import safe_import_tf, safe_import_torch
|
||||
|
||||
torch = safe_import_torch()
|
||||
tf = safe_import_tf()
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
@@ -79,9 +85,6 @@ def Vector(
|
||||
) -> Type[FixedSizeListMixin]:
|
||||
"""Pydantic Vector Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dim : int
|
||||
@@ -155,6 +158,142 @@ def Vector(
|
||||
return FixedSizeList
|
||||
|
||||
|
||||
class FixedShapeTensorMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def shape() -> Tuple[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def Tensor(
|
||||
shape: Tuple[int], value_type: pa.DataType = pa.float32()
|
||||
) -> Type[FixedShapeTensorMixin]:
|
||||
"""Pydantic Tensor Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape : tuple of int
|
||||
The shape of the tensor
|
||||
value_type : pyarrow.DataType, optional
|
||||
The value type of the vector, by default pa.float32()
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import LanceModel, Tensor, Vector
|
||||
...
|
||||
>>> class MyModel(LanceModel):
|
||||
... id: int
|
||||
... url: str
|
||||
... tensor: Tensor((3, 3))
|
||||
... embedding: Vector(768)
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("url", pa.utf8(), False),
|
||||
... pa.field("tensor", pa.fixed_shape_tensor(pa.float32(), (3, 3)), False),
|
||||
... pa.field("embeddings", pa.list_(pa.float32(), 768), False)
|
||||
... ])
|
||||
"""
|
||||
|
||||
# TODO: make a public parameterized type.
|
||||
class FixedShapeTensor(FixedShapeTensorMixin):
|
||||
def __repr__(self):
|
||||
return f"FixedShapeTensor(shape={shape})"
|
||||
|
||||
@staticmethod
|
||||
def shape() -> Tuple[int]:
|
||||
return shape
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return value_type
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
np.asarray,
|
||||
nested_schema(shape, core_schema.float_schema()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if isinstance(v, list):
|
||||
v = cls._validate_list(v, shape)
|
||||
elif isinstance(v, np.ndarray):
|
||||
v = cls._validate_ndarray(v, shape)
|
||||
elif torch is not None and isinstance(v, torch.Tensor):
|
||||
v = cls._validate_torch(v, shape)
|
||||
elif tf is not None and isinstance(v, tf.Tensor):
|
||||
v = cls._validate_tf(v, shape)
|
||||
else:
|
||||
raise TypeError(
|
||||
"A list of numbers, numpy.ndarray, torch.Tensor, "
|
||||
f"or tf.Tensor is needed but got {type(v)} instead."
|
||||
)
|
||||
return np.asarray(v)
|
||||
|
||||
@classmethod
|
||||
def _validate_list(cls, v, shape):
|
||||
v = np.asarray(v)
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
@classmethod
|
||||
def _validate_ndarray(cls, v, shape):
|
||||
if v.shape != shape:
|
||||
raise ValueError(f"Invalid shape {v.shape}, expected {shape}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def _validate_torch(cls, v, shape):
|
||||
v = v.detach().cpu().numpy()
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
@classmethod
|
||||
def _validate_tf(cls, v, shape):
|
||||
v = v.numpy()
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any], field):
|
||||
if field and field.sub_fields:
|
||||
type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]"
|
||||
else:
|
||||
type_with_potential_subtype = "np.ndarray"
|
||||
field_schema.update({"type": type_with_potential_subtype})
|
||||
|
||||
return FixedShapeTensor
|
||||
|
||||
|
||||
def nested_schema(shape, items_schema):
|
||||
if len(shape) == 0:
|
||||
return items_schema
|
||||
else:
|
||||
return core_schema.list_schema(
|
||||
min_length=shape[0],
|
||||
max_length=shape[0],
|
||||
items_schema=nested_schema(shape[1:], items_schema),
|
||||
)
|
||||
|
||||
|
||||
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a field with native Python type to Arrow data type.
|
||||
|
||||
@@ -230,6 +369,10 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
return pa.struct(fields)
|
||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||
elif issubclass(field.annotation, FixedShapeTensorMixin):
|
||||
return pa.fixed_shape_tensor(
|
||||
field.annotation.value_arrow_type(), field.annotation.shape()
|
||||
)
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
|
||||
@@ -1568,7 +1568,7 @@ def _sanitize_schema(
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_float32(field.type.value_type)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and field.type.list_size >= 10
|
||||
)
|
||||
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
|
||||
@@ -1581,6 +1581,11 @@ def _sanitize_schema(
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
is_tensor_type = isinstance(field.type, pa.FixedShapeTensorType)
|
||||
if is_tensor_type and field.name in data.column_names:
|
||||
data = _sanitize_tensor_column(data, column_name=field.name)
|
||||
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
@@ -1649,6 +1654,31 @@ def _sanitize_vector_column(
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_tensor_column(data: pa.Table, column_name: str) -> pa.Table:
|
||||
"""
|
||||
Ensure that the tensor column exists and has type tensor(float32)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: pa.Table
|
||||
The table to sanitize.
|
||||
column_name: str
|
||||
The name of the tensor column.
|
||||
"""
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
tensor_arr = data[column_name].combine_chunks()
|
||||
typ = data[column_name].type
|
||||
if not isinstance(typ, pa.FixedShapeTensorType):
|
||||
raise TypeError(f"Unsupported tensor column type: {tensor_arr.type}")
|
||||
|
||||
tensor_arr = ensure_tensor(tensor_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(column_name), column_name, tensor_arr
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
values = vec_arr.values
|
||||
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
|
||||
@@ -1661,6 +1691,11 @@ def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
return vec_arr
|
||||
|
||||
|
||||
def ensure_tensor(tensor_arr) -> pa.TensorArray:
|
||||
assert 0 == 1
|
||||
return tensor_arr
|
||||
|
||||
|
||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize jagged vectors."""
|
||||
if on_bad_vectors == "error":
|
||||
|
||||
@@ -11,11 +11,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import warnings
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union
|
||||
@@ -155,6 +153,24 @@ def safe_import_polars():
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_torch():
|
||||
try:
|
||||
import torch
|
||||
|
||||
return torch
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_tf():
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
return tf
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
@@ -241,25 +257,3 @@ def _(value: list):
|
||||
@value_to_sql.register(np.ndarray)
|
||||
def _(value: np.ndarray):
|
||||
return value_to_sql(value.tolist())
|
||||
|
||||
|
||||
def deprecated(func):
|
||||
"""This is a decorator which can be used to mark functions
|
||||
as deprecated. It will result in a warning being emitted
|
||||
when the function is used."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def new_func(*args, **kwargs):
|
||||
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
||||
warnings.warn(
|
||||
(
|
||||
f"Function {func.__name__} is deprecated and will be "
|
||||
"removed in a future version"
|
||||
),
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.5.6"
|
||||
version = "0.5.5"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.16",
|
||||
|
||||
@@ -22,7 +22,13 @@ import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Tensor,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -244,3 +250,37 @@ def test_lance_model():
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_tensor():
|
||||
class TestModel(LanceModel):
|
||||
tensor: Tensor((3, 3))
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["tensor"]
|
||||
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
json_schema = TestModel.model_json_schema()
|
||||
else:
|
||||
json_schema = TestModel.schema()
|
||||
|
||||
assert json_schema == {
|
||||
"properties": {
|
||||
"tensor": {
|
||||
"items": {
|
||||
"items": {"type": "number"},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"type": "array",
|
||||
},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"title": "Tensor",
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["tensor"],
|
||||
"title": "TestModel",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Tensor, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -898,3 +898,18 @@ def test_restore_consistency(tmp_path):
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
|
||||
def test_tensor_type(tmp_path):
|
||||
# create a model with a tensor column
|
||||
class MyTable(LanceModel):
|
||||
tensor: Tensor((256, 256, 3))
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
tensor = np.random.rand(256, 256, 3)
|
||||
table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}])
|
||||
|
||||
result = table.search().limit(2).to_pandas()
|
||||
assert np.allclose(result.tensor[0], result.tensor[1])
|
||||
|
||||
Reference in New Issue
Block a user