mirror of
https://github.com/kolbytn/mindcraft.git
synced 2025-04-29 19:44:53 +02:00
add new options and better logging
This commit is contained in:
parent
94faf8f82a
commit
bb50486e81
1 changed files with 42 additions and 28 deletions
|
@ -75,7 +75,6 @@ def extract_result(folder_path):
|
|||
# assert len(json_files) == 2, f"Expected 2 json files in {folder_name}, found {len(json_files)}"
|
||||
|
||||
if not json_files:
|
||||
print(f"No JSON files found in {folder_name}")
|
||||
return None
|
||||
else:
|
||||
outcome = False
|
||||
|
@ -193,9 +192,10 @@ def launch_parallel_experiments(task_path,
|
|||
s3=False,
|
||||
bucket_name="mindcraft-experiments",
|
||||
template_profile="profiles/tasks/collab_profile.json",
|
||||
world_name="Forest",
|
||||
insecure_coding=False,
|
||||
url="http://127.0.0.1:8000/v1"):
|
||||
url="http://127.0.0.1:8000/v1",
|
||||
max_messages=15,
|
||||
num_examples=2):
|
||||
|
||||
with open(task_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
|
@ -210,6 +210,13 @@ def launch_parallel_experiments(task_path,
|
|||
task_ids = list(task_ids)
|
||||
task_ids_split = [task_ids[i::num_parallel] for i in range(num_parallel)]
|
||||
|
||||
if task_type == "cooking":
|
||||
world_name = "Superflat"
|
||||
elif task_type == "techtree":
|
||||
world_name = "Forest"
|
||||
elif task_type == "construction":
|
||||
world_name = "Superflat"
|
||||
|
||||
servers = create_server_files("./server_data/", num_parallel, world_name=world_name)
|
||||
date_time = datetime.now().strftime("%m-%d_%H-%M")
|
||||
experiments_folder = f"experiments/{exp_name}_{date_time}"
|
||||
|
@ -221,7 +228,7 @@ def launch_parallel_experiments(task_path,
|
|||
else:
|
||||
task_path_name = "tasks"
|
||||
|
||||
s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}/"
|
||||
s3_path = f"{bucket_name}/{task_type}/{model}/{task_path_name}/{exp_name}"
|
||||
|
||||
# start wandb
|
||||
os.makedirs(experiments_folder, exist_ok=True)
|
||||
|
@ -241,7 +248,9 @@ def launch_parallel_experiments(task_path,
|
|||
num_agents=num_agents,
|
||||
url=url,
|
||||
task_type=task_type,
|
||||
s3_path=s3_path)
|
||||
s3_path=s3_path,
|
||||
max_messages=max_messages,
|
||||
num_examples=num_examples)
|
||||
time.sleep(5)
|
||||
|
||||
total_num_tasks = len(task_ids)
|
||||
|
@ -255,10 +264,11 @@ def launch_parallel_experiments(task_path,
|
|||
with open(f"{experiments_folder}/results.txt", "w") as file:
|
||||
file.write(str(results))
|
||||
if s3:
|
||||
s3 = boto3.client('s3')
|
||||
s3.upload_file(f"{experiments_folder}/results.txt", bucket_name, s3_path)
|
||||
cmd = f"aws s3 cp {experiments_folder}/results.txt s3://{s3_path}/results.txt"
|
||||
print(cmd)
|
||||
subprocess.run(cmd.split())
|
||||
|
||||
time.sleep(15)
|
||||
time.sleep(60)
|
||||
|
||||
def launch_server_experiment(task_path,
|
||||
task_ids,
|
||||
|
@ -275,7 +285,10 @@ def launch_server_experiment(task_path,
|
|||
insecure_coding=False,
|
||||
url="http://127.0.0.1:8000/v1",
|
||||
task_type="techtree",
|
||||
s3_path=""):
|
||||
s3_path="",
|
||||
max_messages=15,
|
||||
num_examples=2):
|
||||
|
||||
"""
|
||||
Launch a Minecraft server and run experiments on it.
|
||||
@param task_path: Path to the task file
|
||||
|
@ -333,6 +346,8 @@ def launch_server_experiment(task_path,
|
|||
set_environment_variable_tmux_session(session_name, "MINECRAFT_PORT", server_port)
|
||||
set_environment_variable_tmux_session(session_name, "MINDSERVER_PORT", mindserver_port)
|
||||
set_environment_variable_tmux_session(session_name, "PROFILES", agent_profiles_str)
|
||||
set_environment_variable_tmux_session(session_name, "MAX_MESSAGES", str(max_messages))
|
||||
set_environment_variable_tmux_session(session_name, "NUM_EXAMPLES", str(num_examples))
|
||||
if insecure_coding:
|
||||
set_environment_variable_tmux_session(session_name, "INSECURE_CODING", "true")
|
||||
|
||||
|
@ -569,8 +584,6 @@ def test_server_running(port=55916):
|
|||
print("Server is not running on port 55916")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def kill_world(session_name="server"):
|
||||
"""Kill the Minecraft world."""
|
||||
subprocess.run(["tmux", "send-keys", "-t", session_name, "stop", "C-m"])
|
||||
|
@ -615,7 +628,6 @@ def main():
|
|||
|
||||
parser = argparse.ArgumentParser(description='Run Minecraft AI agent experiments')
|
||||
parser.add_argument('--task_path', default="multiagent_crafting_tasks.json", help='Path to the task file')
|
||||
parser.add_argument('--task_id', default=None, help='ID of the task to run')
|
||||
parser.add_argument('--num_agents', default=2, type=int, help='Number of agents to run')
|
||||
parser.add_argument('--num_exp', default=1, type=int, help='Number of experiments to run')
|
||||
parser.add_argument('--num_parallel', default=1, type=int, help='Number of parallel servers to run')
|
||||
|
@ -626,9 +638,11 @@ def main():
|
|||
parser.add_argument('--template_profile', default="profiles/tasks/collab_profile.json", help='Model to use for the agents')
|
||||
parser.add_argument('--model', default="gpt-4o-mini", help='Model to use for the agents')
|
||||
parser.add_argument('--api', default="openai", help='API to use for the agents')
|
||||
parser.add_argument('--world_name', default="Forest", help='Name of the world')
|
||||
# parser.add_argument('--world_name', default="Forest", help='Name of the world')
|
||||
parser.add_argument('--insecure_coding', action='store_true', help='Enable insecure coding')
|
||||
parser.add_argument('--url', default="http://127.0.0.1:8000/v1")
|
||||
parser.add_argument('--max_messages', default=15, type=int, help='Maximum number of messages before summarizing')
|
||||
parser.add_argument('--num_examples', default=2, type=int, help='Maximum number of turns before summarizing')
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
@ -642,21 +656,21 @@ def main():
|
|||
clean_up_server_files(args.num_parallel)
|
||||
if args.add_keys:
|
||||
update_keys_json()
|
||||
if args.task_id is None:
|
||||
launch_parallel_experiments(args.task_path,
|
||||
num_exp=args.num_exp,
|
||||
exp_name=args.exp_name,
|
||||
num_parallel=args.num_parallel,
|
||||
s3=args.s3,
|
||||
bucket_name=args.bucket_name,
|
||||
template_profile=args.template_profile,
|
||||
model=args.model,
|
||||
api=args.api,
|
||||
world_name=args.world_name,
|
||||
insecure_coding=args.insecure_coding,
|
||||
num_agents=args.num_agents,
|
||||
url=args.url)
|
||||
cmd = "aws s3"
|
||||
|
||||
launch_parallel_experiments(args.task_path,
|
||||
num_exp=args.num_exp,
|
||||
exp_name=args.exp_name,
|
||||
num_parallel=args.num_parallel,
|
||||
s3=args.s3,
|
||||
bucket_name=args.bucket_name,
|
||||
template_profile=args.template_profile,
|
||||
model=args.model,
|
||||
api=args.api,
|
||||
insecure_coding=args.insecure_coding,
|
||||
num_agents=args.num_agents,
|
||||
url=args.url,
|
||||
max_messages=args.max_messages,
|
||||
num_examples=args.num_examples)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Reference in a new issue