diff --git a/python/python/lancedb/embeddings/fine_tuner/dataset.py b/python/python/lancedb/embeddings/fine_tuner/dataset.py index 6d01abff..f1521feb 100644 --- a/python/python/lancedb/embeddings/fine_tuner/dataset.py +++ b/python/python/lancedb/embeddings/fine_tuner/dataset.py @@ -1,13 +1,13 @@ import re import uuid from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Optional import lance import pyarrow as pa from pydantic import BaseModel from tqdm import tqdm - +from lancedb.utils.general import LOGGER from .llm import BaseLLM DEFAULT_PROMPT_TMPL = """\ @@ -37,7 +37,7 @@ class QADataset(BaseModel): relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. """ - + path: Optional[str] = None queries: Dict[str, str] # id -> query corpus: Dict[str, str] # id -> text relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids @@ -53,6 +53,7 @@ class QADataset(BaseModel): def save(self, path: str, mode: str = "overwrite") -> None: """Save to lance dataset""" + self.path = path save_dir = Path(path) save_dir.mkdir(parents=True, exist_ok=True) @@ -86,20 +87,28 @@ class QADataset(BaseModel): ) @classmethod - def load(cls, path: str) -> "QADataset": + def load(cls, path: str, version: Optional[int] = None) -> "QADataset": """Load from .lance data""" load_dir = Path(path) - queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict() - corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict() + queries = lance.dataset(load_dir / "queries.lance", version=version).to_table().to_pydict() + corpus = lance.dataset(load_dir / "corpus.lance", version=version).to_table().to_pydict() relevant_docs = ( - lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict() + lance.dataset(load_dir / "relevant_docs.lance", version=version).to_table().to_pydict() ) return cls( + path=str(path), queries=dict(zip(queries["id"], queries["query"])), corpus=dict(zip(corpus["id"], corpus["text"])), relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])), ) + @classmethod + def switch_version(cls, version: int) -> "QADataset": + """Switch version of a dataset.""" + if not cls.path: + raise ValueError("Path not set. You need to call save() first.") + return cls.load(cls.path, version=version) + # generate queries as a convenience function @classmethod def from_llm( @@ -142,6 +151,23 @@ class QADataset(BaseModel): """Create a QADataset from a list of TextChunks and a list of questions.""" node_dict = {node.id: node.text for node in docs} return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs) + + def versions(self) -> List[int]: + """Get the versions of the dataset.""" + # TODO: tidy this up + data_paths = self._get_data_file_paths() + return lance.dataset(data_paths[0]).versions() + + + def _get_data_file_paths(self) -> str: + """Get the absolute path of the dataset.""" + queries = self.path / "queries.lance" + corpus = self.path / "corpus.lance" + relevant_docs = self.path / "relevant_docs.lance" + + return queries, corpus, relevant_docs + + class TextChunk(BaseModel):