This commit is contained in:
ayush chaurasia
2024-04-15 05:59:26 +05:30
parent c7fbc4aaee
commit c75bb65609
6 changed files with 536 additions and 0 deletions

View File

@@ -0,0 +1,134 @@
Fine-tuning workflow for embeddings consists for the following parts:
### QADataset
This class is used for managing the data for fine-tuning. It contains the following builder methods:
```
- from_llm(
nodes: 'List[TextChunk]' ,
llm: BaseLLM,
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
num_questions_per_chunk: int = 2,
) -> "QADataset"
```
Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned.
```python
from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset"
```
Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one.
It also contains the following data attributes:
```
queries (Dict[str, str]): Dict id -> query.
corpus (Dict[str, str]): Dict id -> string.
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
```
### TextChunk
This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes:
```
text: str
id: str
metadata: Dict[str, Any] = {}
```
Builder Methods:
```python
from_llama_index_node(node) -> "TextChunk"
```
Create a text chunk from a llama index node.
```python
from_langchain_node(node) -> "TextChunk"
```
Create a text chunk from a langchain index node.
```python
from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk"
```
Create a text chunk from a string.
### FineTuner
This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api.
```python
class BaseEmbeddingTuner(ABC):
"""Base Embedding finetuning engine."""
@abstractmethod
def finetune(self) -> None:
"""Goes off and does stuff."""
def helper(self) -> None:
"""A helper method."""
pass
```
### Embedding API finetuning implementation
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.
### Fine-tuning workflow
The fine-tuning workflow is as follows:
1. Create a `QADataset` object.
2. Initialize any embedding function using LanceDB embedding API
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
# End-to-End Examples
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
## Example 1: Fine-tuning from a synthetic dataset
```python
import pandas as pd
from lancedb.embeddings.fine_tuner.llm import Openai
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
from lancedb.pydantic import LanceModel, Vector
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import MetadataMode
from lancedb.embeddings import get_registry
# 1. Create a QADataset object
url = "uber10k.pdf"
reader = SimpleDirectoryReader(input_files=url)
docs = reader.load_data()
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(docs)
if os.path.exists(name):
ds = QADataset.load(name)
else:
llm = Openai()
# convert Llama-index TextNode to TextChunk
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
ds = QADataset.from_llm(chunks, llm)
ds.save(name)
# 2. Initialize the embedding model
model = get_registry().get("sentence-transformers").create()
# 3. Fine-tune the model
model.finetune(trainset=ds, path="model_finetuned", epochs=4)
# 4. Evaluate the fine-tuned model
base = get_registry().get("sentence-transformers").create()
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
openai = get_registry().get("openai").create(name="text-embedding-3-large")
rs1 = base.evaluate(trainset, path="val_res")
rs2 = tuned.evaluate(trainset, path="val_res")
rs3 = openai.evaluate(trainset)
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
```

View File

@@ -0,0 +1,4 @@
from .dataset import QADataset, TextChunk
from .llm import Gemini, Openai
__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"]

View File

@@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
class BaseEmbeddingTuner(ABC):
"""Base Embedding finetuning engine."""
@abstractmethod
def finetune(self) -> None:
"""Goes off and does stuff."""
def helper(self) -> None:
"""A helper method."""
pass

View File

