Q-learning is simple yet powerful off-policy model-free reinforcement algorithm that creates a table of state-action values (Q-values) for the agent that helps the agent figure out exactly which action to perform.
Q-learning updates are done as shown above whereas actions are chosen epsilon-greedily i.e. with a policy where with epsilon probability random action is chosen for the next state, and with the rest of probability greedy action with highest Q-value is picked. In order to tackle situations where both state and action spaces are too large for.eg. in case of Atari games, this Q-learning algorithm needs a boost. We have two issues here. First, the amount of memory required to save and update that table would increase as the number of states increases. Second, the amount of time required to explore each state to create the required Q-table would be unrealistic. So, let’s all hail Deep Q learning!
In Deep Q-learning, we use a neural network to approximate the Q-value function. The state is given as the input and the Q-value of all possible actions is generated as the output. The algorithm is as depicted below.
However, the target is continuously changing with each iteration. In deep learning, the target variable does not change and hence the training is stable, which is just not true for RL.
Two changes have been made to tackle this issue: Use of Target Network and Experience Replay Memory. Since the same network is calculating the predicted value and the target value, there could be a lot of divergence between these two. So, instead of using one neural network for learning, we can use two.
We could use a separate network to estimate the target. This target network has the same architecture as the function approximator but with frozen parameters. For every C iterations (a hyperparameter), the parameters from the prediction network are copied to the target network. This leads to more stable training because it keeps the target function fixed (for a while).
Also, instead of running Q-learning on state/action pairs as they occur during simulation or the actual experience, the system stores the data discovered for [state, action, reward, next_state] in a large table. This is called Experience Replay Memory.
During training, we could sample a random batch of 64 frames from the last 20,000 frames to train our network. This would get us a subset within which the correlation amongst the samples is low and will also provide better sampling efficiency.
Now, I shall let you through the implementation details beginning from how to preprocess the OpenAI Gym Pong frames. Let’s dive in.
Modify the OpenAI Gym Environment
First, we convert the images to grayscale and downscale the images to 84x84pixels. Then, there is flickering of some objects in the environment which we can solve by taking the most two recent frames and taking their max. We repeat the action 4 times which is distinct from taking max over previous 2 frames and also distinct from stacking those 4 frames. OpenAI Gym returns images with channels last whereas PyTorch needs images with channels first. This is done by swapping the axis of numpy arrays. Then we stack the four most recent frames. Lastly, we normalize the frames.
Now, we need to keep track of states the agent saw, actions it took, rewards that it received and the states that resulted from taking those actions. A uniformly random chosen batch is going to be released by this memory class to our Deep Q Network for learning update. This random choice reduces the correlation in updates in an episode which happens if the memory buffer were not used. States, actions, rewards, next states and terminal state flags (dones) are stored in separate numpy arrays. In nutshell, store transitions and sample buffer are the two functions applied in this class.
Deep Q Network
As described in the paper, Deep Q Network for this project has 3 Convolutional layers and 2 fully connected layers. For first convolutional layer, we have 32 output channels, 8x8 filter size and stride of 4. For second convolutional layer, we have 64 output channels, 4x4 filter size and stride of 2. For third convolutional layer, we have 64 output channels, 3x3 filter size and stride of 1. First fully connected layer connects input to 512 neurons and last fully connected layer connects input to neurons equal to number of actions. Input size for first fully connected layer is computed with the class in order to let it be automatic. An input of zeros of size (1,input shape) is forwarded through convolutional layers and product of output’s dimensions is computed and used for fully connected layer. RMSProp optimizer and MSELoss are used within the class and Deep Q Network is sent to GPU if available. Model checkpointing is also supported.
The main innovation of Deep Q Learning as described in the paper is the use of an online network that gets updated with gradient descent, in addition to the use of target network that handles calculation of target values during learning update. This target network gets updated only periodically with the weights of the online network. We also have replay memory buffer that we use for sampling the agent’s history and training the online network. This DQN Agent also needs functions for epsilon greedy action selection, copying the weights of online network to target network, decrement the epsilon over time and store new memories. Also, we need to save the learned online model when appropriate. eps_min argument value of 0.01 is used for evaluating the model performance, In the paper, this value is set to 0.1. replace interval of 1000 is about 10 times smaller than used in the paper, as we will need training only for few hours and not days and we are going to use only about 250 games. This is enough to get good performance out of our agent.
Epsilon greedy action selection is implemented in choose_action function. If random number generated is above epsilon, then action with maximum state action value is picked from the online model (q_eval as mentioned in code). Otherwise, random action is picked.
Then, memory’s store_transition function is called in DQNAgent’s store_transition function to store the state, action, reward, next state and terminal state flag. Later, we implement sample_memory function that takes call from memory’s sample_buffer function and transforms the numpy arrays to PyTorch tensors and sends them to GPU. replace_target_network function handles the copying of online network to target network every replace_target_cnt number of learning updates to online network are done. Epsilon is decremented by eps_dec until epsilon is as low as eps_min, after which epsilon becomes eps_min. Save and load models are implemented straightforward.
Now comes the Learning function. If the memory is not filled with even batch_size number of datapoint, no learning is done. Otherwise, states are passed through online network and next states are passed through target network. For the actions taken, only corresponding Q values are sorted from online network’s outputs. Max Q Value is obtained from target network’s outputs and rewards are added to get q targets. Loss is applied on both and backward call is made to get gradients for online network. If next state is terminal, zero target is used for max q values from target network.
Putting it all together
About 250 games or episodes are played for getting good enough performance. For each episode, environment is reset yielding the observation state. Agent chooses the action based on epsilon greedy policy from the online network given that observation state. Environment’s step function is called that yields next state, reward, done and info. Then, state, action, reward and next state information are stored to the memory buffer and agent learns if the memory counter is larger than batch size. Next state is submitted as observation state and the loop continues. Whenever the score is more than previous best score, the models are saved in checkpoint folder.
As the number of episodes increase, one can see that plot shows the increase in score which shows the agent has learned over time to beat the computer.
Thank you for interest in the blog. Please leave comments, feedback and suggestions if you feel any.
Full code on my GitHub repo here.