feat(python): Embedding fn support for gte-mlx/gte-large (#873)

have added testing and an example in the docstring, will be pushing a
separate PR in recipe repo for rag example

---------

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Raghav Dixit
2024-01-30 00:51:57 -05:00
committed by GitHub
parent 5c5e23bbb9
commit d1a7257810
6 changed files with 322 additions and 4 deletions

View File

@@ -79,7 +79,10 @@ def qanda_langchain(query):
download_docs()
docs = store_docs()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200,)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
documents = text_splitter.split_documents(docs)
embeddings = OpenAIEmbeddings()

View File

@@ -48,6 +48,7 @@ def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
if not skip_test:
yield line[strip_length:]
for file in filter(lambda file: file not in excluded_files, files):
with open(file, "r") as f:
lines = list(yield_lines(iter(f), "```", "```"))

View File

@@ -0,0 +1,130 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import numpy as np
from .base import TextEmbeddingFunction
from .registry import register
from .utils import weak_lru
@register("gte-text")
class GteEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only)
as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large.
For Apple users, you will need the mlx package insalled, which can be done with:
pip install mlx
Parameters
----------
name: str, default "thenlper/gte-large"
The name of the model to use.
device: str, default "cpu"
Sets the device type for the model.
normalize: str, default "True"
Controls normalize param in encode function for the transformer.
mlx: bool, default False
Controls which model to use. False for gte-large,True for the mlx version.
Examples
--------
import lancedb
import lancedb.embeddings.gte
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
import pandas as pd
model = get_registry().get("gte-text").create() # mlx=True for Apple silicon
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas()
"""
name: str = "thenlper/gte-large"
device: str = "cpu"
normalize: bool = True
mlx: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
if kwargs:
self.mlx = kwargs.get("mlx", False)
if self.mlx is True:
self.name == "gte-mlx"
@property
def embedding_model(self):
"""
Get the embedding model specified by the flag,
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.get_embedding_model()
def ndims(self):
if self.mlx is True:
self._ndims = self.embedding_model.dims
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts.
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
if self.mlx is True:
return self.embedding_model.run(list(texts)).tolist()
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@weak_lru(maxsize=1)
def get_embedding_model(self):
"""
Get the embedding model specified by the flag,
name and device. This is cached so that the model is only loaded
once per process.
"""
if self.mlx is True:
from .gte_mlx_model import Model
return Model()
else:
sentence_transformers = self.safe_import(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(
self.name, device=self.device
)

View File

@@ -0,0 +1,154 @@
import json
from typing import List, Optional
import numpy as np
from huggingface_hub import snapshot_download
from pydantic import BaseModel
from transformers import BertTokenizer
try:
import mlx.core as mx
import mlx.nn as nn
except ImportError:
raise ImportError("You need to install MLX to use this model use - pip install mlx")
def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array:
last_hidden = mx.multiply(last_hidden_state, attention_mask[..., None])
return last_hidden.sum(axis=1) / attention_mask.sum(axis=1)[..., None]
class ModelConfig(BaseModel):
dim: int = 1024
num_attention_heads: int = 16
num_hidden_layers: int = 24
vocab_size: int = 30522
attention_probs_dropout_prob: float = 0.1
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
class TransformerEncoderLayer(nn.Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
"""
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
layer_norm_eps: float = 1e-12,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = nn.Linear(dims, mlp_dims)
self.linear2 = nn.Linear(mlp_dims, dims)
self.gelu = nn.GELU()
def __call__(self, x, mask):
attention_out = self.attention(x, x, x, mask)
add_and_norm = self.ln1(x + attention_out)
ff = self.linear1(add_and_norm)
ff_gelu = self.gelu(ff)
ff_out = self.linear2(ff_gelu)
x = self.ln2(ff_out + add_and_norm)
return x
class TransformerEncoder(nn.Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
def __call__(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return x
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelConfig):
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.token_type_embeddings = nn.Embedding(2, config.dim)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.dim
)
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids)
position = self.position_embeddings(
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
)
token_types = self.token_type_embeddings(token_type_ids)
embeddings = position + words + token_types
return self.norm(embeddings)
class Bert(nn.Module):
def __init__(self, config: ModelConfig):
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.dim,
num_heads=config.num_attention_heads,
)
self.pooler = nn.Linear(config.dim, config.dim)
def __call__(
self,
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: mx.array = None,
) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
if attention_mask is not None:
# convert 0's to -infs, 1's to 0's, and make it broadcastable
attention_mask = mx.log(attention_mask)
attention_mask = mx.expand_dims(attention_mask, (1, 2))
y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0]))
class Model:
def __init__(self) -> None:
# get converted embedding model
model_path = snapshot_download(repo_id="vegaluisjose/mlx-rag")
with open(f"{model_path}/config.json") as f:
model_config = ModelConfig(**json.load(f))
self.dims = model_config.dim
self.model = Bert(model_config)
self.model.load_weights(f"{model_path}/model.npz")
self.tokenizer = BertTokenizer.from_pretrained("thenlper/gte-large")
self.embeddings = []
def run(self, input_text: List[str]) -> mx.array:
tokens = self.tokenizer(input_text, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
last_hidden_state, _ = self.model(**tokens)
embeddings = average_pool(
last_hidden_state, tokens["attention_mask"].astype(mx.float32)
)
self.embeddings = (
embeddings / mx.linalg.norm(embeddings, ord=2, axis=1)[..., None]
)
return np.array(embeddings.astype(mx.float32))

View File

@@ -52,8 +52,8 @@ tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "d
dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere",
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57" ]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
[project.scripts]
lancedb = "lancedb.cli.cli:cli"
@@ -66,7 +66,8 @@ build-backend = "setuptools.build_meta"
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
[tool.pytest.ini_options]
addopts = "--strict-markers"
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio"

View File

@@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import io
import os
@@ -22,6 +23,11 @@ import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
@@ -204,6 +210,29 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",
)
@pytest.mark.slow
def test_gte_embedding(tmp_path):
import lancedb.embeddings.gte
model = get_registry().get("gte-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
def aws_setup():
try:
import boto3