@@ -0,0 +1,179 @@
import re
import uuid
from pathlib import Path
from typing import Any, Dict, List, Tuple
import lance
import pyarrow as pa
from pydantic import BaseModel
from tqdm import tqdm
from .llm import BaseLLM
DEFAULT_PROMPT_TMPL = """\
Context information is below.
---------------------
{context_str}
---------------------
Given the context information and no prior knowledge.
generate only questions based on the below query.
You are a Teacher/ Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. Restrict the questions to the \
context information provided."
"""
class QADataset(BaseModel):
"""Embedding QA Finetuning Dataset.
Args:
queries (Dict[str, str]): Dict id -> query.
corpus (Dict[str, str]): Dict id -> string.
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
"""
queries: Dict[str, str] # id -> query
corpus: Dict[str, str] # id -> text
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
mode: str = "text"
@property
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
"""Get query, relevant doc ids."""
return [
(query, self.relevant_docs[query_id])
for query_id, query in self.queries.items()
]
def save(self, path: str, mode: str = "overwrite") -> None:
"""Save to lance dataset"""
save_dir = Path(path)
save_dir.mkdir(parents=True, exist_ok=True)
# convert to pydict {"id": []}
queries = {
"id": list(self.queries.keys()),
"query": list(self.queries.values()),
}
corpus = {
"id": list(self.corpus.keys()),
"text": [
val or " " for val in self.corpus.values()
], # lance saves empty strings as null
}
relevant_docs = {
"query_id": list(self.relevant_docs.keys()),
"doc_id": list(self.relevant_docs.values()),
}
# write to lance
lance.write_dataset(
pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode
)
lance.write_dataset(
pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode
)
lance.write_dataset(
pa.Table.from_pydict(relevant_docs),
save_dir / "relevant_docs.lance",
mode=mode,
)
@classmethod
def load(cls, path: str) -> "QADataset":
"""Load from .lance data"""
load_dir = Path(path)
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
relevant_docs = (
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
)
return cls(
queries=dict(zip(queries["id"], queries["query"])),
corpus=dict(zip(corpus["id"], corpus["text"])),
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
)
# generate queries as a convenience function
@classmethod
def from_llm(
cls,
nodes: "List[TextChunk]",
llm: BaseLLM,
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
num_questions_per_chunk: int = 2,
) -> "QADataset":
"""Generate examples given a set of nodes."""
node_dict = {node.id: node.text for node in nodes}
queries = {}
relevant_docs = {}
for node_id, text in tqdm(node_dict.items()):
query = qa_generate_prompt_tmpl.format(
context_str=text, num_questions_per_chunk=num_questions_per_chunk
)
response = llm.chat_completion(query)
result = str(response).strip().split("\n")
questions = [
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]
for question in questions:
question_id = str(uuid.uuid4())
queries[question_id] = question
relevant_docs[question_id] = [node_id]
return QADataset(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
@classmethod
def from_responses(
cls,
docs: List["TextChunk"],
queries: Dict[str, str],
relevant_docs: Dict[str, List[str]],
) -> "QADataset":
"""Create a QADataset from a list of TextChunks and a list of questions."""
node_dict = {node.id: node.text for node in docs}
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
class TextChunk(BaseModel):
"""Simple text chunk for generating questions."""
text: str
id: str
metadata: Dict[str, Any] = {}
@classmethod
def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk":
"""Create a SimpleTextChunk from a chunk."""
# generate a unique id
return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata)
@classmethod
def from_llama_index_node(cls, node):
"""Convert a llama index node to a text chunk."""
return cls(text=node.text, id=node.node_id, metadata=node.metadata)
@classmethod
def from_langchain_node(cls, node):
"""Convert a langchaain node to a text chunk."""
raise NotImplementedError("Not implemented yet.")
def to_dict(self) -> Dict[str, Any]:
"""Convert to a dictionary."""
return self.dict()
def __str__(self) -> str:
return self.text
def __repr__(self) -> str:
return f"SimpleTextChunk(text={self.text}, id={self.id}, \
metadata={self.metadata})"

View File

@@ -0,0 +1,85 @@
import os
import re
from functools import cached_property
from typing import Optional
from pydantic import BaseModel
from ...util import attempt_import_or_raise
from ..utils import api_key_not_found_help
class BaseLLM(BaseModel):
"""
TODO:
Base class for Language Model based Embedding Functions. This class is
loosely desined rn, and will be updated as the usage gets clearer.
"""
model_name: str
model_kwargs: dict = {}
@cached_property
def _client():
"""
Get the client for the language model
"""
raise NotImplementedError
def chat_completion(self, prompt: str, **kwargs):
"""
Get the chat completion for the given prompt
"""
raise NotImplementedError
class Openai(BaseLLM):
model_name: str = "gpt-3.5-turbo"
kwargs: dict = {}
api_key: Optional[str] = None
@cached_property
def _client(self):
"""
Get the client for the language model
"""
openai = attempt_import_or_raise("openai")
if not os.environ.get("OPENAI_API_KEY"):
api_key_not_found_help("openai")
return openai.OpenAI()
def chat_completion(self, prompt: str) -> str:
"""
Get the chat completion for the given prompt
"""
# TODO: this is legacy openai api replace with completions
completion = self._client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
**self.kwargs,
)
text = completion.choices[0].message.content
return text
def get_questions(self, prompt: str) -> str:
"""
Get the chat completion for the given prompt
"""
response = self.chat_completion(prompt)
result = str(response).strip().split("\n")
questions = [
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]
return questions
class Gemini(BaseLLM):
pass

