diff --git a/evaluation_script.py b/evaluation_script.py index 076b3da..6720055 100644 --- a/evaluation_script.py +++ b/evaluation_script.py @@ -75,7 +75,6 @@ def extract_result(folder_path): # assert len(json_files) == 2, f"Expected 2 json files in {folder_name}, found {len(json_files)}" if not json_files: - print(f"No JSON files found in {folder_name}") return None else: outcome = False @@ -193,9 +192,10 @@ def launch_parallel_experiments(task_path, s3=False, bucket_name="mindcraft-experiments", template_profile="profiles/tasks/collab_profile.json", - world_name="Forest", insecure_coding=False, - url="http://127.0.0.1:8000/v1"): + url="http://127.0.0.1:8000/v1", + max_messages=15, + num_examples=2): with open(task_path, 'r', encoding='utf-8') as file: content = file.read() @@ -210,6 +210,13 @@ def launch_parallel_experiments(task_path, task_ids = list(task_ids) task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)] + if task_type == "cooking": + world_name = "Superflat" + elif task_type == "techtree": + world_name = "Forest" + elif task_type == "construction": + world_name = "Superflat" + servers = create_server_files("./server_data/", num_parallel, world_name=world_name) date_time = datetime.now().strftime("%m-%d_%H-%M") experiments_folder = f"experiments/{exp_name}_{date_time}" @@ -221,7 +228,7 @@ def launch_parallel_experiments(task_path, else: task_path_name = "tasks" - s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}/" + s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}" # start wandb os.makedirs(experiments_folder, exist_ok=True) @@ -241,7 +248,9 @@ def launch_parallel_experiments(task_path, num_agents=num_agents, url=url, task_type=task_type, - s3_path=s3_path) + s3_path=s3_path, + max_messages=max_messages, + num_examples=num_examples) time.sleep(5) total_num_tasks = len(task_ids) @@ -255,10 +264,11 @@ def launch_parallel_experiments(task_path, with open(f"{experiments_folder}/results.txt", "w") as file: file.write(str(results)) if s3: - s3 = boto3.client('s3') - s3.upload_file(f"{experiments_folder}/results.txt", bucket_name, s3_path) + cmd = f"aws s3 cp {experiments_folder}/results.txt s3://{s3_path}/results.txt" + print(cmd) + subprocess.run(cmd.split()) - time.sleep(15) + time.sleep(60) def launch_server_experiment(task_path, task_ids, @@ -275,7 +285,10 @@ def launch_server_experiment(task_path, insecure_coding=False, url="http://127.0.0.1:8000/v1", task_type="techtree", - s3_path=""): + s3_path="", + max_messages=15, + num_examples=2): + """ Launch a Minecraft server and run experiments on it. @param task_path: Path to the task file @@ -333,6 +346,8 @@ def launch_server_experiment(task_path, set_environment_variable_tmux_session(session_name, "MINECRAFT_PORT", server_port) set_environment_variable_tmux_session(session_name, "MINDSERVER_PORT", mindserver_port) set_environment_variable_tmux_session(session_name, "PROFILES", agent_profiles_str) + set_environment_variable_tmux_session(session_name, "MAX_MESSAGES", str(max_messages)) + set_environment_variable_tmux_session(session_name, "NUM_EXAMPLES", str(num_examples)) if insecure_coding: set_environment_variable_tmux_session(session_name, "INSECURE_CODING", "true") @@ -569,8 +584,6 @@ def test_server_running(port=55916): print("Server is not running on port 55916") return False - - def kill_world(session_name="server"): """Kill the Minecraft world.""" subprocess.run(["tmux", "send-keys", "-t", session_name, "stop", "C-m"]) @@ -615,7 +628,6 @@ def main(): parser = argparse.ArgumentParser(description='Run Minecraft AI agent experiments') parser.add_argument('--task_path', default="multiagent_crafting_tasks.json", help='Path to the task file') - parser.add_argument('--task_id', default=None, help='ID of the task to run') parser.add_argument('--num_agents', default=2, type=int, help='Number of agents to run') parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run') parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run') @@ -626,9 +638,11 @@ def main(): parser.add_argument('--template_profile', default="profiles/tasks/collab_profile.json", help='Model to use for the agents') parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents') parser.add_argument('--api', default="openai", help='API to use for the agents') - parser.add_argument('--world_name', default="Forest", help='Name of the world') + # parser.add_argument('--world_name', default="Forest", help='Name of the world') parser.add_argument('--insecure_coding', action='store_true', help='Enable insecure coding') parser.add_argument('--url', default="http://127.0.0.1:8000/v1") + parser.add_argument('--max_messages', default=15, type=int, help='Maximum number of messages before summarizing') + parser.add_argument('--num_examples', default=2, type=int, help='Maximum number of turns before summarizing') args = parser.parse_args() print(args) @@ -642,21 +656,21 @@ def main(): clean_up_server_files(args.num_parallel) if args.add_keys: update_keys_json() - if args.task_id is None: - launch_parallel_experiments(args.task_path, - num_exp=args.num_exp, - exp_name=args.exp_name, - num_parallel=args.num_parallel, - s3=args.s3, - bucket_name=args.bucket_name, - template_profile=args.template_profile, - model=args.model, - api=args.api, - world_name=args.world_name, - insecure_coding=args.insecure_coding, - num_agents=args.num_agents, - url=args.url) - cmd = "aws s3" + + launch_parallel_experiments(args.task_path, + num_exp=args.num_exp, + exp_name=args.exp_name, + num_parallel=args.num_parallel, + s3=args.s3, + bucket_name=args.bucket_name, + template_profile=args.template_profile, + model=args.model, + api=args.api, + insecure_coding=args.insecure_coding, + num_agents=args.num_agents, + url=args.url, + max_messages=args.max_messages, + num_examples=args.num_examples) if __name__ == "__main__": main() \ No newline at end of file