mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat(python): Embedding fn support for gte-mlx/gte-large (#873)
have added testing and an example in the docstring, will be pushing a separate PR in recipe repo for rag example --------- Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@@ -79,7 +79,10 @@ def qanda_langchain(query):
|
||||
download_docs()
|
||||
docs = store_docs()
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200,)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200,
|
||||
)
|
||||
documents = text_splitter.split_documents(docs)
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
|
||||
if not skip_test:
|
||||
yield line[strip_length:]
|
||||
|
||||
|
||||
for file in filter(lambda file: file not in excluded_files, files):
|
||||
with open(file, "r") as f:
|
||||
lines = list(yield_lines(iter(f), "```", "```"))
|
||||
|
||||
130
python/lancedb/embeddings/gte.py
Normal file
130
python/lancedb/embeddings/gte.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
|
||||
@register("gte-text")
|
||||
class GteEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only)
|
||||
as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large.
|
||||
|
||||
For Apple users, you will need the mlx package insalled, which can be done with:
|
||||
pip install mlx
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "thenlper/gte-large"
|
||||
The name of the model to use.
|
||||
device: str, default "cpu"
|
||||
Sets the device type for the model.
|
||||
normalize: str, default "True"
|
||||
Controls normalize param in encode function for the transformer.
|
||||
mlx: bool, default False
|
||||
Controls which model to use. False for gte-large,True for the mlx version.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
import lancedb.embeddings.gte
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
import pandas as pd
|
||||
|
||||
model = get_registry().get("gte-text").create() # mlx=True for Apple silicon
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
rs = tbl.search("hello").limit(1).to_pandas()
|
||||
|
||||
"""
|
||||
|
||||
name: str = "thenlper/gte-large"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
mlx: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
if kwargs:
|
||||
self.mlx = kwargs.get("mlx", False)
|
||||
if self.mlx is True:
|
||||
self.name == "gte-mlx"
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the embedding model specified by the flag,
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.get_embedding_model()
|
||||
|
||||
def ndims(self):
|
||||
if self.mlx is True:
|
||||
self._ndims = self.embedding_model.dims
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
if self.mlx is True:
|
||||
return self.embedding_model.run(list(texts)).tolist()
|
||||
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
Get the embedding model specified by the flag,
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
if self.mlx is True:
|
||||
from .gte_mlx_model import Model
|
||||
|
||||
return Model()
|
||||
else:
|
||||
sentence_transformers = self.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(
|
||||
self.name, device=self.device
|
||||
)
|
||||
154
python/lancedb/embeddings/gte_mlx_model.py
Normal file
154
python/lancedb/embeddings/gte_mlx_model.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from pydantic import BaseModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
except ImportError:
|
||||
raise ImportError("You need to install MLX to use this model use - pip install mlx")
|
||||
|
||||
|
||||
def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array:
|
||||
last_hidden = mx.multiply(last_hidden_state, attention_mask[..., None])
|
||||
return last_hidden.sum(axis=1) / attention_mask.sum(axis=1)[..., None]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
dim: int = 1024
|
||||
num_attention_heads: int = 16
|
||||
num_hidden_layers: int = 24
|
||||
vocab_size: int = 30522
|
||||
attention_probs_dropout_prob: float = 0.1
|
||||
hidden_dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-12
|
||||
max_position_embeddings: int = 512
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
A transformer encoder layer with (the original BERT) post-normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
||||
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||
self.linear1 = nn.Linear(dims, mlp_dims)
|
||||
self.linear2 = nn.Linear(mlp_dims, dims)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def __call__(self, x, mask):
|
||||
attention_out = self.attention(x, x, x, mask)
|
||||
add_and_norm = self.ln1(x + attention_out)
|
||||
|
||||
ff = self.linear1(add_and_norm)
|
||||
ff_gelu = self.gelu(ff)
|
||||
ff_out = self.linear2(ff_gelu)
|
||||
x = self.ln2(ff_out + add_and_norm)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.dim
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
||||
words = self.word_embeddings(input_ids)
|
||||
position = self.position_embeddings(
|
||||
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
||||
)
|
||||
token_types = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = position + words + token_types
|
||||
return self.norm(embeddings)
|
||||
|
||||
|
||||
class Bert(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = TransformerEncoder(
|
||||
num_layers=config.num_hidden_layers,
|
||||
dims=config.dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
)
|
||||
self.pooler = nn.Linear(config.dim, config.dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
token_type_ids: mx.array,
|
||||
attention_mask: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.embeddings(input_ids, token_type_ids)
|
||||
|
||||
if attention_mask is not None:
|
||||
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||
attention_mask = mx.log(attention_mask)
|
||||
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||
|
||||
y = self.encoder(x, attention_mask)
|
||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self) -> None:
|
||||
# get converted embedding model
|
||||
model_path = snapshot_download(repo_id="vegaluisjose/mlx-rag")
|
||||
with open(f"{model_path}/config.json") as f:
|
||||
model_config = ModelConfig(**json.load(f))
|
||||
self.dims = model_config.dim
|
||||
self.model = Bert(model_config)
|
||||
self.model.load_weights(f"{model_path}/model.npz")
|
||||
self.tokenizer = BertTokenizer.from_pretrained("thenlper/gte-large")
|
||||
self.embeddings = []
|
||||
|
||||
def run(self, input_text: List[str]) -> mx.array:
|
||||
tokens = self.tokenizer(input_text, return_tensors="np", padding=True)
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
|
||||
last_hidden_state, _ = self.model(**tokens)
|
||||
|
||||
embeddings = average_pool(
|
||||
last_hidden_state, tokens["attention_mask"].astype(mx.float32)
|
||||
)
|
||||
self.embeddings = (
|
||||
embeddings / mx.linalg.norm(embeddings, ord=2, axis=1)[..., None]
|
||||
)
|
||||
|
||||
return np.array(embeddings.astype(mx.float32))
|
||||
@@ -52,8 +52,8 @@ tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "d
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere",
|
||||
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57" ]
|
||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
|
||||
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
|
||||
|
||||
[project.scripts]
|
||||
lancedb = "lancedb.cli.cli:cli"
|
||||
@@ -66,7 +66,8 @@ build-backend = "setuptools.build_meta"
|
||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers"
|
||||
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio"
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
@@ -22,6 +23,11 @@ import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
# or connection to external api
|
||||
@@ -204,6 +210,29 @@ def test_gemini_embedding(tmp_path):
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_mlx is None,
|
||||
reason="mlx tests only required for apple users.",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_gte_embedding(tmp_path):
|
||||
import lancedb.embeddings.gte
|
||||
|
||||
model = get_registry().get("gte-text").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
def aws_setup():
|
||||
try:
|
||||
import boto3
|
||||
|
||||
Reference in New Issue
Block a user