add new options and better logging

This commit is contained in:
Isadora White 2025-03-19 23:29:11 -05:00
parent 94faf8f82a
commit bb50486e81

View file

@ -75,7 +75,6 @@ def extract_result(folder_path):
# assert len(json_files) == 2, f"Expected 2 json files in {folder_name}, found {len(json_files)}"
if not json_files:
print(f"No JSON files found in {folder_name}")
return None
else:
outcome = False
@ -193,9 +192,10 @@ def launch_parallel_experiments(task_path,
s3=False,
bucket_name="mindcraft-experiments",
template_profile="profiles/tasks/collab_profile.json",
world_name="Forest",
insecure_coding=False,
url="http://127.0.0.1:8000/v1"):
url="http://127.0.0.1:8000/v1",
max_messages=15,
num_examples=2):
with open(task_path, 'r', encoding='utf-8') as file:
content = file.read()
@ -210,6 +210,13 @@ def launch_parallel_experiments(task_path,
task_ids = list(task_ids)
task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)]
if task_type == "cooking":
world_name = "Superflat"
elif task_type == "techtree":
world_name = "Forest"
elif task_type == "construction":
world_name = "Superflat"
servers = create_server_files("./server_data/", num_parallel, world_name=world_name)
date_time = datetime.now().strftime("%m-%d_%H-%M")
experiments_folder = f"experiments/{exp_name}_{date_time}"
@ -221,7 +228,7 @@ def launch_parallel_experiments(task_path,
else:
task_path_name = "tasks"
s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}/"
s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}"
# start wandb
os.makedirs(experiments_folder, exist_ok=True)
@ -241,7 +248,9 @@ def launch_parallel_experiments(task_path,
num_agents=num_agents,
url=url,
task_type=task_type,
s3_path=s3_path)
s3_path=s3_path,
max_messages=max_messages,
num_examples=num_examples)
time.sleep(5)
total_num_tasks = len(task_ids)
@ -255,10 +264,11 @@ def launch_parallel_experiments(task_path,
with open(f"{experiments_folder}/results.txt", "w") as file:
file.write(str(results))
if s3:
s3 = boto3.client('s3')
s3.upload_file(f"{experiments_folder}/results.txt", bucket_name, s3_path)
cmd = f"aws s3 cp {experiments_folder}/results.txt s3://{s3_path}/results.txt"
print(cmd)
subprocess.run(cmd.split())
time.sleep(15)
time.sleep(60)
def launch_server_experiment(task_path,
task_ids,
@ -275,7 +285,10 @@ def launch_server_experiment(task_path,
insecure_coding=False,
url="http://127.0.0.1:8000/v1",
task_type="techtree",
s3_path=""):
s3_path="",
max_messages=15,
num_examples=2):
"""
Launch a Minecraft server and run experiments on it.
@param task_path: Path to the task file
@ -333,6 +346,8 @@ def launch_server_experiment(task_path,
set_environment_variable_tmux_session(session_name, "MINECRAFT_PORT", server_port)
set_environment_variable_tmux_session(session_name, "MINDSERVER_PORT", mindserver_port)
set_environment_variable_tmux_session(session_name, "PROFILES", agent_profiles_str)
set_environment_variable_tmux_session(session_name, "MAX_MESSAGES", str(max_messages))
set_environment_variable_tmux_session(session_name, "NUM_EXAMPLES", str(num_examples))
if insecure_coding:
set_environment_variable_tmux_session(session_name, "INSECURE_CODING", "true")
@ -569,8 +584,6 @@ def test_server_running(port=55916):
print("Server is not running on port 55916")
return False
def kill_world(session_name="server"):
"""Kill the Minecraft world."""
subprocess.run(["tmux", "send-keys", "-t", session_name, "stop", "C-m"])
@ -615,7 +628,6 @@ def main():
parser = argparse.ArgumentParser(description='Run Minecraft AI agent experiments')
parser.add_argument('--task_path', default="multiagent_crafting_tasks.json", help='Path to the task file')
parser.add_argument('--task_id', default=None, help='ID of the task to run')
parser.add_argument('--num_agents', default=2, type=int, help='Number of agents to run')
parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run')
parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run')
@ -626,9 +638,11 @@ def main():
parser.add_argument('--template_profile', default="profiles/tasks/collab_profile.json", help='Model to use for the agents')
parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents')
parser.add_argument('--api', default="openai", help='API to use for the agents')
parser.add_argument('--world_name', default="Forest", help='Name of the world')
# parser.add_argument('--world_name', default="Forest", help='Name of the world')
parser.add_argument('--insecure_coding', action='store_true', help='Enable insecure coding')
parser.add_argument('--url', default="http://127.0.0.1:8000/v1")
parser.add_argument('--max_messages', default=15, type=int, help='Maximum number of messages before summarizing')
parser.add_argument('--num_examples', default=2, type=int, help='Maximum number of turns before summarizing')
args = parser.parse_args()
print(args)
@ -642,21 +656,21 @@ def main():
clean_up_server_files(args.num_parallel)
if args.add_keys:
update_keys_json()
if args.task_id is None:
launch_parallel_experiments(args.task_path,
num_exp=args.num_exp,
exp_name=args.exp_name,
num_parallel=args.num_parallel,
s3=args.s3,
bucket_name=args.bucket_name,
template_profile=args.template_profile,
model=args.model,
api=args.api,
world_name=args.world_name,
insecure_coding=args.insecure_coding,
num_agents=args.num_agents,
url=args.url)
cmd = "aws s3"
launch_parallel_experiments(args.task_path,
num_exp=args.num_exp,
exp_name=args.exp_name,
num_parallel=args.num_parallel,
s3=args.s3,
bucket_name=args.bucket_name,
template_profile=args.template_profile,
model=args.model,
api=args.api,
insecure_coding=args.insecure_coding,
num_agents=args.num_agents,
url=args.url,
max_messages=args.max_messages,
num_examples=args.num_examples)
if __name__ == "__main__":
main()