mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 21:02:58 +00:00
update
This commit is contained in:
134
python/python/lancedb/embeddings/fine_tuner/README.md
Normal file
134
python/python/lancedb/embeddings/fine_tuner/README.md
Normal file
@@ -0,0 +1,134 @@
|
||||
Fine-tuning workflow for embeddings consists for the following parts:
|
||||
|
||||
### QADataset
|
||||
This class is used for managing the data for fine-tuning. It contains the following builder methods:
|
||||
```
|
||||
- from_llm(
|
||||
nodes: 'List[TextChunk]' ,
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset"
|
||||
```
|
||||
Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned.
|
||||
|
||||
```python
|
||||
|
||||
from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset"
|
||||
```
|
||||
Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one.
|
||||
|
||||
It also contains the following data attributes:
|
||||
```
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
```
|
||||
|
||||
### TextChunk
|
||||
This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes:
|
||||
```
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
```
|
||||
|
||||
Builder Methods:
|
||||
|
||||
```python
|
||||
from_llama_index_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a llama index node.
|
||||
|
||||
```python
|
||||
from_langchain_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a langchain index node.
|
||||
|
||||
```python
|
||||
from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a string.
|
||||
|
||||
### FineTuner
|
||||
This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api.
|
||||
```python
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""Goes off and does stuff."""
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Embedding API finetuning implementation
|
||||
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.
|
||||
|
||||
### Fine-tuning workflow
|
||||
The fine-tuning workflow is as follows:
|
||||
1. Create a `QADataset` object.
|
||||
2. Initialize any embedding function using LanceDB embedding API
|
||||
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
|
||||
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
|
||||
|
||||
# End-to-End Examples
|
||||
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
|
||||
|
||||
## Example 1: Fine-tuning from a synthetic dataset
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
# 1. Create a QADataset object
|
||||
url = "uber10k.pdf"
|
||||
reader = SimpleDirectoryReader(input_files=url)
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
|
||||
# 2. Initialize the embedding model
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
|
||||
# 3. Fine-tune the model
|
||||
model.finetune(trainset=ds, path="model_finetuned", epochs=4)
|
||||
|
||||
# 4. Evaluate the fine-tuned model
|
||||
base = get_registry().get("sentence-transformers").create()
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-large")
|
||||
|
||||
|
||||
rs1 = base.evaluate(trainset, path="val_res")
|
||||
rs2 = tuned.evaluate(trainset, path="val_res")
|
||||
rs3 = openai.evaluate(trainset)
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
|
||||
```
|
||||
|
||||
|
||||
4
python/python/lancedb/embeddings/fine_tuner/__init__.py
Normal file
4
python/python/lancedb/embeddings/fine_tuner/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dataset import QADataset, TextChunk
|
||||
from .llm import Gemini, Openai
|
||||
|
||||
__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"]
|
||||
13
python/python/lancedb/embeddings/fine_tuner/basetuner.py
Normal file
13
python/python/lancedb/embeddings/fine_tuner/basetuner.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""Goes off and does stuff."""
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
pass
|
||||
179
python/python/lancedb/embeddings/fine_tuner/dataset.py
Normal file
179
python/python/lancedb/embeddings/fine_tuner/dataset.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import lance
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from .llm import BaseLLM
|
||||
|
||||
DEFAULT_PROMPT_TMPL = """\
|
||||
Context information is below.
|
||||
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
|
||||
Given the context information and no prior knowledge.
|
||||
generate only questions based on the below query.
|
||||
|
||||
You are a Teacher/ Professor. Your task is to setup \
|
||||
{num_questions_per_chunk} questions for an upcoming \
|
||||
quiz/examination. The questions should be diverse in nature \
|
||||
across the document. Restrict the questions to the \
|
||||
context information provided."
|
||||
"""
|
||||
|
||||
|
||||
class QADataset(BaseModel):
|
||||
"""Embedding QA Finetuning Dataset.
|
||||
|
||||
Args:
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
|
||||
"""
|
||||
|
||||
queries: Dict[str, str] # id -> query
|
||||
corpus: Dict[str, str] # id -> text
|
||||
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
|
||||
mode: str = "text"
|
||||
|
||||
@property
|
||||
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
|
||||
"""Get query, relevant doc ids."""
|
||||
return [
|
||||
(query, self.relevant_docs[query_id])
|
||||
for query_id, query in self.queries.items()
|
||||
]
|
||||
|
||||
def save(self, path: str, mode: str = "overwrite") -> None:
|
||||
"""Save to lance dataset"""
|
||||
save_dir = Path(path)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# convert to pydict {"id": []}
|
||||
queries = {
|
||||
"id": list(self.queries.keys()),
|
||||
"query": list(self.queries.values()),
|
||||
}
|
||||
corpus = {
|
||||
"id": list(self.corpus.keys()),
|
||||
"text": [
|
||||
val or " " for val in self.corpus.values()
|
||||
], # lance saves empty strings as null
|
||||
}
|
||||
relevant_docs = {
|
||||
"query_id": list(self.relevant_docs.keys()),
|
||||
"doc_id": list(self.relevant_docs.values()),
|
||||
}
|
||||
|
||||
# write to lance
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(relevant_docs),
|
||||
save_dir / "relevant_docs.lance",
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "QADataset":
|
||||
"""Load from .lance data"""
|
||||
load_dir = Path(path)
|
||||
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
|
||||
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
|
||||
relevant_docs = (
|
||||
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
|
||||
)
|
||||
return cls(
|
||||
queries=dict(zip(queries["id"], queries["query"])),
|
||||
corpus=dict(zip(corpus["id"], corpus["text"])),
|
||||
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
|
||||
)
|
||||
|
||||
# generate queries as a convenience function
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
nodes: "List[TextChunk]",
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset":
|
||||
"""Generate examples given a set of nodes."""
|
||||
node_dict = {node.id: node.text for node in nodes}
|
||||
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
for node_id, text in tqdm(node_dict.items()):
|
||||
query = qa_generate_prompt_tmpl.format(
|
||||
context_str=text, num_questions_per_chunk=num_questions_per_chunk
|
||||
)
|
||||
response = llm.chat_completion(query)
|
||||
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [node_id]
|
||||
|
||||
return QADataset(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
@classmethod
|
||||
def from_responses(
|
||||
cls,
|
||||
docs: List["TextChunk"],
|
||||
queries: Dict[str, str],
|
||||
relevant_docs: Dict[str, List[str]],
|
||||
) -> "QADataset":
|
||||
"""Create a QADataset from a list of TextChunks and a list of questions."""
|
||||
node_dict = {node.id: node.text for node in docs}
|
||||
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
"""Simple text chunk for generating questions."""
|
||||
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk":
|
||||
"""Create a SimpleTextChunk from a chunk."""
|
||||
# generate a unique id
|
||||
return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def from_llama_index_node(cls, node):
|
||||
"""Convert a llama index node to a text chunk."""
|
||||
return cls(text=node.text, id=node.node_id, metadata=node.metadata)
|
||||
|
||||
@classmethod
|
||||
def from_langchain_node(cls, node):
|
||||
"""Convert a langchaain node to a text chunk."""
|
||||
raise NotImplementedError("Not implemented yet.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to a dictionary."""
|
||||
return self.dict()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SimpleTextChunk(text={self.text}, id={self.id}, \
|
||||
metadata={self.metadata})"
|
||||
85
python/python/lancedb/embeddings/fine_tuner/llm.py
Normal file
85
python/python/lancedb/embeddings/fine_tuner/llm.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import re
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...util import attempt_import_or_raise
|
||||
from ..utils import api_key_not_found_help
|
||||
|
||||
|
||||
class BaseLLM(BaseModel):
|
||||
"""
|
||||
TODO:
|
||||
Base class for Language Model based Embedding Functions. This class is
|
||||
loosely desined rn, and will be updated as the usage gets clearer.
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
model_kwargs: dict = {}
|
||||
|
||||
@cached_property
|
||||
def _client():
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_completion(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Openai(BaseLLM):
|
||||
model_name: str = "gpt-3.5-turbo"
|
||||
kwargs: dict = {}
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
return openai.OpenAI()
|
||||
|
||||
def chat_completion(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
|
||||
# TODO: this is legacy openai api replace with completions
|
||||
completion = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
text = completion.choices[0].message.content
|
||||
|
||||
return text
|
||||
|
||||
def get_questions(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
response = self.chat_completion(prompt)
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
return questions
|
||||
|
||||
|
||||
class Gemini(BaseLLM):
|
||||
pass
|
||||
Reference in New Issue
Block a user