diff --git a/docs/src/eval/bench_fine_tuned_hybrid.py b/docs/src/eval/bench_fine_tuned_hybrid.py new file mode 100644 index 00000000..60fe92ff --- /dev/null +++ b/docs/src/eval/bench_fine_tuned_hybrid.py @@ -0,0 +1,150 @@ +import json +from tqdm import tqdm +import pandas as pd +import os +import requests +from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext +from llama_index.core.schema import TextNode +from llama_index.vector_stores.lancedb import LanceDBVectorStore +from lancedb.rerankers import CrossEncoderReranker, ColbertReranker, CohereReranker, LinearCombinationReranker +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.embeddings.openai import OpenAIEmbedding +from lancedb.pydantic import LanceModel, Vector +from lancedb.embeddings import get_registry +from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL +from lancedb.pydantic import LanceModel, Vector +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter +from lancedb.embeddings.fine_tuner.llm import Openai + +import time +import lancedb +import wandb +from pydantic import BaseModel, root_validator +from typing import Optional + +TRAIN_DATASET_FPATH = './data/train_dataset.json' +VAL_DATASET_FPATH = './data/val_dataset.json' + +with open(TRAIN_DATASET_FPATH, 'r+') as f: + train_dataset = json.load(f) + +with open(VAL_DATASET_FPATH, 'r+') as f: + val_dataset = json.load(f) + +def train_embedding_model(epoch): + def download_test_files(url): + # download to cwd + files = [] + filename = os.path.basename(url) + if not os.path.exists(filename): + print(f"Downloading {url} to {filename}") + r = requests.get(url) + with open(filename, 'wb') as f: + f.write(r.content) + files.append(filename) + return files + + def get_dataset(url, name): + reader = SimpleDirectoryReader(input_files=download_test_files(url)) + docs = reader.load_data() + + parser = SentenceSplitter() + nodes = parser.get_nodes_from_documents(docs) + + if os.path.exists(name): + ds = QADataset.load(name) + else: + llm = Openai() + + # convert Llama-index TextNode to TextChunk + chunks = [TextChunk.from_llama_index_node(node) for node in nodes] + + ds = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2) + ds.save(name) + return ds + train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' + ds = get_dataset(train_url, "qa_dataset_uber") + + + model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5") + model.finetune(trainset=ds, valset=None, path="model_airbnb", epochs=epoch, log_wandb=True, run_name="lyft_finetune") + + +def evaluate( + dataset, + embed_model, + reranker=None, + top_k=5, + verbose=False, +): + corpus = dataset['corpus'] + queries = dataset['queries'] + relevant_docs = dataset['relevant_docs'] + + vector_store = LanceDBVectorStore(uri="/tmp/lancedb") + storage_context = StorageContext.from_defaults(vector_store=vector_store) + service_context = ServiceContext.from_defaults(embed_model=embed_model) + nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()] + index = VectorStoreIndex( + nodes, + service_context=service_context, + show_progress=True, + storage_context=storage_context, + ) + tbl = vector_store.connection.open_table(vector_store.table_name) + tbl.create_fts_index("text", replace=True) + + eval_results = [] + for query_id, query in tqdm(queries.items()): + query_vector = embed_model.get_query_embedding(query) + try: + if reranker is None: + rs = tbl.search(query_vector).limit(top_k).to_pandas() + else: + rs = tbl.search((query_vector, query)).rerank(reranker=reranker).limit(top_k).to_pandas() + except Exception as e: + print(f'Error with query: {query_id} {e}') + continue + retrieved_ids = rs['id'].tolist()[:top_k] + expected_id = relevant_docs[query_id][0] + is_hit = expected_id in retrieved_ids # assume 1 relevant doc + if len(eval_results) == 0: + print(f"Query: {query}") + print(f"Expected: {expected_id}") + print(f"Retrieved: {retrieved_ids}") + eval_result = { + 'is_hit': is_hit, + 'retrieved': retrieved_ids, + 'expected': expected_id, + 'query': query_id, + } + eval_results.append(eval_result) + return eval_results + +if __name__ == '__main__': + train_embedding_model(4) + #embed_model = OpenAIEmbedding() # model="text-embedding-3-small" + rerankers = { + "Vector Search": None, + "Cohere": CohereReranker(), + "Cross Encoder": CrossEncoderReranker(), + "Colbert": ColbertReranker(), + "linear": LinearCombinationReranker(), + } + top_ks = [3] + for top_k in top_ks: + #for epoch in epochs: + for name, reranker in rerankers.items(): + #embed_model = HuggingFaceEmbedding("./model_airbnb") + embed_model = OpenAIEmbedding() + wandb.init(project=f"Reranker-based", name=name) + val_eval_results = evaluate(val_dataset, embed_model, reranker=reranker, top_k=top_k) + df = pd.DataFrame(val_eval_results) + + hit_rate = df['is_hit'].mean() + print(f'Hit rate: {hit_rate:.2f}') + wandb.log({f"openai_base_hit_rate_@{top_k}": hit_rate}) + wandb.finish() + + diff --git a/docs/src/eval/test_fine_tune_from_llm.py b/docs/src/eval/test_fine_tune_from_llm.py new file mode 100644 index 00000000..beea7b43 --- /dev/null +++ b/docs/src/eval/test_fine_tune_from_llm.py @@ -0,0 +1,71 @@ +import os +import json +import lancedb +import pandas as pd + +from lancedb.embeddings.fine_tuner.llm import Openai +from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk +from lancedb.pydantic import LanceModel, Vector +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.schema import MetadataMode +from lancedb.embeddings import get_registry + + +test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf' +train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' +def download_test_files(url): + import os + import requests + + # download to cwd + files = [] + filename = os.path.basename(url) + if not os.path.exists(filename): + print(f"Downloading {url} to {filename}") + r = requests.get(url) + with open(filename, 'wb') as f: + f.write(r.content) + files.append(filename) + return files + +def get_dataset(url, name): + reader = SimpleDirectoryReader(input_files=download_test_files(url)) + docs = reader.load_data() + + parser = SentenceSplitter() + nodes = parser.get_nodes_from_documents(docs) + + if os.path.exists(name): + ds = QADataset.load(name) + else: + llm = Openai() + + # convert Llama-index TextNode to TextChunk + chunks = [TextChunk.from_llama_index_node(node) for node in nodes] + + ds = QADataset.from_llm(chunks, llm) + ds.save(name) + return ds + + + +trainset = get_dataset(test_url, "qa_dataset_1") +valset = get_dataset(train_url, "valset") + +model = get_registry().get("sentence-transformers").create() +model.finetune(trainset=trainset, valset=valset, path="model_finetuned_1", epochs=4) + +base = get_registry().get("sentence-transformers").create() +tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1") +openai = get_registry().get("openai").create(name="text-embedding-3-large") + + +rs1 = base.evaluate(valset, path="val_res") +rs2 = tuned.evaluate(valset, path="val_res") +rs3 = openai.evaluate(valset) + +print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean()) +print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean()) +print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean()) + diff --git a/docs/src/eval/test_fine_tune_from_responses.py b/docs/src/eval/test_fine_tune_from_responses.py new file mode 100644 index 00000000..ef45e5b8 --- /dev/null +++ b/docs/src/eval/test_fine_tune_from_responses.py @@ -0,0 +1,119 @@ +import os +import re +import json +import uuid +import lancedb +import pandas as pd + +from tqdm import tqdm +from lancedb.embeddings.fine_tuner.llm import Openai +from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL +from lancedb.pydantic import LanceModel, Vector +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.schema import MetadataMode +from lancedb.embeddings import get_registry + + + +test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf' +train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' +def download_test_files(url): + import os + import requests + + + # download to cwd + files = [] + filename = os.path.basename(url) + if not os.path.exists(filename): + print(f"Downloading {url} to {filename}") + r = requests.get(url) + with open(filename, 'wb') as f: + f.write(r.content) + files.append(filename) + return files + + +def get_node(url): + reader = SimpleDirectoryReader(input_files=download_test_files(url)) + docs = reader.load_data() + + parser = SentenceSplitter() + nodes = parser.get_nodes_from_documents(docs) + + return nodes +def get_dataset(url, name): + reader = SimpleDirectoryReader(input_files=download_test_files(url)) + docs = reader.load_data() + + parser = SentenceSplitter() + nodes = parser.get_nodes_from_documents(docs) + + if os.path.exists(name): + ds = QADataset.load(name) + else: + llm = Openai() + + # convert Llama-index TextNode to TextChunk + chunks = [TextChunk.from_llama_index_node(node) for node in nodes] + + ds = QADataset.from_llm(chunks, llm) + ds.save(name) + return ds + +nodes = get_node(train_url) + +db = lancedb.connect("~/lancedb/fine-tuning") +model = get_registry().get("openai").create() +class Schema(LanceModel): + id: str + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + +retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite") +pylist = [{"id": str(node.node_id), "text": node.text} for node in nodes] +retriever.add(pylist) + + + +ds_name = "response_data" +if os.path.exists(ds_name): + ds = QADataset.load(ds_name) +else: + # Generate questions + llm = Openai() + text_chunks = [TextChunk.from_llama_index_node(node) for node in nodes] + + queries = {} + relevant_docs = {} + for chunk in tqdm(text_chunks): + text = chunk.text + questions = llm.get_questions(DEFAULT_PROMPT_TMPL.format(context_str=text, num_questions_per_chunk=2)) + + for question in questions: + question_id = str(uuid.uuid4()) + queries[question_id] = question + relevant_docs[question_id] = [retriever.search(question).to_pandas()["id"].tolist()[0]] + ds = QADataset.from_responses(text_chunks, queries, relevant_docs) + ds.save(ds_name) + + +# Fine-tune model +valset = get_dataset(train_url, "valset") + +model = get_registry().get("sentence-transformers").create() +res_base = model.evaluate(valset) + +model.finetune(trainset=ds, path="model_finetuned", epochs=4, log_wandb=True) +tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned") +res_tuned = tuned.evaluate(valset) + +openai_model = get_registry().get("openai").create() +#res_openai = openai_model.evaluate(valset) + +#print(f"openai model results: {pd.DataFrame(res_openai)['is_hit'].mean()}") +print(f"base model results: {pd.DataFrame(res_base)['is_hit'].mean()}") +print(f"tuned model results: {pd.DataFrame(res_tuned)['is_hit'].mean()}") + + diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index 3d940810..1dd84b0e 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -10,13 +10,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from abc import ABC, abstractmethod from typing import List, Union import numpy as np import pyarrow as pa from pydantic import BaseModel, Field, PrivateAttr +from tqdm import tqdm +import lancedb + +from .fine_tuner import QADataset from .utils import TEXT, retry_with_exponential_backoff @@ -126,6 +131,22 @@ class EmbeddingFunction(BaseModel, ABC): def __hash__(self) -> int: return hash(frozenset(vars(self).items())) + def finetune(self, dataset: QADataset, *args, **kwargs): + """ + Finetune the embedding function on a dataset + """ + raise NotImplementedError( + "Finetuning is not supported for this embedding function" + ) + + def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs): + """ + Evaluate the embedding function on a dataset + """ + raise NotImplementedError( + "Evaluation is not supported for this embedding function" + ) + class EmbeddingFunctionConfig(BaseModel): """ @@ -159,3 +180,52 @@ class TextEmbeddingFunction(EmbeddingFunction): Generate the embeddings for the given texts """ pass + + def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs): + """ + Evaluate the embedding function on a dataset. This calculates the hit-rate for + the top-k retrieved documents for each query in the dataset. Assumes that the + first relevant document is the expected document. + Pro - Should work for any embedding model + Con - Returns every simple metric. + Parameters + ---------- + dataset: QADataset + The dataset to evaluate on + + Returns + ------- + dict + The evaluation results + """ + corpus = dataset.corpus + queries = dataset.queries + relevant_docs = dataset.relevant_docs + path = path or os.path.join(os.getcwd(), "eval") + db = lancedb.connect(path) + + class Schema(lancedb.pydantic.LanceModel): + id: str + text: str = self.SourceField() + vector: lancedb.pydantic.Vector(self.ndims()) = self.VectorField() + + retriever = db.create_table("eval", schema=Schema, mode="overwrite") + pylist = [{"id": str(k), "text": v} for k, v in corpus.items()] + retriever.add(pylist) + + eval_results = [] + for query_id, query in tqdm(queries.items()): + retrieved_nodes = retriever.search(query).limit(top_k).to_list() + retrieved_ids = [node["id"] for node in retrieved_nodes] + expected_id = relevant_docs[query_id][0] + is_hit = expected_id in retrieved_ids # assume 1 relevant doc + + eval_result = { + "is_hit": is_hit, + "retrieved": retrieved_ids, + "expected": expected_id, + "query": query_id, + } + eval_results.append(eval_result) + + return eval_results diff --git a/python/python/lancedb/embeddings/fine_tuner/README.md b/python/python/lancedb/embeddings/fine_tuner/README.md new file mode 100644 index 00000000..e06f2f3f --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/README.md @@ -0,0 +1,133 @@ +Fine-tuning workflow for embeddings consists for the following parts: + +### QADataset +This class is used for managing the data for fine-tuning. It contains the following builder methods: +``` +- from_llm( + nodes: 'List[TextChunk]' , + llm: BaseLLM, + qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL, + num_questions_per_chunk: int = 2, +) -> "QADataset" +``` +Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned. + +```python + +from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset" +``` +Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one. + +It also contains the following data attributes: +``` + queries (Dict[str, str]): Dict id -> query. + corpus (Dict[str, str]): Dict id -> string. + relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. +``` + +### TextChunk +This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes: +``` + text: str + id: str + metadata: Dict[str, Any] = {} +``` + +Builder Methods: + +```python +from_llama_index_node(node) -> "TextChunk" +``` +Create a text chunk from a llama index node. + +```python +from_langchain_node(node) -> "TextChunk" +``` +Create a text chunk from a langchain index node. + +```python +from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk" +``` +Create a text chunk from a string. + +### FineTuner +This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api. +```python +class BaseEmbeddingTuner(ABC): + """Base Embedding finetuning engine.""" + + @abstractmethod + def finetune(self) -> None: + """Goes off and does stuff.""" + + def helper(self) -> None: + """A helper method.""" + pass +``` + +### Embedding API finetuning implementation +Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`. + +### Fine-tuning workflow +The fine-tuning workflow is as follows: +1. Create a `QADataset` object. +2. Initialize any embedding function using LanceDB embedding API +3. Call `finetune` method on the embedding object with the `QADataset` object as an argument. +4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API. + +# End-to-End Examples +The following is an example of how to fine-tune an embedding model using the LanceDB embedding API. + +## Example 1: Fine-tuning from a synthetic dataset +```python +import pandas as pd + +from lancedb.embeddings.fine_tuner.llm import Openai +from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk +from lancedb.pydantic import LanceModel, Vector +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.schema import MetadataMode +from lancedb.embeddings import get_registry + +# 1. Create a QADataset object +url = "uber10k.pdf" +reader = SimpleDirectoryReader(input_files=url) +docs = reader.load_data() + +parser = SentenceSplitter() +nodes = parser.get_nodes_from_documents(docs) + +if os.path.exists(name): + ds = QADataset.load(name) +else: + llm = Openai() + + # convert Llama-index TextNode to TextChunk + chunks = [TextChunk.from_llama_index_node(node) for node in nodes] + + ds = QADataset.from_llm(chunks, llm) + ds.save(name) + +# 2. Initialize the embedding model +model = get_registry().get("sentence-transformers").create() + +# 3. Fine-tune the model +model.finetune(trainset=ds, path="model_finetuned", epochs=4) + +# 4. Evaluate the fine-tuned model +base = get_registry().get("sentence-transformers").create() +tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1") +openai = get_registry().get("openai").create(name="text-embedding-3-large") + + +rs1 = base.evaluate(trainset, path="val_res") +rs2 = tuned.evaluate(trainset, path="val_res") +rs3 = openai.evaluate(trainset) + +print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean()) +print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean()) +print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean()) +``` + + diff --git a/python/python/lancedb/embeddings/fine_tuner/__init__.py b/python/python/lancedb/embeddings/fine_tuner/__init__.py new file mode 100644 index 00000000..74df8066 --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/__init__.py @@ -0,0 +1,4 @@ +from .dataset import QADataset, TextChunk +from .llm import Gemini, Openai + +__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"] diff --git a/python/python/lancedb/embeddings/fine_tuner/basetuner.py b/python/python/lancedb/embeddings/fine_tuner/basetuner.py new file mode 100644 index 00000000..a90be05a --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/basetuner.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + + +class BaseEmbeddingTuner(ABC): + """Base Embedding finetuning engine.""" + + @abstractmethod + def finetune(self) -> None: + """Goes off and does stuff.""" + + def helper(self) -> None: + """A helper method.""" + pass diff --git a/python/python/lancedb/embeddings/fine_tuner/dataset.py b/python/python/lancedb/embeddings/fine_tuner/dataset.py new file mode 100644 index 00000000..6d01abff --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/dataset.py @@ -0,0 +1,179 @@ +import re +import uuid +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import lance +import pyarrow as pa +from pydantic import BaseModel +from tqdm import tqdm + +from .llm import BaseLLM + +DEFAULT_PROMPT_TMPL = """\ +Context information is below. + +--------------------- +{context_str} +--------------------- + +Given the context information and no prior knowledge. +generate only questions based on the below query. + +You are a Teacher/ Professor. Your task is to setup \ +{num_questions_per_chunk} questions for an upcoming \ +quiz/examination. The questions should be diverse in nature \ +across the document. Restrict the questions to the \ +context information provided." +""" + + +class QADataset(BaseModel): + """Embedding QA Finetuning Dataset. + + Args: + queries (Dict[str, str]): Dict id -> query. + corpus (Dict[str, str]): Dict id -> string. + relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. + + """ + + queries: Dict[str, str] # id -> query + corpus: Dict[str, str] # id -> text + relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids + mode: str = "text" + + @property + def query_docid_pairs(self) -> List[Tuple[str, List[str]]]: + """Get query, relevant doc ids.""" + return [ + (query, self.relevant_docs[query_id]) + for query_id, query in self.queries.items() + ] + + def save(self, path: str, mode: str = "overwrite") -> None: + """Save to lance dataset""" + save_dir = Path(path) + save_dir.mkdir(parents=True, exist_ok=True) + + # convert to pydict {"id": []} + queries = { + "id": list(self.queries.keys()), + "query": list(self.queries.values()), + } + corpus = { + "id": list(self.corpus.keys()), + "text": [ + val or " " for val in self.corpus.values() + ], # lance saves empty strings as null + } + relevant_docs = { + "query_id": list(self.relevant_docs.keys()), + "doc_id": list(self.relevant_docs.values()), + } + + # write to lance + lance.write_dataset( + pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode + ) + lance.write_dataset( + pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode + ) + lance.write_dataset( + pa.Table.from_pydict(relevant_docs), + save_dir / "relevant_docs.lance", + mode=mode, + ) + + @classmethod + def load(cls, path: str) -> "QADataset": + """Load from .lance data""" + load_dir = Path(path) + queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict() + corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict() + relevant_docs = ( + lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict() + ) + return cls( + queries=dict(zip(queries["id"], queries["query"])), + corpus=dict(zip(corpus["id"], corpus["text"])), + relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])), + ) + + # generate queries as a convenience function + @classmethod + def from_llm( + cls, + nodes: "List[TextChunk]", + llm: BaseLLM, + qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL, + num_questions_per_chunk: int = 2, + ) -> "QADataset": + """Generate examples given a set of nodes.""" + node_dict = {node.id: node.text for node in nodes} + + queries = {} + relevant_docs = {} + for node_id, text in tqdm(node_dict.items()): + query = qa_generate_prompt_tmpl.format( + context_str=text, num_questions_per_chunk=num_questions_per_chunk + ) + response = llm.chat_completion(query) + + result = str(response).strip().split("\n") + questions = [ + re.sub(r"^\d+[\).\s]", "", question).strip() for question in result + ] + questions = [question for question in questions if len(question) > 0] + for question in questions: + question_id = str(uuid.uuid4()) + queries[question_id] = question + relevant_docs[question_id] = [node_id] + + return QADataset(queries=queries, corpus=node_dict, relevant_docs=relevant_docs) + + @classmethod + def from_responses( + cls, + docs: List["TextChunk"], + queries: Dict[str, str], + relevant_docs: Dict[str, List[str]], + ) -> "QADataset": + """Create a QADataset from a list of TextChunks and a list of questions.""" + node_dict = {node.id: node.text for node in docs} + return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs) + + +class TextChunk(BaseModel): + """Simple text chunk for generating questions.""" + + text: str + id: str + metadata: Dict[str, Any] = {} + + @classmethod + def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk": + """Create a SimpleTextChunk from a chunk.""" + # generate a unique id + return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata) + + @classmethod + def from_llama_index_node(cls, node): + """Convert a llama index node to a text chunk.""" + return cls(text=node.text, id=node.node_id, metadata=node.metadata) + + @classmethod + def from_langchain_node(cls, node): + """Convert a langchaain node to a text chunk.""" + raise NotImplementedError("Not implemented yet.") + + def to_dict(self) -> Dict[str, Any]: + """Convert to a dictionary.""" + return self.dict() + + def __str__(self) -> str: + return self.text + + def __repr__(self) -> str: + return f"SimpleTextChunk(text={self.text}, id={self.id}, \ + metadata={self.metadata})" diff --git a/python/python/lancedb/embeddings/fine_tuner/llm.py b/python/python/lancedb/embeddings/fine_tuner/llm.py new file mode 100644 index 00000000..1bdf2d97 --- /dev/null +++ b/python/python/lancedb/embeddings/fine_tuner/llm.py @@ -0,0 +1,85 @@ +import os +import re +from functools import cached_property +from typing import Optional + +from pydantic import BaseModel + +from ...util import attempt_import_or_raise +from ..utils import api_key_not_found_help + + +class BaseLLM(BaseModel): + """ + TODO: + Base class for Language Model based Embedding Functions. This class is + loosely desined rn, and will be updated as the usage gets clearer. + """ + + model_name: str + model_kwargs: dict = {} + + @cached_property + def _client(): + """ + Get the client for the language model + """ + raise NotImplementedError + + def chat_completion(self, prompt: str, **kwargs): + """ + Get the chat completion for the given prompt + """ + raise NotImplementedError + + +class Openai(BaseLLM): + model_name: str = "gpt-3.5-turbo" + kwargs: dict = {} + api_key: Optional[str] = None + + @cached_property + def _client(self): + """ + Get the client for the language model + """ + openai = attempt_import_or_raise("openai") + + if not os.environ.get("OPENAI_API_KEY"): + api_key_not_found_help("openai") + return openai.OpenAI() + + def chat_completion(self, prompt: str) -> str: + """ + Get the chat completion for the given prompt + """ + + # TODO: this is legacy openai api replace with completions + completion = self._client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + **self.kwargs, + ) + + text = completion.choices[0].message.content + + return text + + def get_questions(self, prompt: str) -> str: + """ + Get the chat completion for the given prompt + """ + response = self.chat_completion(prompt) + result = str(response).strip().split("\n") + questions = [ + re.sub(r"^\d+[\).\s]", "", question).strip() for question in result + ] + questions = [question for question in questions if len(question) > 0] + return questions + + +class Gemini(BaseLLM): + pass diff --git a/python/python/lancedb/embeddings/instructor.py b/python/python/lancedb/embeddings/instructor.py index 98206bc5..e6481e19 100644 --- a/python/python/lancedb/embeddings/instructor.py +++ b/python/python/lancedb/embeddings/instructor.py @@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): # convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly source_instruction: str = "represent the document for retrieval" - query_instruction: str = ( - "represent the document for retrieving the most similar documents" - ) + query_instruction: ( + str + ) = "represent the document for retrieving the most similar documents" @weak_lru(maxsize=1) def ndims(self): diff --git a/python/python/lancedb/embeddings/sentence_transformers.py b/python/python/lancedb/embeddings/sentence_transformers.py index 97fe1318..bf0a3ac4 100644 --- a/python/python/lancedb/embeddings/sentence_transformers.py +++ b/python/python/lancedb/embeddings/sentence_transformers.py @@ -10,12 +10,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Any, List, Optional, Union import numpy as np +from lancedb.embeddings.fine_tuner import QADataset +from lancedb.utils.general import LOGGER + from ..util import attempt_import_or_raise from .base import TextEmbeddingFunction +from .fine_tuner.basetuner import BaseEmbeddingTuner from .registry import register from .utils import weak_lru @@ -80,3 +84,151 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): "sentence_transformers", "sentence-transformers" ) return sentence_transformers.SentenceTransformer(self.name, device=self.device) + + def finetune(self, trainset: QADataset, *args, **kwargs): + """ + Finetune the Sentence Transformers model + + Parameters + ---------- + dataset: QADataset + The dataset to use for finetuning + """ + tuner = SentenceTransformersTuner( + model=self.embedding_model, + trainset=trainset, + **kwargs, + ) + tuner.finetune() + + +class SentenceTransformersTuner(BaseEmbeddingTuner): + """Sentence Transformers Embedding Finetuning Engine.""" + + def __init__( + self, + model: Any, + trainset: QADataset, + valset: Optional[QADataset] = None, + path: Optional[str] = "~/.lancedb/embeddings/models", + batch_size: int = 8, + epochs: int = 1, + show_progress: bool = True, + eval_steps: int = 50, + max_input_per_doc: int = -1, + loss: Optional[Any] = None, + evaluator: Optional[Any] = None, + run_name: Optional[str] = None, + log_wandb: bool = False, + ) -> None: + """ + Parameters + ---------- + model: str + The model to use for finetuning. + trainset: QADataset + The training dataset. + valset: Optional[QADataset] + The validation dataset. + path: Optional[str] + The path to save the model. + batch_size: int, default=8 + The batch size. + epochs: int, default=1 + The number of epochs. + show_progress: bool, default=True + Whether to show progress. + eval_steps: int, default=50 + The number of steps to evaluate. + max_input_per_doc: int, default=-1 + The number of input per document. + if -1, use all documents. + """ + from sentence_transformers import InputExample, losses + from sentence_transformers.evaluation import InformationRetrievalEvaluator + from torch.utils.data import DataLoader + + self.model = model + self.trainset = trainset + self.valset = valset + self.path = path + self.batch_size = batch_size + self.epochs = epochs + self.show_progress = show_progress + self.eval_steps = eval_steps + self.max_input_per_doc = max_input_per_doc + self.evaluator = None + self.epochs = epochs + self.show_progress = show_progress + self.eval_steps = eval_steps + self.run_name = run_name + self.log_wandb = log_wandb + + if self.max_input_per_doc < -1: + raise ValueError("max_input_per_doc must be -1 or greater than 0.") + + examples: Any = [] + for query_id, query in self.trainset.queries.items(): + if max_input_per_doc == -1: + for node_id in self.trainset.relevant_docs[query_id]: + text = self.trainset.corpus[node_id] + example = InputExample(texts=[query, text]) + examples.append(example) + else: + node_id = self.trainset.relevant_docs[query_id][ + min(max_input_per_doc, len(self.trainset.relevant_docs[query_id])) + ] + text = self.trainset.corpus[node_id] + example = InputExample(texts=[query, text]) + examples.append(example) + + self.examples = examples + + self.loader: DataLoader = DataLoader(examples, batch_size=batch_size) + + if self.valset is not None: + eval_engine = evaluator or InformationRetrievalEvaluator + self.evaluator = eval_engine( + valset.queries, valset.corpus, valset.relevant_docs + ) + self.evaluator = evaluator + + # define loss + self.loss = loss or losses.MultipleNegativesRankingLoss(self.model) + self.warmup_steps = int(len(self.loader) * epochs * 0.1) + + def finetune(self) -> None: + """Finetune the Sentence Transformers model.""" + self.model.fit( + train_objectives=[(self.loader, self.loss)], + epochs=self.epochs, + warmup_steps=self.warmup_steps, + output_path=self.path, + show_progress_bar=self.show_progress, + evaluator=self.evaluator, + evaluation_steps=self.eval_steps, + callback=self._wandb_callback if self.log_wandb else None, + ) + + self.helper() + + def helper(self) -> None: + """A helper method.""" + LOGGER.info("Finetuning complete.") + LOGGER.info(f"Model saved to {self.path}.") + LOGGER.info("You can now use the model as follows:") + LOGGER.info( + f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa + ) + + def _wandb_callback(self, score, epoch, steps): + try: + import wandb + except ImportError: + raise ImportError( + "wandb is not installed. Please install it using `pip install wandb`" + ) + run = wandb.run or wandb.init( + project="sbert_lancedb_finetune", name=self.run_name + ) + run.log({"epoch": epoch, "steps": steps, "score": score}) diff --git a/python/python/tests/test_embedding_fine_tuning.py b/python/python/tests/test_embedding_fine_tuning.py new file mode 100644 index 00000000..da5a13cc --- /dev/null +++ b/python/python/tests/test_embedding_fine_tuning.py @@ -0,0 +1,45 @@ +import uuid + +import pytest +from lancedb.embeddings import get_registry +from lancedb.embeddings.fine_tuner import QADataset, TextChunk +from tqdm import tqdm + + +@pytest.mark.slow +def test_finetuning_sentence_transformers(tmp_path): + queries = {} + relevant_docs = {} + chunks = [ + "This is a chunk related to legal docs", + "This is another chunk related financial docs", + "This is a chunk related to sports docs", + "This is another chunk related to fashion docs", + ] + text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks] + for chunk in tqdm(text_chunks): + questions = [ + "What is this chunk about?", + "What is the main topic of this chunk?", + ] + for question in questions: + question_id = str(uuid.uuid4()) + queries[question_id] = question + relevant_docs[question_id] = [chunk.id] + ds = QADataset.from_responses(text_chunks, queries, relevant_docs) + + assert len(ds.queries) == 8 + assert len(ds.corpus) == 4 + + model = get_registry().get("sentence-transformers").create() + model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1) + model = ( + get_registry().get("sentence-transformers").create(name=str(tmp_path / "model")) + ) + res = model.evaluate(ds) + assert res is not None + + +def test_text_chunk(): + # TODO + pass