Compare commits

...

1 Commits

Author SHA1 Message Date
ayush chaurasia
99d1a06a44 add dataset features 2024-04-05 16:34:21 +05:30

View File

@@ -1,13 +1,13 @@
import re import re
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, Optional
import lance import lance
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
from lancedb.utils.general import LOGGER
from .llm import BaseLLM from .llm import BaseLLM
DEFAULT_PROMPT_TMPL = """\ DEFAULT_PROMPT_TMPL = """\
@@ -37,7 +37,7 @@ class QADataset(BaseModel):
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
""" """
path: Optional[str] = None
queries: Dict[str, str] # id -> query queries: Dict[str, str] # id -> query
corpus: Dict[str, str] # id -> text corpus: Dict[str, str] # id -> text
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
@@ -53,6 +53,7 @@ class QADataset(BaseModel):
def save(self, path: str, mode: str = "overwrite") -> None: def save(self, path: str, mode: str = "overwrite") -> None:
"""Save to lance dataset""" """Save to lance dataset"""
self.path = path
save_dir = Path(path) save_dir = Path(path)
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
@@ -86,20 +87,28 @@ class QADataset(BaseModel):
) )
@classmethod @classmethod
def load(cls, path: str) -> "QADataset": def load(cls, path: str, version: Optional[int] = None) -> "QADataset":
"""Load from .lance data""" """Load from .lance data"""
load_dir = Path(path) load_dir = Path(path)
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict() queries = lance.dataset(load_dir / "queries.lance", version=version).to_table().to_pydict()
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict() corpus = lance.dataset(load_dir / "corpus.lance", version=version).to_table().to_pydict()
relevant_docs = ( relevant_docs = (
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict() lance.dataset(load_dir / "relevant_docs.lance", version=version).to_table().to_pydict()
) )
return cls( return cls(
path=str(path),
queries=dict(zip(queries["id"], queries["query"])), queries=dict(zip(queries["id"], queries["query"])),
corpus=dict(zip(corpus["id"], corpus["text"])), corpus=dict(zip(corpus["id"], corpus["text"])),
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])), relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
) )
@classmethod
def switch_version(cls, version: int) -> "QADataset":
"""Switch version of a dataset."""
if not cls.path:
raise ValueError("Path not set. You need to call save() first.")
return cls.load(cls.path, version=version)
# generate queries as a convenience function # generate queries as a convenience function
@classmethod @classmethod
def from_llm( def from_llm(
@@ -142,6 +151,23 @@ class QADataset(BaseModel):
"""Create a QADataset from a list of TextChunks and a list of questions.""" """Create a QADataset from a list of TextChunks and a list of questions."""
node_dict = {node.id: node.text for node in docs} node_dict = {node.id: node.text for node in docs}
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs) return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
def versions(self) -> List[int]:
"""Get the versions of the dataset."""
# TODO: tidy this up
data_paths = self._get_data_file_paths()
return lance.dataset(data_paths[0]).versions()
def _get_data_file_paths(self) -> str:
"""Get the absolute path of the dataset."""
queries = self.path / "queries.lance"
corpus = self.path / "corpus.lance"
relevant_docs = self.path / "relevant_docs.lance"
return queries, corpus, relevant_docs
class TextChunk(BaseModel): class TextChunk(BaseModel):