add dataset features

This commit is contained in:
ayush chaurasia
2024-04-05 16:34:21 +05:30
parent f23641d703
commit 99d1a06a44

View File

@@ -1,13 +1,13 @@
import re
import uuid
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Optional
import lance
import pyarrow as pa
from pydantic import BaseModel
from tqdm import tqdm
from lancedb.utils.general import LOGGER
from .llm import BaseLLM
DEFAULT_PROMPT_TMPL = """\
@@ -37,7 +37,7 @@ class QADataset(BaseModel):
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
"""
path: Optional[str] = None
queries: Dict[str, str] # id -> query
corpus: Dict[str, str] # id -> text
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
@@ -53,6 +53,7 @@ class QADataset(BaseModel):
def save(self, path: str, mode: str = "overwrite") -> None:
"""Save to lance dataset"""
self.path = path
save_dir = Path(path)
save_dir.mkdir(parents=True, exist_ok=True)
@@ -86,20 +87,28 @@ class QADataset(BaseModel):
)
@classmethod
def load(cls, path: str) -> "QADataset":
def load(cls, path: str, version: Optional[int] = None) -> "QADataset":
"""Load from .lance data"""
load_dir = Path(path)
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
queries = lance.dataset(load_dir / "queries.lance", version=version).to_table().to_pydict()
corpus = lance.dataset(load_dir / "corpus.lance", version=version).to_table().to_pydict()
relevant_docs = (
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
lance.dataset(load_dir / "relevant_docs.lance", version=version).to_table().to_pydict()
)
return cls(
path=str(path),
queries=dict(zip(queries["id"], queries["query"])),
corpus=dict(zip(corpus["id"], corpus["text"])),
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
)
@classmethod
def switch_version(cls, version: int) -> "QADataset":
"""Switch version of a dataset."""
if not cls.path:
raise ValueError("Path not set. You need to call save() first.")
return cls.load(cls.path, version=version)
# generate queries as a convenience function
@classmethod
def from_llm(
@@ -142,6 +151,23 @@ class QADataset(BaseModel):
"""Create a QADataset from a list of TextChunks and a list of questions."""
node_dict = {node.id: node.text for node in docs}
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
def versions(self) -> List[int]:
"""Get the versions of the dataset."""
# TODO: tidy this up
data_paths = self._get_data_file_paths()
return lance.dataset(data_paths[0]).versions()
def _get_data_file_paths(self) -> str:
"""Get the absolute path of the dataset."""
queries = self.path / "queries.lance"
corpus = self.path / "corpus.lance"
relevant_docs = self.path / "relevant_docs.lance"
return queries, corpus, relevant_docs
class TextChunk(BaseModel):