diff --git a/data/generate_data.py b/data/generate_data.py index ebfc11a..edd3d35 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -598,7 +598,7 @@ def generate_refusal_example(refusal_case: PileOfRefusalsType, persona: str, lan "assistant_turns": assistant_turns } -def format_example_sharegpt(example, persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format): +def format_example_sharegpt(example: Example, persona: str, language: str, use_system_role: bool, append_user_instruction_prompt: bool, use_service_names: bool, tool_response_format: str) -> DatasetEntry: piles = get_dataset_piles(language) sys_prompt = generate_system_prompt(example, persona, language, piles.pile_of_system_prompts) question = example["question"] @@ -698,7 +698,7 @@ def format_example_sharegpt(example, persona, language, use_system_role, append_ def generate_sft_file( filename: str, seed: int, - format_func: Callable, + format_func: Callable[[Example, str, str, bool, bool, bool, str], DatasetEntry], use_system_role: bool, append_user_instruction_prompt: bool, use_service_names: bool, @@ -717,7 +717,7 @@ def generate_sft_file( print("Generating...") - def run_factor_times(func: Callable[..., Example], examples: list[Example], data, persona: str, factor: int | float, language: str): + def run_factor_times(func: Callable[..., Example], examples: list[DatasetEntry], data, persona: str, factor: int | float, language: str): if factor >= 1: for i in range(int(factor)): examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) @@ -725,7 +725,7 @@ def generate_sft_file( if random.random() < factor: examples.append(format_func(func(data, persona, language, use_service_names=use_service_names), persona, language, use_system_role, append_user_instruction_prompt, use_service_names, tool_response_format)) - generated_examples: list[Example] = [] + generated_examples: list[DatasetEntry] = [] missing_responses = set()