-
Notifications
You must be signed in to change notification settings - Fork 1
/
common.py
104 lines (87 loc) · 3.14 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (C) 2023 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
import logging
import os.path
import re
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
import datasets
import rathe
from rathe import InstructPrompt, Prompt, PromptParser
from rathe.conversion import ConversionContext
SCORE_RE = re.compile(r"[sS]core: ([0-5](\.5)?)")
def alpaca_json(prompt: rathe.InstructPrompt) -> Dict:
return {
"instruction": prompt.instruction,
"input": prompt.input,
"output": prompt.output,
}
def load_data(
dataset: str,
parser: PromptParser,
get_judge_prompt: Callable[[Prompt, ConversionContext], InstructPrompt],
conversion_context: Optional[ConversionContext] = None,
data_files: Optional[List[str]] = None,
offset: int = -1,
last_index: int = -1,
shuffle: bool = False,
) -> datasets.Dataset:
logging.info("loading dataset...")
data = datasets.load_dataset(dataset, data_files=data_files)
if "train" in data:
data = data["train"]
if "id" not in data.column_names:
data = data.map(
lambda e, idx: {"id": f"{dataset}.{idx}", **e}, with_indices=True
)
if offset > 0 or last_index >= 0:
if last_index < 0:
last_index = len(data)
if offset < 0:
offset = 0
logging.info(f"selecting samples from {offset} to {last_index}")
data = data.select(range(offset, last_index))
if conversion_context is None:
conversion_context = ConversionContext.default()
logging.info("parsing and formatting prompts...")
t = PromptTransform(parser, get_judge_prompt, conversion_context)
data: datasets.Dataset = data.map(
t,
num_proc=os.cpu_count(),
)
if shuffle:
data = data.shuffle(seed=4)
return data
@dataclass
class PromptTransform:
parser: rathe.PromptParser
get_judge_prompt: Callable[[Prompt, ConversionContext], InstructPrompt]
conversion_context: Optional[ConversionContext]
def __call__(self, row: Dict) -> Dict:
parsed = self.parser.parse(row)
judge_prompt = self.get_judge_prompt(parsed, self.conversion_context)
res = alpaca_json(judge_prompt)
res["id"] = row["id"]
return res
def parse_score(row_id: str, response_text: str):
match = SCORE_RE.search(response_text)
if match:
score = float(match.group(1))
return {
"id": row_id,
"score": score,
"response": response_text,
}
return None