How regularization affects the critical points in linear networks

$\newcommand{\transpose}{\intercal}$ $\DeclareMathOperator*{\minimize}{minimize}$

Given an input initial random vector $X_0\in\mathbb{R}^n$ with $p_X$ distribution and covariance matrix $\Sigma_{X_0}=\mathbb{E}[X_0{X_0}^\transpose]$. Assume the input-output model is in the following linear form: \[Z=RX_0+\xi,\] where $\xi\in\mathbb{R}^n$ is the noise and $Z\in\mathbb{R}^n$ is the output. In addition, the noise $\xi$ is assumed to have $p_\xi$ distribution and be independent to the input $X_0$, i.e. $\mathbb{E}[\xi{X_0}^\transpose]=0$. The problem is using i.i.d. input-output samples $\{({X_0}^{(k)},Z^{(k)})\}_{k=1}^K$ to learn the weights of a linear feed-forward neural network \[\dfrac{dX_t}{dt}=A_tX_t\] in order to match the input-output relation $R$. Note that $A_t$ are the network weights, $t$ denotes the input layer with at most depth $T$, and $K$ is the total number of trainning samples.

Consider the following regularized form of the optimization problem: \[\begin{align} \minimize\limits_{A_t}\textsf{: }&\text{J}[A_t]=\mathbb{E}\left[\lambda\int_0^T\dfrac{1}{2}\text{tr}({A_t}^\transpose A_t)dt+\dfrac{1}{2}(X_T-Z)^2\right]\\ \textsf{subject to: }&\dfrac{dX_t}{dt}=A_tX_t,\ X_0\textsf{ is given}, \end{align}\] where $(\cdot)^2$ denotes the dot product of itself, and $\lambda\in\mathbb{R}^+\cup\ \{0\}$ is a regularization parameter.

