diff --git a/docs/src/examples/modal_langchain.py b/docs/src/examples/modal_langchain.py index 3e2416fe..20c9960a 100644 --- a/docs/src/examples/modal_langchain.py +++ b/docs/src/examples/modal_langchain.py @@ -79,7 +79,10 @@ def qanda_langchain(query): download_docs() docs = store_docs() - text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200,) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200, + ) documents = text_splitter.split_documents(docs) embeddings = OpenAIEmbeddings() diff --git a/docs/test/md_testing.py b/docs/test/md_testing.py index b1a629e8..2a4012c5 100644 --- a/docs/test/md_testing.py +++ b/docs/test/md_testing.py @@ -48,6 +48,7 @@ def yield_lines(lines: Iterator[str], prefix: str, suffix: str): if not skip_test: yield line[strip_length:] + for file in filter(lambda file: file not in excluded_files, files): with open(file, "r") as f: lines = list(yield_lines(iter(f), "```", "```")) diff --git a/python/lancedb/embeddings/gte.py b/python/lancedb/embeddings/gte.py new file mode 100644 index 00000000..bdff1ffc --- /dev/null +++ b/python/lancedb/embeddings/gte.py @@ -0,0 +1,130 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import numpy as np + +from .base import TextEmbeddingFunction +from .registry import register +from .utils import weak_lru + + +@register("gte-text") +class GteEmbeddings(TextEmbeddingFunction): + """ + An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only) + as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large. + + For Apple users, you will need the mlx package insalled, which can be done with: + pip install mlx + + Parameters + ---------- + name: str, default "thenlper/gte-large" + The name of the model to use. + device: str, default "cpu" + Sets the device type for the model. + normalize: str, default "True" + Controls normalize param in encode function for the transformer. + mlx: bool, default False + Controls which model to use. False for gte-large,True for the mlx version. + + Examples + -------- + import lancedb + import lancedb.embeddings.gte + from lancedb.embeddings import get_registry + from lancedb.pydantic import LanceModel, Vector + import pandas as pd + + model = get_registry().get("gte-text").create() # mlx=True for Apple silicon + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]}) + db = lancedb.connect("~/.lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + rs = tbl.search("hello").limit(1).to_pandas() + + """ + + name: str = "thenlper/gte-large" + device: str = "cpu" + normalize: bool = True + mlx: bool = False + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._ndims = None + if kwargs: + self.mlx = kwargs.get("mlx", False) + if self.mlx is True: + self.name == "gte-mlx" + + @property + def embedding_model(self): + """ + Get the embedding model specified by the flag, + name and device. This is cached so that the model is only loaded + once per process. + """ + return self.get_embedding_model() + + def ndims(self): + if self.mlx is True: + self._ndims = self.embedding_model.dims + if self._ndims is None: + self._ndims = len(self.generate_embeddings("foo")[0]) + return self._ndims + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Get the embeddings for the given texts. + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + if self.mlx is True: + return self.embedding_model.run(list(texts)).tolist() + + return self.embedding_model.encode( + list(texts), + convert_to_numpy=True, + normalize_embeddings=self.normalize, + ).tolist() + + @weak_lru(maxsize=1) + def get_embedding_model(self): + """ + Get the embedding model specified by the flag, + name and device. This is cached so that the model is only loaded + once per process. + """ + if self.mlx is True: + from .gte_mlx_model import Model + + return Model() + else: + sentence_transformers = self.safe_import( + "sentence_transformers", "sentence-transformers" + ) + return sentence_transformers.SentenceTransformer( + self.name, device=self.device + ) diff --git a/python/lancedb/embeddings/gte_mlx_model.py b/python/lancedb/embeddings/gte_mlx_model.py new file mode 100644 index 00000000..89c509f6 --- /dev/null +++ b/python/lancedb/embeddings/gte_mlx_model.py @@ -0,0 +1,154 @@ +import json +from typing import List, Optional + +import numpy as np +from huggingface_hub import snapshot_download +from pydantic import BaseModel +from transformers import BertTokenizer + +try: + import mlx.core as mx + import mlx.nn as nn +except ImportError: + raise ImportError("You need to install MLX to use this model use - pip install mlx") + + +def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array: + last_hidden = mx.multiply(last_hidden_state, attention_mask[..., None]) + return last_hidden.sum(axis=1) / attention_mask.sum(axis=1)[..., None] + + +class ModelConfig(BaseModel): + dim: int = 1024 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + vocab_size: int = 30522 + attention_probs_dropout_prob: float = 0.1 + hidden_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-12 + max_position_embeddings: int = 512 + + +class TransformerEncoderLayer(nn.Module): + """ + A transformer encoder layer with (the original BERT) post-normalization. + """ + + def __init__( + self, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + layer_norm_eps: float = 1e-12, + ): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True) + self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) + self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) + self.linear1 = nn.Linear(dims, mlp_dims) + self.linear2 = nn.Linear(mlp_dims, dims) + self.gelu = nn.GELU() + + def __call__(self, x, mask): + attention_out = self.attention(x, x, x, mask) + add_and_norm = self.ln1(x + attention_out) + + ff = self.linear1(add_and_norm) + ff_gelu = self.gelu(ff) + ff_out = self.linear2(ff_gelu) + x = self.ln2(ff_out + add_and_norm) + + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.layers = [ + TransformerEncoderLayer(dims, num_heads, mlp_dims) + for i in range(num_layers) + ] + + def __call__(self, x, mask): + for layer in self.layers: + x = layer(x, mask) + + return x + + +class BertEmbeddings(nn.Module): + def __init__(self, config: ModelConfig): + self.word_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.token_type_embeddings = nn.Embedding(2, config.dim) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.dim + ) + self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps) + + def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array: + words = self.word_embeddings(input_ids) + position = self.position_embeddings( + mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape) + ) + token_types = self.token_type_embeddings(token_type_ids) + + embeddings = position + words + token_types + return self.norm(embeddings) + + +class Bert(nn.Module): + def __init__(self, config: ModelConfig): + self.embeddings = BertEmbeddings(config) + self.encoder = TransformerEncoder( + num_layers=config.num_hidden_layers, + dims=config.dim, + num_heads=config.num_attention_heads, + ) + self.pooler = nn.Linear(config.dim, config.dim) + + def __call__( + self, + input_ids: mx.array, + token_type_ids: mx.array, + attention_mask: mx.array = None, + ) -> tuple[mx.array, mx.array]: + x = self.embeddings(input_ids, token_type_ids) + + if attention_mask is not None: + # convert 0's to -infs, 1's to 0's, and make it broadcastable + attention_mask = mx.log(attention_mask) + attention_mask = mx.expand_dims(attention_mask, (1, 2)) + + y = self.encoder(x, attention_mask) + return y, mx.tanh(self.pooler(y[:, 0])) + + +class Model: + def __init__(self) -> None: + # get converted embedding model + model_path = snapshot_download(repo_id="vegaluisjose/mlx-rag") + with open(f"{model_path}/config.json") as f: + model_config = ModelConfig(**json.load(f)) + self.dims = model_config.dim + self.model = Bert(model_config) + self.model.load_weights(f"{model_path}/model.npz") + self.tokenizer = BertTokenizer.from_pretrained("thenlper/gte-large") + self.embeddings = [] + + def run(self, input_text: List[str]) -> mx.array: + tokens = self.tokenizer(input_text, return_tensors="np", padding=True) + tokens = {key: mx.array(v) for key, v in tokens.items()} + + last_hidden_state, _ = self.model(**tokens) + + embeddings = average_pool( + last_hidden_state, tokens["attention_mask"].astype(mx.float32) + ) + self.embeddings = ( + embeddings / mx.linalg.norm(embeddings, ord=2, axis=1)[..., None] + ) + + return np.array(embeddings.astype(mx.float32)) diff --git a/python/pyproject.toml b/python/pyproject.toml index 3a42dc5c..e43b635f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -52,8 +52,8 @@ tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "d dev = ["ruff", "pre-commit"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] -embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", - "InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57" ] +embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub", + "InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"] [project.scripts] lancedb = "lancedb.cli.cli:cli" @@ -66,7 +66,8 @@ build-backend = "setuptools.build_meta" select = ["F", "E", "W", "I", "G", "TCH", "PERF"] [tool.pytest.ini_options] -addopts = "--strict-markers" +addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py" + markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "asyncio" diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index f280482b..e3724b03 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import io import os @@ -22,6 +23,11 @@ import lancedb from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector +try: + if importlib.util.find_spec("mlx.core") is not None: + _mlx = True +except ImportError: + _mlx = None # These are integration tests for embedding functions. # They are slow because they require downloading models # or connection to external api @@ -204,6 +210,29 @@ def test_gemini_embedding(tmp_path): assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" +@pytest.mark.skipif( + _mlx is None, + reason="mlx tests only required for apple users.", +) +@pytest.mark.slow +def test_gte_embedding(tmp_path): + import lancedb.embeddings.gte + + model = get_registry().get("gte-text").create() + + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" + + def aws_setup(): try: import boto3