mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
add dataset features
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
|
||||
import lance
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from lancedb.utils.general import LOGGER
|
||||
from .llm import BaseLLM
|
||||
|
||||
DEFAULT_PROMPT_TMPL = """\
|
||||
@@ -37,7 +37,7 @@ class QADataset(BaseModel):
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
|
||||
"""
|
||||
|
||||
path: Optional[str] = None
|
||||
queries: Dict[str, str] # id -> query
|
||||
corpus: Dict[str, str] # id -> text
|
||||
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
|
||||
@@ -53,6 +53,7 @@ class QADataset(BaseModel):
|
||||
|
||||
def save(self, path: str, mode: str = "overwrite") -> None:
|
||||
"""Save to lance dataset"""
|
||||
self.path = path
|
||||
save_dir = Path(path)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -86,20 +87,28 @@ class QADataset(BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "QADataset":
|
||||
def load(cls, path: str, version: Optional[int] = None) -> "QADataset":
|
||||
"""Load from .lance data"""
|
||||
load_dir = Path(path)
|
||||
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
|
||||
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
|
||||
queries = lance.dataset(load_dir / "queries.lance", version=version).to_table().to_pydict()
|
||||
corpus = lance.dataset(load_dir / "corpus.lance", version=version).to_table().to_pydict()
|
||||
relevant_docs = (
|
||||
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
|
||||
lance.dataset(load_dir / "relevant_docs.lance", version=version).to_table().to_pydict()
|
||||
)
|
||||
return cls(
|
||||
path=str(path),
|
||||
queries=dict(zip(queries["id"], queries["query"])),
|
||||
corpus=dict(zip(corpus["id"], corpus["text"])),
|
||||
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def switch_version(cls, version: int) -> "QADataset":
|
||||
"""Switch version of a dataset."""
|
||||
if not cls.path:
|
||||
raise ValueError("Path not set. You need to call save() first.")
|
||||
return cls.load(cls.path, version=version)
|
||||
|
||||
# generate queries as a convenience function
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@@ -142,6 +151,23 @@ class QADataset(BaseModel):
|
||||
"""Create a QADataset from a list of TextChunks and a list of questions."""
|
||||
node_dict = {node.id: node.text for node in docs}
|
||||
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
def versions(self) -> List[int]:
|
||||
"""Get the versions of the dataset."""
|
||||
# TODO: tidy this up
|
||||
data_paths = self._get_data_file_paths()
|
||||
return lance.dataset(data_paths[0]).versions()
|
||||
|
||||
|
||||
def _get_data_file_paths(self) -> str:
|
||||
"""Get the absolute path of the dataset."""
|
||||
queries = self.path / "queries.lance"
|
||||
corpus = self.path / "corpus.lance"
|
||||
relevant_docs = self.path / "relevant_docs.lance"
|
||||
|
||||
return queries, corpus, relevant_docs
|
||||
|
||||
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user