Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import IPython | |
| import traceback | |
| import json | |
| from gensim.utils import ( | |
| save_text, | |
| add_to_txt, | |
| extract_dict, | |
| format_dict_prompt, | |
| generate_feedback, | |
| ) | |
| import copy | |
| import random | |
| class Critic: | |
| """ | |
| class that reflects and criticizes new task for improvement | |
| """ | |
| def __init__(self, cfg, memory): | |
| self.prompt_folder = f"prompts/{cfg['prompt_folder']}" | |
| self.memory = memory | |
| self.chat_log = self.memory.chat_log | |
| self.cfg = cfg | |
| self.model_output_dir = cfg["model_output_dir"] | |
| def error_review(self, new_task): | |
| """ commonly made error review """ | |
| if os.path.exists(f"{self.prompt_folder}/cliport_prompt_common_errors_template.txt") and "task-name" in new_task: | |
| self.chat_log = add_to_txt(self.chat_log, "================= Error Book Preview!", with_print=True) | |
| errorbook_prompt_text = open(f'{self.prompt_folder}/cliport_prompt_common_errors_template.txt').read() | |
| errorbook_prompt_text = errorbook_prompt_text.replace("TASK_NAME_TEMPLATE", new_task["task-name"]) | |
| res = generate_feedback(errorbook_prompt_text, temperature=0., interaction_txt=self.chat_log) # cfg['gpt_temperature'] | |
| def reflection(self, new_task, new_code, current_tasks=None): | |
| """ reflect on if the new task needs to be added """ | |
| all_add_to_the_task_list_flag = True | |
| if os.path.exists(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt"): | |
| # only consider successful task | |
| self.chat_log = add_to_txt(self.chat_log, "================= Code Reflect!", with_print=True) | |
| total_tasks = copy.deepcopy(self.memory.online_task_buffer) | |
| if current_tasks is not None: | |
| # adding all the tasks in the current run. at least should not overlap with those | |
| for t in current_tasks: | |
| total_tasks[t['task-name']] = t | |
| # need to load more | |
| total_tasks = self.memory.online_task_buffer | |
| MAX_NUM = 40 | |
| if len(total_tasks) > MAX_NUM: | |
| total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM)) | |
| print("reflection history task num:", len(total_tasks)) | |
| task_descriptions_replacement_str = format_dict_prompt(total_tasks, -1) | |
| # append current new task | |
| code_reflection_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task_reflection.txt").read() | |
| code_reflection_prompt_text = code_reflection_prompt_text.replace("CURRENT_TASK_NAME_TEMPLATE", str(task_descriptions_replacement_str)) | |
| code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_STRING_TEMPLATE", str(new_task)) | |
| code_reflection_prompt_text = code_reflection_prompt_text.replace("TASK_CODE_TEMPLATE", str(new_code)) | |
| if len(self.cfg['target_task_name']) > 0: | |
| code_reflection_prompt_text = code_reflection_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name']) | |
| # no matter | |
| total_tasks[new_task["task-name"].replace("-", "_")] = str(new_task) | |
| res = generate_feedback(code_reflection_prompt_text, temperature=0.4, interaction_txt=self.chat_log, n=int(self.cfg['reflection_agreement_num'])) # cfg['gpt_temperature'] | |
| all_add_to_the_task_list_flag = True | |
| for idx, r in enumerate(res): | |
| # iterate through for agreement | |
| reflection_def_cmd = extract_dict(r, prefix='task_reflection') | |
| exec(reflection_def_cmd, globals()) | |
| try: | |
| print(f"critic {idx}:", task_reflection) | |
| if task_reflection["add_to_the_task_list"] == 'False': | |
| all_add_to_the_task_list_flag = False | |
| print(f"critic {idx} suggests not adding this task to the buffer! ") | |
| except: | |
| IPython.embed() | |
| save_text(self.model_output_dir, new_task['task-name'] + "_reflection_output", str(task_reflection)) | |
| return all_add_to_the_task_list_flag |