mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 13:52:58 +00:00
Compare commits
1 Commits
embedding_
...
tuning/dat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99d1a06a44 |
@@ -1,13 +1,13 @@
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple, Optional
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from lancedb.utils.general import LOGGER
|
||||||
from .llm import BaseLLM
|
from .llm import BaseLLM
|
||||||
|
|
||||||
DEFAULT_PROMPT_TMPL = """\
|
DEFAULT_PROMPT_TMPL = """\
|
||||||
@@ -37,7 +37,7 @@ class QADataset(BaseModel):
|
|||||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
path: Optional[str] = None
|
||||||
queries: Dict[str, str] # id -> query
|
queries: Dict[str, str] # id -> query
|
||||||
corpus: Dict[str, str] # id -> text
|
corpus: Dict[str, str] # id -> text
|
||||||
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
|
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
|
||||||
@@ -53,6 +53,7 @@ class QADataset(BaseModel):
|
|||||||
|
|
||||||
def save(self, path: str, mode: str = "overwrite") -> None:
|
def save(self, path: str, mode: str = "overwrite") -> None:
|
||||||
"""Save to lance dataset"""
|
"""Save to lance dataset"""
|
||||||
|
self.path = path
|
||||||
save_dir = Path(path)
|
save_dir = Path(path)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -86,20 +87,28 @@ class QADataset(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str) -> "QADataset":
|
def load(cls, path: str, version: Optional[int] = None) -> "QADataset":
|
||||||
"""Load from .lance data"""
|
"""Load from .lance data"""
|
||||||
load_dir = Path(path)
|
load_dir = Path(path)
|
||||||
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
|
queries = lance.dataset(load_dir / "queries.lance", version=version).to_table().to_pydict()
|
||||||
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
|
corpus = lance.dataset(load_dir / "corpus.lance", version=version).to_table().to_pydict()
|
||||||
relevant_docs = (
|
relevant_docs = (
|
||||||
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
|
lance.dataset(load_dir / "relevant_docs.lance", version=version).to_table().to_pydict()
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
|
path=str(path),
|
||||||
queries=dict(zip(queries["id"], queries["query"])),
|
queries=dict(zip(queries["id"], queries["query"])),
|
||||||
corpus=dict(zip(corpus["id"], corpus["text"])),
|
corpus=dict(zip(corpus["id"], corpus["text"])),
|
||||||
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
|
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def switch_version(cls, version: int) -> "QADataset":
|
||||||
|
"""Switch version of a dataset."""
|
||||||
|
if not cls.path:
|
||||||
|
raise ValueError("Path not set. You need to call save() first.")
|
||||||
|
return cls.load(cls.path, version=version)
|
||||||
|
|
||||||
# generate queries as a convenience function
|
# generate queries as a convenience function
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
@@ -142,6 +151,23 @@ class QADataset(BaseModel):
|
|||||||
"""Create a QADataset from a list of TextChunks and a list of questions."""
|
"""Create a QADataset from a list of TextChunks and a list of questions."""
|
||||||
node_dict = {node.id: node.text for node in docs}
|
node_dict = {node.id: node.text for node in docs}
|
||||||
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||||
|
|
||||||
|
def versions(self) -> List[int]:
|
||||||
|
"""Get the versions of the dataset."""
|
||||||
|
# TODO: tidy this up
|
||||||
|
data_paths = self._get_data_file_paths()
|
||||||
|
return lance.dataset(data_paths[0]).versions()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data_file_paths(self) -> str:
|
||||||
|
"""Get the absolute path of the dataset."""
|
||||||
|
queries = self.path / "queries.lance"
|
||||||
|
corpus = self.path / "corpus.lance"
|
||||||
|
relevant_docs = self.path / "relevant_docs.lance"
|
||||||
|
|
||||||
|
return queries, corpus, relevant_docs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextChunk(BaseModel):
|
class TextChunk(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user