Open main menu

Humanoid Robots Wiki β

Dennis' Optimization Notes

Revision as of 04:57, 25 May 2024 by Dennisc (talk | contribs) (Initial notes on optimization)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

Notes of various riffs on Gradient Descent from a perspective of neural networks.

Contents

A review of standard Gradient DescentEdit

The goal of Gradient Descent is to minimize a loss function  . To be more specific, if   is a differentiable multivariate function, we want to find the vector Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w} that minimizes Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle L(w)} .

Given an initial vector Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{0}} , we want to “move” in the direction Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \Delta w} where Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle L(w_{0})-L(w_{0}+\Delta w)} is minimized (suppose the magnitude of Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \Delta w} is fixed). By Cauchy’s Inequality, this is precisely when Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \Delta w} is in the direction of Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle -\nabla L(w_{0})} .

So given some Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n}} , we want to update in the direction of Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle -\alpha \nabla L(w_{n})} . This motivates setting Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n+1}=w_{n}-\alpha \nabla L(w_{n})} , where Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \alpha } is a scalar factor. We call Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \alpha } the “learning rate” because it affects how fast the series Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n}} converges to the optimum. The main trouble in machine learning is to tweak the Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \alpha } to what “works best” in ensuring convergence, and that is one of the considerations that the remaining algorithms try to address.

Stochastic Gradient DescentEdit

In practice we don’t actually know the “true gradient”. So instead we take some datasets, say datasets Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle 1} through Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle n} , and for dataset Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle i} we derive an estimated gradient Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L_{i}} . Then we may estimate Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L} as

 

If it is easy to compute Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L_{i}(w)} in general then we are golden: this is the best estimate of   we can get. But what if Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L_{i}} are computationally expensive to compute? Then there is a tradeoff between variance and computational cost when evaluating our estimate of Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L} .

A very low-cost (but low-accuracy) way to estimate Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L} is just via Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L_{1}} (or any other Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L_{i}} ). But this is obviously problematic: we aren’t even using most of our data! A better balance can be struck as follows: to evaluate  , select Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle k} functions at random from Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \{\nabla L_{1},\ldots ,\nabla L_{n}\}} . Then estimate Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L} as the average of those Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle k} functions only at that step.

Riffs on stochastic gradient descentEdit

MomentumEdit

See also “Momentum” on Distill.

In typical stochastic gradient descent, the next step we take is based solely on the gradient at the current point. This completely ignores the past gradients. However, many times it makes sense to take the past gradients into account. Of course, if we are at Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{100}} , we should care about Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L(w_{99})} much heavier than Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L(w_{1})} . So we should weight Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L(w_{99})} much more than Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \nabla L(w_{1})} .

The simplest way is to weight it with a geometric approach. So when we iterate, instead of taking Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n+1}} to satisfy

Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\displaystyle w_{n+1}-w_{n}=-\alpha \nabla L(w_{n})}

like in standard gradient descent, we instead want to take Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n+1}} to satisfy

Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\displaystyle w_{n+1}-w_{n}=-\alpha \nabla L(w_{n})-\beta \alpha \nabla L(w_{n-1})-\cdots -\beta ^{n}\nabla L(w_{0}).}

But this raises a concern: are we really going to be storing all of these terms, especially as Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle n} grows? Fortunately, we do not need to. For we may notice that

Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\displaystyle w_{n+1}-w_{n}=-\alpha \nabla L(w_{n})-\beta (\alpha \nabla L(w_{n-1})-\cdots -\beta ^{n-1}L(w_{0}))=-\alpha \nabla L(w_{n})-\beta (w_{n}-w_{n-1}).}

To put it another way, if we write Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n}-w_{n-1}=\Delta w_{n}} , i.e. how much Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n}} differs from Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n-1}} by, we may rewrite this equation as

Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\displaystyle \Delta w_{n+1}=-\alpha \nabla L(w_{n})+\beta \Delta w_{n}.}

Some of the benefits of using a momentum based approach:

  • most importantly, it can dramatically speed up convergence to a local minimum.
  • it makes convergence more likely in general
  • escaping local minima/saddles/plateaus (its importance is possibly contested? See this reddit thread)

RMSPropEdit

Gradient descent also often has diminishing learning rates. In order to counter this, we very broadly want to - track the past learning rates, - and if they have been low, multiply   by a scalar to increase the learning rate. - (As a side effect, if our past learning rates are quite high, we will tamper the learning rates.)

While performing our gradient descent to get Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n}\to w_{n+1}} , we create and store an auxillary parameter Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle v_{n+1}} as follows:

Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\displaystyle v_{n+1}=\beta v_{n}+(1-\beta )\nabla L(w)^{2}}

and define

Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\displaystyle w_{n+1}=w_{n}-{\frac {\alpha }{{\sqrt {v_{n}}}+\epsilon }}L(w),}

where Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \alpha } as usual is the learning rate, Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \beta } is the decay rate of Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle v_{n}} , and Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \epsilon } is a constant that also needs to be fine-tuned.

We include the constant term of Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle \epsilon } in order to ensure that the sequence Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle w_{n}} actually converges and to ensure numerical stability. If we are near the minimum, then Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle v_{n}} will be quite small, meaning the denominator Failed to parse (Conversion error. Server ("https://wikimedia.org/api/rest_") reported: "Cannot get mml. Server problem."): {\textstyle {\sqrt {v_{n}}}+\epsilon } will essentially just become Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \sqrt{v_n}} . But because Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle w} will converge when Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle L(w)} is just multiplied by a constant (this is the underlying assumption of standard gradient descent, after all), we will achieve convergence when near a minimum.

Side note: in order to get RMSProp to interoperate with stochastic gradient descent, we instead compute the sequence Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle v_n} for each approximated loss function Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle L_i} .

AdamEdit

Adam (Adaptive Moment Estimation) is a gradient descent modification that combines Momentum and RMSProp. We create two auxillary variables while iterating Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle w_n} (where Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \alpha} is the learning rate, Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \beta_1} and Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \beta_2} are decay parameters that need to be fine-tuned, and Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \epsilon} is a parameter serving the same purpose as in RMSProp):

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle m_{n+1} = \beta_1 m_n + (1 - \beta_1) \nabla L(w_n)}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle v_{n+1} = \beta_2 v_n + (1 - \beta_2) \nabla L(w_n)^2.}

For notational convenience, we will define

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \widehat{m}_n = \frac{m_n}{1 - \beta_1^n}}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle \widehat{v}_n = \frac{v_n}{1 - \beta_2^n}.}

Then our update function to get Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle w_{n+1}} is

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\displaystyle w_{n+1} = w_n - \alpha \frac{\widehat{m}_n}{\sqrt{\widehat{v}_w} + \epsilon}.}

It is worth noting that though this formula does not explicitly include Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \nabla L(w_n)} , it is accounted for in the Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle \widehat{m}_n} term through Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://wikimedia.org/api/rest_v1/":): {\textstyle m_n} .