Difference between revisions of "Allen's REINFORCE notes"
| Line 35: | Line 35: | ||
Now we want to find the gradient of <math> J (\theta) </math>, namely | Now we want to find the gradient of <math> J (\theta) </math>, namely | ||
| − | <math>\nabla_\theta \sum_\tau P(\tau | \theta) R(\tau) </math>. The | + | <math>\nabla_\theta \sum_\tau P(\tau | \theta) R(\tau) </math>. Since the reward function isn't a dependent on the parameters. We can rearrange: <math>\nabla_\theta \sum_\tau P(\tau | \theta) R(\tau) </math>. The next step here is what's called the Log Derivative Trick. |
====Log Derivative Trick==== | ====Log Derivative Trick==== | ||
Suppose we'd like to find <math>\nabla_{x_1}\log(f(x_1, x_2, x_3, ...))</math>. By the chain rule this is equal to <math>\frac{\nabla_{x_1}f(x_1, x_2, x_3 ...)}{f(x_1, x_2, x_3 ...)}</math>. Thus, by rearranging, we can take the gradient of any function with respect to some variable as <math>\nabla_{x_1}f(x_1, x_2, x_3, ...)= f(x_1, x_2, x_3,...)\nabla_{x_1}\log(f(x_1, x_2, x_3, ...)</math>. | Suppose we'd like to find <math>\nabla_{x_1}\log(f(x_1, x_2, x_3, ...))</math>. By the chain rule this is equal to <math>\frac{\nabla_{x_1}f(x_1, x_2, x_3 ...)}{f(x_1, x_2, x_3 ...)}</math>. Thus, by rearranging, we can take the gradient of any function with respect to some variable as <math>\nabla_{x_1}f(x_1, x_2, x_3, ...)= f(x_1, x_2, x_3,...)\nabla_{x_1}\log(f(x_1, x_2, x_3, ...)</math>. | ||
| + | |||
| + | Thus | ||
=== Loss Function === | === Loss Function === | ||
The goal of REINFORCE is to optimize the expected cumulative reward. We do so using gradient descent | The goal of REINFORCE is to optimize the expected cumulative reward. We do so using gradient descent | ||
Revision as of 00:30, 26 May 2024
Allen's REINFORCE notes
Contents
Links
Motivation
Recall that the objective of Reinforcement Learning is to find an optimal policy which we encode in a neural network with parameters Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \theta^*} . Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \pi_\theta } is a mapping from observations to actions. These optimal parameters are defined as Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \theta^* = \text{argmax}_\theta E_{\tau \sim p_\theta(\tau)} \left[ \sum_t r(s_t, a_t) \right] } . Let's unpack what this means. To phrase it in english, this is basically saying that the optimal policy is one such that the expected value of the total reward over following a trajectory (Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \tau } ) determined by the policy is the highest over all policies.
Overview
Initialize neural network with input dimensions = observation dimensions and output dimensions = action dimensions
For each episode:
While not terminated:
Get observation from environment
Use policy network to map observation to action distribution
Randomly sample one action from action distribution
Compute logarithmic probability of that action occurring
Step environment using action and store reward
Calculate loss over entire trajectory as function of probabilities and rewards
Recall loss functions are differentiable with respect to each parameter - thus, calculate how changes in parameters correlate with changes in the loss
Based on the loss, use a gradient descent policy to update weights
Objective Function
The goal of reinforcement learning is to maximize the expected reward over the entire episode. We use Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle R(\tau)} to denote the total reward over some trajectory Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \tau} defined by our policy. Thus we want to maximize Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle E_{\tau \sim \pi_\theta}[R(\tau)]} . We can use the definition of expected value to expand this as Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \sum_\tau P(\tau | \theta) R (\tau)} , where the probability of a given trajectory occurring can further be expressed as Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle P(\tau | \theta) = P(s_0) \prod^T_{t=0} \pi_\theta(a_t | s_t) P(s_{t + 1} | s_t, a_t) } .
Now we want to find the gradient of Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle J (\theta) } , namely Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \nabla_\theta \sum_\tau P(\tau | \theta) R(\tau) } . Since the reward function isn't a dependent on the parameters. We can rearrange: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \nabla_\theta \sum_\tau P(\tau | \theta) R(\tau) } . The next step here is what's called the Log Derivative Trick.
Log Derivative Trick
Suppose we'd like to find Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \nabla_{x_1}\log(f(x_1, x_2, x_3, ...))} . By the chain rule this is equal to Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \frac{\nabla_{x_1}f(x_1, x_2, x_3 ...)}{f(x_1, x_2, x_3 ...)}} . Thus, by rearranging, we can take the gradient of any function with respect to some variable as Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \nabla_{x_1}f(x_1, x_2, x_3, ...)= f(x_1, x_2, x_3,...)\nabla_{x_1}\log(f(x_1, x_2, x_3, ...)} .
Thus
Loss Function
The goal of REINFORCE is to optimize the expected cumulative reward. We do so using gradient descent