chainerx.lstm¶
- chainerx.lstm(c_prev, x)¶
Long Short-Term Memory units as an activation function.
This function implements LSTM units with forget gates. Let the previous cell state
c_prev
and the input arrayx
. First, the input arrayx
is split into four arrays \(a, i, f, o\) of the same shapes along the second axis. It means thatx
‘s second axis must have 4 times thec_prev
‘s second axis. The split input arrays are corresponding to:\(a\) : sources of cell input
\(i\) : sources of input gate
\(f\) : sources of forget gate
\(o\) : sources of output gate
Second, it computes the updated cell state
c
and the outgoing signalh
as\[\begin{split}c &= \tanh(a) \sigma(i) + c_{\text{prev}} \sigma(f), \\ h &= \tanh(c) \sigma(o),\end{split}\]where \(\sigma\) is the elementwise sigmoid function. These are returned as a tuple of two variables. This function supports variable length inputs. The mini-batch size of the current input must be equal to or smaller than that of the previous one. When mini-batch size of
x
is smaller than that ofc
, this function only updatesc[0:len(x)]
and doesn’t change the rest ofc
,c[len(x):]
. So, please sort input sequences in descending order of lengths before applying the function.- Parameters
c_prev (
array
) – Variable that holds the previous cell state. The cell state should be a zero array or the output of the previous call of LSTM.x (
array
) – Variable that holds the sources of cell input, input gate, forget gate and output gate. It must have the second dimension whose size is four times of that of the cell state.
- Returns
Two
array
objectsc
andh
.c
is the updated cell state.h
indicates the outgoing signal.- Return type
See the original paper proposing LSTM with forget gates: Long Short-Term Memory in Recurrent Neural Networks.
Example
Assuming
y
is the current incoming signal,c
is the previous cell state, andh
is the previous outgoing signal from anlstm
function. Each ofy
,c
andh
hasn_units
channels. Most typical preparation ofx
is>>> n_units = 100 >>> c_prev = chainerx.zeros((1, n_units), chainerx.float32) >>> x = chainerx.zeros((1, 4 * n_units), chainerx.float32) >>> c, h = chainerx.lstm(c_prev, x)
It corresponds to calculate the input array
x
, or the input sources \(a, i, f, o\), from the current incoming signaly
and the previous outgoing signalh
. Different parameters are used for different kind of input sources.