Open main menu

Humanoid Robots Wiki β

Allen's Reinforcement Learning Notes

Revision as of 20:07, 24 May 2024 by 108.211.178.220 (talk)

Allen's reinforcement learning notes

Contents

Links

Motivation

Consider a problem where we have to train a robot to pick up some object. A traditional ML algorithm might try to learn some function f(x) = y, where given some position x observed via the camera we output some behavior y. The trouble is that in the real world, the correct grab location is some function of the object and the physical environment, which is hard to intuitively ascertain by observation.

The motivation behind reinforcement learning is to repeatedly take observations, then sample the effects of actions on those observations (reward and new observation/state). Ultimately, we hope to create a policy pi that maps states or observations to actions.

Learning

Learning involves the agent taking actions and the environment returning a new state and reward.

  • Input: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle s_t} : States at each time step
  • Output: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle a_t} : Actions at each time step
  • Data:  
  • Learn   to maximize  

State vs. Observation

A state is a complete representation of the physical world while the observation is some subset or representation of s. They are not necessarily the same in that we can't always infer s_t from o_t, but o_t is inferable from s_t. To think of it as a network of conditional probability, we have

  •   (policy)
  • Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle s_1, a_1 - (p(s_{t+1} | s_t, a_t) -> s_2 } (dynamics)

Note that theta represents the parameters of the policy (for example, the parameters of a neural network). Assumption: Markov Property - Future states are independent of past states given present states. This is the fundamental difference between states and observations.

Problem Representation

States and actions are typically continuous - thus, we often want to model our output policy as a density function, which tells us the distribution of probabilities of actions at some given state.

The reward is a function of the state and action r(s, a) -> int, which tells us what states and actions are better. We often use and tune hyperparameters for reward functions to make model training faster

Markov Chain & Decision Process

Markov Chain: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle M = {S, T} } , where S - state space, T- transition operator. The state space is the set of all states, and can be discrete or continuous. The transition probabilities is represented in a matrix, where the i,j'th entry is the probability of going into state i at state j, and we can express the next time step by multiplying the current time step with the transition operator.

Markov Decision Process: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle M = {S, A, T, r} } , where A - action space. T is now a tensor, containing the current state, current action, and next state. We let Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle T_{i, j, k} = p(s_t + 1 = i | s_t = j, a_t = k) } . r is the reward function.

Reinforcement Learning Algorithms - High-level

  1. Generate Samples (run policy)
  2. Fit a model/estimate something about how well policy is performing
  3. Improve policy
  4. Repeat

Policy Gradients - Directly differentiate objective with respect to the optimal theta and then perform gradient descent

Value-based: Estimate value function or q-function of optimal policy (policy is often represented implicitly)

Actor-Critic: Estimate value function or q-function of current policy, and find a better policy gradient

Model-based: Estimate some transition model, and then use it to improve a policy

REINFORCE

-

Temporal Difference Learning

Temporal Difference (TD) is a model for estimating the utility of states given some state-action-outcome information. Suppose we have some initial value Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle V_0(s) } , and we get some information Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle (s, a, s', r(s, a) } . We can then use the update equation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle V_{t+1}(s) = (1- \alpha)V_{t}(s)+\alpha(R(s, a, s') + \gamma V_i(s')) } . Here Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \alpha} represents the learning rate, which is how much new information is weighted relative to old information, while Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \gamma} represents the discount factor, which can be thought of how much getting a reward in the future factors into our current reward.

Q Learning

Q Learning gives us a way to extract the optimal policy after learning. Instead of keeping track of the values of individual states, we keep track of Q values for state-action pairs, representing the utility of taking action a at state s.

How do we use this Q value? Two main ideas.

Idea 1: Policy iteration - if we have a policy Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \pi } and we know Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle Q^pi (s, a) } , we can improve the policy, by deterministically setting the action at each state be the argmax of all possible actions at the state.

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle Q_{i+1} (s,a) = (1 - \alpha) Q_i (s,a) + \alpha (r(s,a) + \gamma V_i(s'))}

Idea 2: Gradient update - If Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle Q^pi(s, a) > V^pi(s) } , then a is better than average. We will then modify the policy to increase the probability of a.