Spaces:
Configuration error
Configuration error
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| from car_dqn import CarRacingDQN | |
| import os | |
| import tensorflow as tf | |
| import gym | |
| import _thread | |
| import re | |
| import sys | |
| import numpy as np | |
| #Ensure its running og GPU | |
| print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) | |
| load_checkpoint = True | |
| checkpoint_path = "data/checkpoints/train24" | |
| train_episodes = 15000 | |
| save_freq_episodes = train_episodes/100 ###############333 | |
| finished = False | |
| opendir = checkpoint_path + '.txt' | |
| text_results = open(opendir, "w") | |
| render = False | |
| frame_skip = 3 #frame_skip number n. model is trained n to n times only | |
| model_config = dict( | |
| min_epsilon=0.05, | |
| max_negative_rewards=8, | |
| min_experience_size=int(100), #######################################33 | |
| experience_capacity=int(150000), | |
| num_frame_stack=frame_skip, | |
| frame_skip=frame_skip, | |
| train_freq=frame_skip, | |
| batchsize=64, | |
| epsilon_decay_steps=int(100000), | |
| target_network_update_freq=int(1000), #Updates the target network every 10000 global steps by copying them from the prediction network to the target network | |
| gamma=0.95, | |
| render=False, | |
| ) | |
| dqn_scores = [] | |
| eps_history = [] | |
| avg_score_all = [0] | |
| env = gym.make('CarRacing-v0', verbose=False) | |
| tf.compat.v1.reset_default_graph | |
| dqn_agent = CarRacingDQN(env=env, **model_config) | |
| dqn_agent.build_graph() | |
| sess = tf.InteractiveSession() | |
| dqn_agent.session = sess | |
| #Initialize save checkpoints | |
| saver = tf.train.Saver(max_to_keep=1000) #max number of checkpoints = 500 | |
| #Choice to load checkpoints | |
| if load_checkpoint: | |
| train_episodes = 150 | |
| save_freq_episodes = 0 | |
| print("loading the latest checkpoint from %s" % checkpoint_path) | |
| ckpt = tf.train.get_checkpoint_state(checkpoint_path) | |
| assert ckpt, "checkpoint path %s not found" % checkpoint_path | |
| global_counter = int(re.findall("-(\d+)$", ckpt.model_checkpoint_path)[0]) | |
| saver.restore(sess, ckpt.model_checkpoint_path) | |
| dqn_agent.global_counter = global_counter | |
| render = True | |
| else: | |
| if checkpoint_path is not None: | |
| assert not os.path.exists(checkpoint_path), \ | |
| "checkpoint path already exists but load_checkpoint is false" | |
| tf.global_variables_initializer().run() | |
| def save_checkpoint(): | |
| if not os.path.exists(checkpoint_path): | |
| os.makedirs(checkpoint_path) | |
| p = os.path.join(checkpoint_path, "m.ckpt") | |
| saver.save(sess, p, dqn_agent.global_counter) | |
| print("saved to %s - %d" % (p, dqn_agent.global_counter)) | |
| def one_episode(eps_history,dqn_scores,avg_score_all,render,load_checkpoint): | |
| score, reward, frames, epsilon = dqn_agent.play_episode(render, load_checkpoint) | |
| eps_history.append(epsilon) | |
| dqn_scores.append(score) | |
| i = dqn_agent.episode_counter | |
| avg_score = np.mean(dqn_scores[max(0, i - 100):(i + 1)]) | |
| avg_score_all.append(avg_score) | |
| max_avg_score = max(avg_score_all) | |
| if avg_score >= max_avg_score: | |
| new_max = ' => New HighScore! <= ' | |
| highscore = True | |
| else: | |
| new_max = '' | |
| highscore = False | |
| strm = ("#> episode: %i | score: %.2f | total steps: %i | epsilon: %.5f | average 100 score: %.2f" % | |
| (i, score, dqn_agent.global_counter, epsilon, avg_score)) | |
| print(strm + new_max) | |
| text_results = open(opendir, "a") | |
| text_results.write(strm + new_max + '\n') | |
| text_results.close() | |
| if not load_checkpoint: | |
| save_cond = ( | |
| dqn_agent.episode_counter % save_freq_episodes == 0 | |
| and checkpoint_path is not None | |
| and dqn_agent.do_training | |
| ) | |
| if save_cond or (highscore and dqn_agent.episode_counter > 100): | |
| save_checkpoint() | |
| return eps_history,dqn_scores,avg_score_all | |
| def input_thread(list): | |
| input("...enter to stop after current episode\n") | |
| list.append("OK") | |
| def main_loop(eps_history,dqn_scores,avg_score_all,render,load_checkpoint): | |
| #call training loop | |
| list = [] | |
| _thread.start_new_thread(input_thread, (list,)) | |
| while True: | |
| if list: | |
| break | |
| if dqn_agent.do_training and dqn_agent.episode_counter >= train_episodes: | |
| break | |
| eps_history,dqn_scores,avg_score_all = one_episode(eps_history,dqn_scores,avg_score_all,render,load_checkpoint) | |
| print("done") | |
| text_results.close() | |
| exit() | |
| return eps_history,dqn_scores,avg_score_all | |
| if train_episodes > 0 and dqn_agent.episode_counter < train_episodes and not load_checkpoint : | |
| print("now training... you can early stop with enter...") | |
| print("##########") | |
| sys.stdout.flush() | |
| main_loop(eps_history,dqn_scores,avg_score_all,render,load_checkpoint) | |
| save_checkpoint() | |
| print("ok training done") | |
| else: | |
| print("now just playing...") | |
| sys.stdout.flush() | |
| main_loop(eps_history,dqn_scores,avg_score_all,render,load_checkpoint) | |