121
test.py Normal file
View File

@@ -0,0 +1,121 @@
import json
from tqdm import tqdm
import pandas as pd
import time
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
from llama_index.core.schema import TextNode
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from lancedb.rerankers import Reranker, CohereReranker
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import wandb
TRAIN_DATASET_FPATH = './finetune-embedding/data/train_dataset.json'
VAL_DATASET_FPATH = './finetune-embedding/data/val_dataset.json'
with open(TRAIN_DATASET_FPATH, 'r+') as f:
train_dataset = json.load(f)
with open(VAL_DATASET_FPATH, 'r+') as f:
val_dataset = json.load(f)
def run_query(tbl, query, vector_query, fts_query=None, reranker=None, top_k=5):
if reranker is None:
return tbl.search(vector_query).limit(2*top_k)
elif fts_query is None:
results = tbl.search(vector_query).rerank(reranker=reranker, query_string=query).limit(2*top_k)
else:
results = tbl.search((vector_query, fts_query)).rerank(reranker=reranker).limit(2*top_k)
return results
def evaluate(
dataset,
embed_model,
top_k=5,
verbose=False,
reranker: Reranker=None,
query_type="vector"
):
"""
Evaluate the retrieval performance of the given dataset using the given embedding model.
Args:
- dataset (dict): The dataset to evaluate. It should have the following keys:
- corpus (dict): A dictionary of document IDs and their corresponding text.
- queries (dict): A dictionary of query IDs and their corresponding text.
- relevant_docs (dict): A dictionary of query IDs and their corresponding relevant document IDs.
- embed_model (str): The embedding model to use.
- top_k (int): The number of documents to retrieve.
- verbose (bool): Whether to print the evaluation results.
- reranker (Reranker): The reranker to use.
- query_type (str): The type of query to use. It should be either "vector" or "hybrid".
"""
corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']
service_context = ServiceContext.from_defaults(embed_model=embed_model)
nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
vector_store = LanceDBVectorStore(uri=f"/tmp/lancedb_hybrid_benc-{time.time()}")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex(
nodes,
service_context=service_context,
storage_context=storage_context,
show_progress=True
)
tbl = vector_store._connection.open_table(vector_store.table_name)
# id: string
# doc_id: null
# vector: fixed_size_list<item: float>[1536]
# child 0, item: float
# text: string
tbl.create_fts_index("text", replace=True)
eval_results = []
for query_id, query in tqdm(queries.items()):
vector_query = embed_model.get_text_embedding(query)
if query_type == "vector":
rs = run_query(tbl, query, vector_query, reranker=reranker, top_k=top_k)
elif query_type == "hybrid":
fts_query = query
rs = run_query(tbl, query, vector_query, fts_query, reranker=reranker, top_k=top_k)
else:
raise ValueError(f"Invalid query_type: {query_type}")
try:
retrieved_ids = rs.to_pandas()["id"].tolist()[:top_k]
except Exception as e:
print(f"Error: {e}")
continue
expected_id = relevant_docs[query_id][0]
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
eval_result = {
'is_hit': is_hit,
'retrieved': retrieved_ids,
'expected': expected_id,
'query': query_id,
}
eval_results.append(eval_result)
return eval_results
rerankers = [None, CohereReranker(), CohereReranker(model_name="rerank-english-v3.0")]
query_type = ["vector", "hybrid"]
top_ks = [3]
for top_k in top_ks:
for qt in query_type:
for reranker in rerankers:
wandb.init(project="cohere-v3-hf-embed", name=f"{reranker}_{qt}_top@{top_k}")
embed = HuggingFaceEmbedding("sentence-transformers/all-MiniLM-L6-v2") #OpenAIEmbedding()
results = evaluate(val_dataset, embed, reranker=reranker, query_type=qt)
df = pd.DataFrame(results)
hit_rate = df['is_hit'].mean()
print(f"Reranker: {reranker}, Query Type: {qt}, Hit Rate: {hit_rate}")
wandb.log({"hit_rate": hit_rate})
wandb.finish()