To minimize $A_t$, we consider the Lagrange multiplier $Y_t$, a random process $[0,T]\rightarrow\mathbb{R}^n$, s.t. \[\mathcal{L}(X_t,A_t,Y_t)=\mathbb{E}\left[\lambda\int_0^T\dfrac{1}{2}\text{tr}({A_t}^\transpose A_t)dt+\dfrac{1}{2}(X_T-Z)^2+\int_0^T{Y_t}^\transpose\left(\dfrac{dX_t}{dt}-A_tX_t\right) dt\right],\] then we have \[\begin{align} &\textsf{(1) }\dfrac{\partial\mathcal{L}}{\partial Y_t}=0\Rightarrow\mathbb{E}\left[\int_0^T{\delta Y_t}^\transpose\left(\dfrac{dX_t}{dt}-A_tX_t\right)dt\right]=0,\ \forall\ \delta Y_t\\ &\textsf{(2) }\dfrac{\partial\mathcal{L}}{\partial X_t}=0\Rightarrow\left\{\begin{array}{l l}\mathbb{E}[X_t-Z+Y_t]=0&\textsf{, if }t=T\\\\\mathbb{E}\left[-\displaystyle\int_0^T\left(\dfrac{dY_t}{dt}+{A_t}^\transpose Y_t\right)^\transpose \delta X_t\ dt\right],\ \forall\ \delta X_t&\textsf{, if }t\neq T\end{array}\right..\\ &\textsf{(3) }\dfrac{\partial\mathcal{L}}{\partial A_t}=0\Rightarrow\mathbb{E}\left[\int_0^T(\lambda A_t-Y_t{X_t}^\transpose)dt\right]=0 \end{align}\] Note that \[\int_0^T{Y_t}^\transpose\dfrac{dX_t}{dt}dt={Y_T}^\transpose X_T-{Y_0}^\transpose X_0-\int_0^T{\dfrac{dY_t}{dt}}^\transpose X_t dt.\] Therefore, from (1), (2) and (3), we have \[\begin{align} &\textsf{(4) }\dfrac{dX_t}{dt}=A_tX_t\textsf{, with }X_0\textsf{ given.}\\ &\textsf{(5) }\dfrac{dY_t}{dt}=-{A_t}^\transpose Y_t\textsf{, with }Y_T=Z-X_T\\ &\textsf{(6) }A_t=\dfrac{1}{\lambda}\mathbb{E}[Y_t{X_t}^\transpose] \end{align}\]

Next, we would like to solve the 3 differential equation. We start with assuming the solution of equations (4) and (5) in the form of \[\left\{\begin{align}X_t=&\phi_{t;0}X_0&&\textsf{(7)}\\Y_t=&{\phi_{T;t}}^\transpose Y_T\\=&{\phi_{T;t}}^\transpose(Z-X_T)\\=&{\phi_{T;t}}^\transpose(RX_0+\xi-\phi_{T;0}X_0)&&\textsf{(8)}\end{align}\right.,\] and also first, take the derivative of both sides of (6) with respect to $t$ and next, substitute the differentiate terms with equations (4) and (5), one obtain \[\dfrac{dA_t}{dt}=-{A_t}^\transpose A_t+A_t{A_t}^\transpose.\textsf{ (9)}\]

Since the differentiation of $A_t$ is a symmetric matrix, we decomposing $A_t$ into the combination of symmetric matrix $S_t$ and skew-symmetric matrix $\bar S_t$: \[A_t=\dfrac{A_t+{A_t}^\transpose}{2}+\dfrac{A_t-{A_t}^\transpose}{2}=S_t+\bar S_t,\] and guess that the differentiation result of the skew-symmetric matrix $\bar S_t$ should be 0. The derivatives of $S_t$ and $\bar S_t$ are \[\left\{\begin{align}\dfrac{dS_t}{dt}&=\dfrac{d}{dt}\dfrac{A_t+{A_t}^\transpose}{2}=-{A_t}^\transpose A_t+A_t{A_t}^\transpose\\\dfrac{d\bar S_t}{dt}&=\dfrac{d}{dt}\dfrac{A_t-{A_t}^\transpose}{2}=0\end{align}\right.,\] which matches our assumption. We can further rewrite $\dfrac{dS_t}{dt}$ in terms of $S_t$ and $\bar S_t$: \[\dfrac{dS_t}{dt}=2(\bar S_tS_t-S_t\bar S_t)=MS_t-S_tM,\] where $M=2\bar S_t=2\bar S_0$ is a constant matrix (i.e. not a matrix function of time $t$). From here, we can have the general solution of $S_t$ as \[S_t=e^{tM}S_0e^{-tM},\] and therefore, \[\begin{align}A_t=&S_t+\bar S_t\\=&e^{2t\bar S_0}S_0e^{-2t\bar S_0}+\bar S_0\\=&e^{2t\bar S_0}(S_0+\bar S_0)e^{-2t\bar S_0}\\=&e^{t(A_0-{A_0}^\transpose)}A_0e^{-t(A_0-{A_0}^\transpose)}.&&\textsf{(10)}\end{align}\] Note that the third equality holds since any two functions of the same matrix are commutable. In addition, if we substitute (6) with (7) and (8), and set $t=0$, we will have a constrain on $A_0$: \[\lambda A_0={\phi_{T;0}}^\transpose(R-\phi_{T;0})\Sigma_{X_0}.\textsf{ (11)}\]

Finally, to solve $\phi_{t;0}$, first let $\bar X_t=e^{-t(A_0-{A_0}^\transpose)}X_t$, and substitute (4) with $\bar X_t$, $\bar X_0$ and (10), we have \[\dfrac{d\bar X_t}{dt}={A_0}^\transpose \bar X_t,\] which can be easily solved. Thus, the solution to (4) is \[X_t=e^{t(A_0-{A_0}^\transpose)}\bar X_t=e^{t(A_0-{A_0}^\transpose)}e^{t{A_0}^\transpose}X_0.\textsf{ (12)}\] Similarly, the solution to (5) is \begin{align}Y_t=&e^{t(A_0-{A_0}^\transpose)}e^{-tA_0}Y_0\\=&e^{t(A_0-{A_0}^\transpose)}e^{(T-t)A_0}e^{-T(A_0-{A_0}^\transpose)}Y_T\\=&e^{t(A_0-{A_0}^\transpose)}e^{(T-t)A_0}e^{-T(A_0-{A_0}^\transpose)}(Z-X_T).&&\textsf{(13)}\end{align} Note that $\phi_{t;0}$ is also solved by comparing (7) and (12) \[\phi_{t;0}=e^{t(A_0-{A_0}^\transpose)}e^{t{A_0}^\transpose}.\] Furthermore, the constrain on $A_0$ can be rewritten as \begin{align}\lambda A_0=&\left(e^{TA_0}e^{-T(A_0-{A_0}^\transpose)}\right)\left(R-e^{T(A_0-{A_0}^\transpose)}e^{T{A_0}^\transpose}\right)\Sigma_{X_0}\\=&\left(e^{TA_0}e^{-T(A_0-{A_0}^\transpose)}R-e^{TA_0}e^{T{A_0}^\transpose}\right)\Sigma_{X_0}.\end{align} Reference: How regularization affects the critical points in linear networks

Comments

Popular posts from this blog

MNIST Dataset

Entropy and Mutual Information