Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import IPython | |
| import random | |
| import json | |
| from gensim.utils import save_text | |
| class Memory: | |
| """ | |
| class that maintains a buffer of generated tasks and codes | |
| """ | |
| def __init__(self, cfg): | |
| self.prompt_folder = f"prompts/{cfg['prompt_folder']}" | |
| self.data_path = cfg["prompt_data_path"] | |
| self.cfg = cfg | |
| # a chat history is a list of strings | |
| self.chat_log = [] | |
| self.online_task_buffer = {} | |
| self.online_code_buffer = {} | |
| self.online_asset_buffer = {} | |
| # directly load current offline memory into online memory | |
| base_tasks, base_assets, base_task_codes = self.load_offline_memory() | |
| self.online_task_buffer.update(base_tasks) | |
| self.online_asset_buffer.update(base_assets) | |
| # load each code file | |
| for task_file in base_task_codes: | |
| # the original cliport task path | |
| if os.path.exists("cliport/tasks/" + task_file): | |
| self.online_code_buffer[task_file] = open("cliport/tasks/" + task_file).read() | |
| # the generated cliport task path | |
| elif os.path.exists("cliport/generated_tasks/" + task_file): | |
| self.online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read() | |
| print(f"load {len(self.online_code_buffer)} tasks for memory from offline to online:") | |
| cache_embedding_path = "outputs/task_cache_embedding.npz" | |
| if os.path.exists(cache_embedding_path): | |
| print("task code embeding:", cache_embedding_path) | |
| self.task_code_embedding = np.load(cache_embedding_path) | |
| def save_run(self, new_task): | |
| """save chat history and potentially save base memory""" | |
| print("save all interaction to :", f'{new_task["task-name"]}_full_output') | |
| unroll_chatlog = '' | |
| for chat in self.chat_log: | |
| unroll_chatlog += chat | |
| save_text( | |
| self.cfg['model_output_dir'], f'{new_task["task-name"]}_full_output', unroll_chatlog | |
| ) | |
| def save_task_to_online(self, new_task, code): | |
| """(not dumping the task offline). save the task information for online bootstrapping.""" | |
| self.online_task_buffer[new_task['task-name']] = new_task | |
| code_file_name = new_task["task-name"].replace("-", "_") + ".py" | |
| # code file name: actual code in contrast to offline code files format. | |
| self.online_code_buffer[code_file_name] = code | |
| def save_task_to_offline(self, new_task, code): | |
| """save the current task descriptions, assets, and code, if it passes reflection and environment test""" | |
| generated_task_code_path = os.path.join( | |
| self.cfg["prompt_data_path"], "generated_task_codes.json" | |
| ) | |
| generated_task_codes = json.load(open(generated_task_code_path)) | |
| new_file_path = new_task["task-name"].replace("-", "_") + ".py" | |
| if new_file_path not in generated_task_codes: | |
| generated_task_codes.append(new_file_path) | |
| python_file_path = "cliport/generated_tasks/" + new_file_path | |
| print(f"save {new_task['task-name']} to ", python_file_path) | |
| with open(python_file_path, "w", | |
| ) as fhandle: | |
| fhandle.write(code) | |
| with open(generated_task_code_path, "w") as outfile: | |
| json.dump(generated_task_codes, outfile, indent=4) | |
| else: | |
| print(f"{new_file_path}.py already exists.") | |
| # save task descriptions | |
| generated_task_path = os.path.join( | |
| self.cfg["prompt_data_path"], "generated_tasks.json" | |
| ) | |
| generated_tasks = json.load(open(generated_task_path)) | |
| generated_tasks[new_task["task-name"]] = new_task | |
| with open(generated_task_path, "w") as outfile: | |
| json.dump(generated_tasks, outfile, indent=4) | |
| def load_offline_memory(self): | |
| """get the current task descriptions, assets, and code""" | |
| base_task_path = os.path.join(self.data_path, "base_tasks.json") | |
| base_asset_path = os.path.join(self.data_path, "base_assets.json") | |
| base_task_code_path = os.path.join(self.data_path, "base_task_codes.json") | |
| base_tasks = json.load(open(base_task_path)) | |
| base_assets = json.load(open(base_asset_path)) | |
| base_task_codes = json.load(open(base_task_code_path)) | |
| if self.cfg["load_memory"]: | |
| generated_task_path = os.path.join(self.data_path, "generated_tasks.json") | |
| generated_asset_path = os.path.join(self.data_path, "generated_assets.json") | |
| generated_task_code_path = os.path.join(self.data_path, "generated_task_codes.json") | |
| print("original base task num:", len(base_tasks)) | |
| base_tasks.update(json.load(open(generated_task_path))) | |
| # base_assets.update(json.load(open(generated_asset_path))) | |
| for task in json.load(open(generated_task_code_path)): | |
| if task not in base_task_codes: | |
| base_task_codes.append(task) | |
| print("current base task num:", len(base_tasks)) | |
| return base_tasks, base_assets, base_task_codes | |