mirror of
https://github.com/kolbytn/mindcraft.git
synced 2025-04-29 19:44:53 +02:00
add new options and better logging
This commit is contained in:
parent
94faf8f82a
commit
bb50486e81
1 changed files with 42 additions and 28 deletions
|
@ -75,7 +75,6 @@ def extract_result(folder_path):
|
||||||
# assert len(json_files) == 2, f"Expected 2 json files in {folder_name}, found {len(json_files)}"
|
# assert len(json_files) == 2, f"Expected 2 json files in {folder_name}, found {len(json_files)}"
|
||||||
|
|
||||||
if not json_files:
|
if not json_files:
|
||||||
print(f"No JSON files found in {folder_name}")
|
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
outcome = False
|
outcome = False
|
||||||
|
@ -193,9 +192,10 @@ def launch_parallel_experiments(task_path,
|
||||||
s3=False,
|
s3=False,
|
||||||
bucket_name="mindcraft-experiments",
|
bucket_name="mindcraft-experiments",
|
||||||
template_profile="profiles/tasks/collab_profile.json",
|
template_profile="profiles/tasks/collab_profile.json",
|
||||||
world_name="Forest",
|
|
||||||
insecure_coding=False,
|
insecure_coding=False,
|
||||||
url="http://127.0.0.1:8000/v1"):
|
url="http://127.0.0.1:8000/v1",
|
||||||
|
max_messages=15,
|
||||||
|
num_examples=2):
|
||||||
|
|
||||||
with open(task_path, 'r', encoding='utf-8') as file:
|
with open(task_path, 'r', encoding='utf-8') as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
@ -210,6 +210,13 @@ def launch_parallel_experiments(task_path,
|
||||||
task_ids = list(task_ids)
|
task_ids = list(task_ids)
|
||||||
task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)]
|
task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)]
|
||||||
|
|
||||||
|
if task_type == "cooking":
|
||||||
|
world_name = "Superflat"
|
||||||
|
elif task_type == "techtree":
|
||||||
|
world_name = "Forest"
|
||||||
|
elif task_type == "construction":
|
||||||
|
world_name = "Superflat"
|
||||||
|
|
||||||
servers = create_server_files("./server_data/", num_parallel, world_name=world_name)
|
servers = create_server_files("./server_data/", num_parallel, world_name=world_name)
|
||||||
date_time = datetime.now().strftime("%m-%d_%H-%M")
|
date_time = datetime.now().strftime("%m-%d_%H-%M")
|
||||||
experiments_folder = f"experiments/{exp_name}_{date_time}"
|
experiments_folder = f"experiments/{exp_name}_{date_time}"
|
||||||
|
@ -221,7 +228,7 @@ def launch_parallel_experiments(task_path,
|
||||||
else:
|
else:
|
||||||
task_path_name = "tasks"
|
task_path_name = "tasks"
|
||||||
|
|
||||||
s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}/"
|
s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}"
|
||||||
|
|
||||||
# start wandb
|
# start wandb
|
||||||
os.makedirs(experiments_folder, exist_ok=True)
|
os.makedirs(experiments_folder, exist_ok=True)
|
||||||
|
@ -241,7 +248,9 @@ def launch_parallel_experiments(task_path,
|
||||||
num_agents=num_agents,
|
num_agents=num_agents,
|
||||||
url=url,
|
url=url,
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
s3_path=s3_path)
|
s3_path=s3_path,
|
||||||
|
max_messages=max_messages,
|
||||||
|
num_examples=num_examples)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
total_num_tasks = len(task_ids)
|
total_num_tasks = len(task_ids)
|
||||||
|
@ -255,10 +264,11 @@ def launch_parallel_experiments(task_path,
|
||||||
with open(f"{experiments_folder}/results.txt", "w") as file:
|
with open(f"{experiments_folder}/results.txt", "w") as file:
|
||||||
file.write(str(results))
|
file.write(str(results))
|
||||||
if s3:
|
if s3:
|
||||||
s3 = boto3.client('s3')
|
cmd = f"aws s3 cp {experiments_folder}/results.txt s3://{s3_path}/results.txt"
|
||||||
s3.upload_file(f"{experiments_folder}/results.txt", bucket_name, s3_path)
|
print(cmd)
|
||||||
|
subprocess.run(cmd.split())
|
||||||
|
|
||||||
time.sleep(15)
|
time.sleep(60)
|
||||||
|
|
||||||
def launch_server_experiment(task_path,
|
def launch_server_experiment(task_path,
|
||||||
task_ids,
|
task_ids,
|
||||||
|
@ -275,7 +285,10 @@ def launch_server_experiment(task_path,
|
||||||
insecure_coding=False,
|
insecure_coding=False,
|
||||||
url="http://127.0.0.1:8000/v1",
|
url="http://127.0.0.1:8000/v1",
|
||||||
task_type="techtree",
|
task_type="techtree",
|
||||||
s3_path=""):
|
s3_path="",
|
||||||
|
max_messages=15,
|
||||||
|
num_examples=2):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Launch a Minecraft server and run experiments on it.
|
Launch a Minecraft server and run experiments on it.
|
||||||
@param task_path: Path to the task file
|
@param task_path: Path to the task file
|
||||||
|
@ -333,6 +346,8 @@ def launch_server_experiment(task_path,
|
||||||
set_environment_variable_tmux_session(session_name, "MINECRAFT_PORT", server_port)
|
set_environment_variable_tmux_session(session_name, "MINECRAFT_PORT", server_port)
|
||||||
set_environment_variable_tmux_session(session_name, "MINDSERVER_PORT", mindserver_port)
|
set_environment_variable_tmux_session(session_name, "MINDSERVER_PORT", mindserver_port)
|
||||||
set_environment_variable_tmux_session(session_name, "PROFILES", agent_profiles_str)
|
set_environment_variable_tmux_session(session_name, "PROFILES", agent_profiles_str)
|
||||||
|
set_environment_variable_tmux_session(session_name, "MAX_MESSAGES", str(max_messages))
|
||||||
|
set_environment_variable_tmux_session(session_name, "NUM_EXAMPLES", str(num_examples))
|
||||||
if insecure_coding:
|
if insecure_coding:
|
||||||
set_environment_variable_tmux_session(session_name, "INSECURE_CODING", "true")
|
set_environment_variable_tmux_session(session_name, "INSECURE_CODING", "true")
|
||||||
|
|
||||||
|
@ -569,8 +584,6 @@ def test_server_running(port=55916):
|
||||||
print("Server is not running on port 55916")
|
print("Server is not running on port 55916")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def kill_world(session_name="server"):
|
def kill_world(session_name="server"):
|
||||||
"""Kill the Minecraft world."""
|
"""Kill the Minecraft world."""
|
||||||
subprocess.run(["tmux", "send-keys", "-t", session_name, "stop", "C-m"])
|
subprocess.run(["tmux", "send-keys", "-t", session_name, "stop", "C-m"])
|
||||||
|
@ -615,7 +628,6 @@ def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run Minecraft AI agent experiments')
|
parser = argparse.ArgumentParser(description='Run Minecraft AI agent experiments')
|
||||||
parser.add_argument('--task_path', default="multiagent_crafting_tasks.json", help='Path to the task file')
|
parser.add_argument('--task_path', default="multiagent_crafting_tasks.json", help='Path to the task file')
|
||||||
parser.add_argument('--task_id', default=None, help='ID of the task to run')
|
|
||||||
parser.add_argument('--num_agents', default=2, type=int, help='Number of agents to run')
|
parser.add_argument('--num_agents', default=2, type=int, help='Number of agents to run')
|
||||||
parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run')
|
parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run')
|
||||||
parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run')
|
parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run')
|
||||||
|
@ -626,9 +638,11 @@ def main():
|
||||||
parser.add_argument('--template_profile', default="profiles/tasks/collab_profile.json", help='Model to use for the agents')
|
parser.add_argument('--template_profile', default="profiles/tasks/collab_profile.json", help='Model to use for the agents')
|
||||||
parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents')
|
parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents')
|
||||||
parser.add_argument('--api', default="openai", help='API to use for the agents')
|
parser.add_argument('--api', default="openai", help='API to use for the agents')
|
||||||
parser.add_argument('--world_name', default="Forest", help='Name of the world')
|
# parser.add_argument('--world_name', default="Forest", help='Name of the world')
|
||||||
parser.add_argument('--insecure_coding', action='store_true', help='Enable insecure coding')
|
parser.add_argument('--insecure_coding', action='store_true', help='Enable insecure coding')
|
||||||
parser.add_argument('--url', default="http://127.0.0.1:8000/v1")
|
parser.add_argument('--url', default="http://127.0.0.1:8000/v1")
|
||||||
|
parser.add_argument('--max_messages', default=15, type=int, help='Maximum number of messages before summarizing')
|
||||||
|
parser.add_argument('--num_examples', default=2, type=int, help='Maximum number of turns before summarizing')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
@ -642,21 +656,21 @@ def main():
|
||||||
clean_up_server_files(args.num_parallel)
|
clean_up_server_files(args.num_parallel)
|
||||||
if args.add_keys:
|
if args.add_keys:
|
||||||
update_keys_json()
|
update_keys_json()
|
||||||
if args.task_id is None:
|
|
||||||
launch_parallel_experiments(args.task_path,
|
launch_parallel_experiments(args.task_path,
|
||||||
num_exp=args.num_exp,
|
num_exp=args.num_exp,
|
||||||
exp_name=args.exp_name,
|
exp_name=args.exp_name,
|
||||||
num_parallel=args.num_parallel,
|
num_parallel=args.num_parallel,
|
||||||
s3=args.s3,
|
s3=args.s3,
|
||||||
bucket_name=args.bucket_name,
|
bucket_name=args.bucket_name,
|
||||||
template_profile=args.template_profile,
|
template_profile=args.template_profile,
|
||||||
model=args.model,
|
model=args.model,
|
||||||
api=args.api,
|
api=args.api,
|
||||||
world_name=args.world_name,
|
insecure_coding=args.insecure_coding,
|
||||||
insecure_coding=args.insecure_coding,
|
num_agents=args.num_agents,
|
||||||
num_agents=args.num_agents,
|
url=args.url,
|
||||||
url=args.url)
|
max_messages=args.max_messages,
|
||||||
cmd = "aws s3"
|
num_examples=args.num_examples)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
Loading…
Add table
Reference in a new issue