From c75bb65609f03c4da750fecc45c90c9cea4a8905 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 15 Apr 2024 05:59:26 +0530 Subject: [PATCH] update --- .../lancedb/embeddings/fine_tuner/README.md | 134 +++++++++++++ .../lancedb/embeddings/fine_tuner/__init__.py | 4 + .../embeddings/fine_tuner/basetuner.py | 13 ++ .../lancedb/embeddings/fine_tuner/dataset.py | 179 ++++++++++++++++++ .../lancedb/embeddings/fine_tuner/llm.py | 85 +++++++++ test.py | 121 ++++++++++++ 6 files changed, 536 insertions(+) create mode 100644 python/python/lancedb/embeddings/fine_tuner/README.md create mode 100644 python/python/lancedb/embeddings/fine_tuner/__init__.py create mode 100644 python/python/lancedb/embeddings/fine_tuner/basetuner.py create mode 100644 python/python/lancedb/embeddings/fine_tuner/dataset.py create mode 100644 python/python/lancedb/embeddings/fine_tuner/llm.py create mode 100644 test.py diff --git a/python/python/lancedb/embeddings/fine_tuner/README.md b/python/python/lancedb/embeddings/fine_tuner/README.md new file mode 100644 index 00000000..7cee0cf9 --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/README.md @@ -0,0 +1,134 @@ +Fine-tuning workflow for embeddings consists for the following parts: + +### QADataset +This class is used for managing the data for fine-tuning. It contains the following builder methods: +``` +- from_llm( + nodes: 'List[TextChunk]' , + llm: BaseLLM, + qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL, + num_questions_per_chunk: int = 2, +) -> "QADataset" +``` +Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned. + +```python + +from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset" +``` +Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one. + +It also contains the following data attributes: +``` + queries (Dict[str, str]): Dict id -> query. + corpus (Dict[str, str]): Dict id -> string. + relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. +``` + +### TextChunk +This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes: +``` + text: str + id: str + metadata: Dict[str, Any] = {} +``` + +Builder Methods: + +```python +from_llama_index_node(node) -> "TextChunk" +``` +Create a text chunk from a llama index node. + +```python +from_langchain_node(node) -> "TextChunk" +``` +Create a text chunk from a langchain index node. + +```python +from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk" +``` +Create a text chunk from a string. + +### FineTuner +This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api. +```python +class BaseEmbeddingTuner(ABC): + """Base Embedding finetuning engine.""" + + @abstractmethod + def finetune(self) -> None: + """Goes off and does stuff.""" + + def helper(self) -> None: + """A helper method.""" + pass +``` + +### Embedding API finetuning implementation +Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`. + +### Fine-tuning workflow +The fine-tuning workflow is as follows: +1. Create a `QADataset` object. +2. Initialize any embedding function using LanceDB embedding API +3. Call `finetune` method on the embedding object with the `QADataset` object as an argument. +4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API. + +# End-to-End Examples +The following is an example of how to fine-tune an embedding model using the LanceDB embedding API. + +## Example 1: Fine-tuning from a synthetic dataset + +```python +import pandas as pd + +from lancedb.embeddings.fine_tuner.llm import Openai +from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk +from lancedb.pydantic import LanceModel, Vector +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.schema import MetadataMode +from lancedb.embeddings import get_registry + +# 1. Create a QADataset object +url = "uber10k.pdf" +reader = SimpleDirectoryReader(input_files=url) +docs = reader.load_data() + +parser = SentenceSplitter() +nodes = parser.get_nodes_from_documents(docs) + +if os.path.exists(name): + ds = QADataset.load(name) +else: + llm = Openai() + + # convert Llama-index TextNode to TextChunk + chunks = [TextChunk.from_llama_index_node(node) for node in nodes] + + ds = QADataset.from_llm(chunks, llm) + ds.save(name) + +# 2. Initialize the embedding model +model = get_registry().get("sentence-transformers").create() + +# 3. Fine-tune the model +model.finetune(trainset=ds, path="model_finetuned", epochs=4) + +# 4. Evaluate the fine-tuned model +base = get_registry().get("sentence-transformers").create() +tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1") +openai = get_registry().get("openai").create(name="text-embedding-3-large") + + +rs1 = base.evaluate(trainset, path="val_res") +rs2 = tuned.evaluate(trainset, path="val_res") +rs3 = openai.evaluate(trainset) + +print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean()) +print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean()) +print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean()) +``` + + diff --git a/python/python/lancedb/embeddings/fine_tuner/__init__.py b/python/python/lancedb/embeddings/fine_tuner/__init__.py new file mode 100644 index 00000000..74df8066 --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/__init__.py @@ -0,0 +1,4 @@ +from .dataset import QADataset, TextChunk +from .llm import Gemini, Openai + +__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"] diff --git a/python/python/lancedb/embeddings/fine_tuner/basetuner.py b/python/python/lancedb/embeddings/fine_tuner/basetuner.py new file mode 100644 index 00000000..a90be05a --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/basetuner.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + + +class BaseEmbeddingTuner(ABC): + """Base Embedding finetuning engine.""" + + @abstractmethod + def finetune(self) -> None: + """Goes off and does stuff.""" + + def helper(self) -> None: + """A helper method.""" + pass diff --git a/python/python/lancedb/embeddings/fine_tuner/dataset.py b/python/python/lancedb/embeddings/fine_tuner/dataset.py new file mode 100644 index 00000000..6d01abff --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/dataset.py @@ -0,0 +1,179 @@ +import re +import uuid +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import lance +import pyarrow as pa +from pydantic import BaseModel +from tqdm import tqdm + +from .llm import BaseLLM + +DEFAULT_PROMPT_TMPL = """\ +Context information is below. + +--------------------- +{context_str} +--------------------- + +Given the context information and no prior knowledge. +generate only questions based on the below query. + +You are a Teacher/ Professor. Your task is to setup \ +{num_questions_per_chunk} questions for an upcoming \ +quiz/examination. The questions should be diverse in nature \ +across the document. Restrict the questions to the \ +context information provided." +""" + + +class QADataset(BaseModel): + """Embedding QA Finetuning Dataset. + + Args: + queries (Dict[str, str]): Dict id -> query. + corpus (Dict[str, str]): Dict id -> string. + relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. + + """ + + queries: Dict[str, str] # id -> query + corpus: Dict[str, str] # id -> text + relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids + mode: str = "text" + + @property + def query_docid_pairs(self) -> List[Tuple[str, List[str]]]: + """Get query, relevant doc ids.""" + return [ + (query, self.relevant_docs[query_id]) + for query_id, query in self.queries.items() + ] + + def save(self, path: str, mode: str = "overwrite") -> None: + """Save to lance dataset""" + save_dir = Path(path) + save_dir.mkdir(parents=True, exist_ok=True) + + # convert to pydict {"id": []} + queries = { + "id": list(self.queries.keys()), + "query": list(self.queries.values()), + } + corpus = { + "id": list(self.corpus.keys()), + "text": [ + val or " " for val in self.corpus.values() + ], # lance saves empty strings as null + } + relevant_docs = { + "query_id": list(self.relevant_docs.keys()), + "doc_id": list(self.relevant_docs.values()), + } + + # write to lance + lance.write_dataset( + pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode + ) + lance.write_dataset( + pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode + ) + lance.write_dataset( + pa.Table.from_pydict(relevant_docs), + save_dir / "relevant_docs.lance", + mode=mode, + ) + + @classmethod + def load(cls, path: str) -> "QADataset": + """Load from .lance data""" + load_dir = Path(path) + queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict() + corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict() + relevant_docs = ( + lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict() + ) + return cls( + queries=dict(zip(queries["id"], queries["query"])), + corpus=dict(zip(corpus["id"], corpus["text"])), + relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])), + ) + + # generate queries as a convenience function + @classmethod + def from_llm( + cls, + nodes: "List[TextChunk]", + llm: BaseLLM, + qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL, + num_questions_per_chunk: int = 2, + ) -> "QADataset": + """Generate examples given a set of nodes.""" + node_dict = {node.id: node.text for node in nodes} + + queries = {} + relevant_docs = {} + for node_id, text in tqdm(node_dict.items()): + query = qa_generate_prompt_tmpl.format( + context_str=text, num_questions_per_chunk=num_questions_per_chunk + ) + response = llm.chat_completion(query) + + result = str(response).strip().split("\n") + questions = [ + re.sub(r"^\d+[\).\s]", "", question).strip() for question in result + ] + questions = [question for question in questions if len(question) > 0] + for question in questions: + question_id = str(uuid.uuid4()) + queries[question_id] = question + relevant_docs[question_id] = [node_id] + + return QADataset(queries=queries, corpus=node_dict, relevant_docs=relevant_docs) + + @classmethod + def from_responses( + cls, + docs: List["TextChunk"], + queries: Dict[str, str], + relevant_docs: Dict[str, List[str]], + ) -> "QADataset": + """Create a QADataset from a list of TextChunks and a list of questions.""" + node_dict = {node.id: node.text for node in docs} + return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs) + + +class TextChunk(BaseModel): + """Simple text chunk for generating questions.""" + + text: str + id: str + metadata: Dict[str, Any] = {} + + @classmethod + def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk": + """Create a SimpleTextChunk from a chunk.""" + # generate a unique id + return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata) + + @classmethod + def from_llama_index_node(cls, node): + """Convert a llama index node to a text chunk.""" + return cls(text=node.text, id=node.node_id, metadata=node.metadata) + + @classmethod + def from_langchain_node(cls, node): + """Convert a langchaain node to a text chunk.""" + raise NotImplementedError("Not implemented yet.") + + def to_dict(self) -> Dict[str, Any]: + """Convert to a dictionary.""" + return self.dict() + + def __str__(self) -> str: + return self.text + + def __repr__(self) -> str: + return f"SimpleTextChunk(text={self.text}, id={self.id}, \ + metadata={self.metadata})" diff --git a/python/python/lancedb/embeddings/fine_tuner/llm.py b/python/python/lancedb/embeddings/fine_tuner/llm.py new file mode 100644 index 00000000..1bdf2d97 --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/llm.py @@ -0,0 +1,85 @@ +import os +import re +from functools import cached_property +from typing import Optional + +from pydantic import BaseModel + +from ...util import attempt_import_or_raise +from ..utils import api_key_not_found_help + + +class BaseLLM(BaseModel): + """ + TODO: + Base class for Language Model based Embedding Functions. This class is + loosely desined rn, and will be updated as the usage gets clearer. + """ + + model_name: str + model_kwargs: dict = {} + + @cached_property + def _client(): + """ + Get the client for the language model + """ + raise NotImplementedError + + def chat_completion(self, prompt: str, **kwargs): + """ + Get the chat completion for the given prompt + """ + raise NotImplementedError + + +class Openai(BaseLLM): + model_name: str = "gpt-3.5-turbo" + kwargs: dict = {} + api_key: Optional[str] = None + + @cached_property + def _client(self): + """ + Get the client for the language model + """ + openai = attempt_import_or_raise("openai") + + if not os.environ.get("OPENAI_API_KEY"): + api_key_not_found_help("openai") + return openai.OpenAI() + + def chat_completion(self, prompt: str) -> str: + """ + Get the chat completion for the given prompt + """ + + # TODO: this is legacy openai api replace with completions + completion = self._client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + **self.kwargs, + ) + + text = completion.choices[0].message.content + + return text + + def get_questions(self, prompt: str) -> str: + """ + Get the chat completion for the given prompt + """ + response = self.chat_completion(prompt) + result = str(response).strip().split("\n") + questions = [ + re.sub(r"^\d+[\).\s]", "", question).strip() for question in result + ] + questions = [question for question in questions if len(question) > 0] + return questions + + +class Gemini(BaseLLM): + pass diff --git a/test.py b/test.py new file mode 100644 index 00000000..09510ab9 --- /dev/null +++ b/test.py @@ -0,0 +1,121 @@ +import json +from tqdm import tqdm +import pandas as pd +import time +from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext +from llama_index.core.schema import TextNode +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.vector_stores.lancedb import LanceDBVectorStore +from lancedb.rerankers import Reranker, CohereReranker +from llama_index.embeddings.huggingface import HuggingFaceEmbedding + +import wandb + + +TRAIN_DATASET_FPATH = './finetune-embedding/data/train_dataset.json' +VAL_DATASET_FPATH = './finetune-embedding/data/val_dataset.json' + +with open(TRAIN_DATASET_FPATH, 'r+') as f: + train_dataset = json.load(f) + +with open(VAL_DATASET_FPATH, 'r+') as f: + val_dataset = json.load(f) + +def run_query(tbl, query, vector_query, fts_query=None, reranker=None, top_k=5): + if reranker is None: + return tbl.search(vector_query).limit(2*top_k) + elif fts_query is None: + results = tbl.search(vector_query).rerank(reranker=reranker, query_string=query).limit(2*top_k) + else: + results = tbl.search((vector_query, fts_query)).rerank(reranker=reranker).limit(2*top_k) + return results + + +def evaluate( + dataset, + embed_model, + top_k=5, + verbose=False, + reranker: Reranker=None, + query_type="vector" +): + """ + Evaluate the retrieval performance of the given dataset using the given embedding model. + + Args: + - dataset (dict): The dataset to evaluate. It should have the following keys: + - corpus (dict): A dictionary of document IDs and their corresponding text. + - queries (dict): A dictionary of query IDs and their corresponding text. + - relevant_docs (dict): A dictionary of query IDs and their corresponding relevant document IDs. + - embed_model (str): The embedding model to use. + - top_k (int): The number of documents to retrieve. + - verbose (bool): Whether to print the evaluation results. + - reranker (Reranker): The reranker to use. + - query_type (str): The type of query to use. It should be either "vector" or "hybrid". + + """ + corpus = dataset['corpus'] + queries = dataset['queries'] + relevant_docs = dataset['relevant_docs'] + + service_context = ServiceContext.from_defaults(embed_model=embed_model) + nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()] + + vector_store = LanceDBVectorStore(uri=f"/tmp/lancedb_hybrid_benc-{time.time()}") + storage_context = StorageContext.from_defaults(vector_store=vector_store) + index = VectorStoreIndex( + nodes, + service_context=service_context, + storage_context=storage_context, + show_progress=True + ) + tbl = vector_store._connection.open_table(vector_store.table_name) + # id: string + # doc_id: null + # vector: fixed_size_list[1536] + # child 0, item: float + # text: string + tbl.create_fts_index("text", replace=True) + eval_results = [] + for query_id, query in tqdm(queries.items()): + vector_query = embed_model.get_text_embedding(query) + if query_type == "vector": + rs = run_query(tbl, query, vector_query, reranker=reranker, top_k=top_k) + elif query_type == "hybrid": + fts_query = query + rs = run_query(tbl, query, vector_query, fts_query, reranker=reranker, top_k=top_k) + else: + raise ValueError(f"Invalid query_type: {query_type}") + try: + retrieved_ids = rs.to_pandas()["id"].tolist()[:top_k] + except Exception as e: + print(f"Error: {e}") + continue + expected_id = relevant_docs[query_id][0] + is_hit = expected_id in retrieved_ids # assume 1 relevant doc + + eval_result = { + 'is_hit': is_hit, + 'retrieved': retrieved_ids, + 'expected': expected_id, + 'query': query_id, + } + eval_results.append(eval_result) + return eval_results + +rerankers = [None, CohereReranker(), CohereReranker(model_name="rerank-english-v3.0")] +query_type = ["vector", "hybrid"] +top_ks = [3] + +for top_k in top_ks: + for qt in query_type: + for reranker in rerankers: + wandb.init(project="cohere-v3-hf-embed", name=f"{reranker}_{qt}_top@{top_k}") + embed = HuggingFaceEmbedding("sentence-transformers/all-MiniLM-L6-v2") #OpenAIEmbedding() + results = evaluate(val_dataset, embed, reranker=reranker, query_type=qt) + df = pd.DataFrame(results) + hit_rate = df['is_hit'].mean() + print(f"Reranker: {reranker}, Query Type: {qt}, Hit Rate: {hit_rate}") + wandb.log({"hit_rate": hit_rate}) + wandb.finish() +