chainerx.lstm¶
- chainerx.lstm(c_prev, x)¶
Long Short-Term Memory units as an activation function.
This function implements LSTM units with forget gates. Let the previous cell state
c_prevand the input arrayx. First, the input arrayxis split into four arrays \(a, i, f, o\) of the same shapes along the second axis. It means thatx‘s second axis must have 4 times thec_prev‘s second axis. The split input arrays are corresponding to:\(a\) : sources of cell input
\(i\) : sources of input gate
\(f\) : sources of forget gate
\(o\) : sources of output gate
Second, it computes the updated cell state
cand the outgoing signalhas\[\begin{split}c &= \tanh(a) \sigma(i) + c_{\text{prev}} \sigma(f), \\ h &= \tanh(c) \sigma(o),\end{split}\]where \(\sigma\) is the elementwise sigmoid function. These are returned as a tuple of two variables. This function supports variable length inputs. The mini-batch size of the current input must be equal to or smaller than that of the previous one. When mini-batch size of
xis smaller than that ofc, this function only updatesc[0:len(x)]and doesn’t change the rest ofc,c[len(x):]. So, please sort input sequences in descending order of lengths before applying the function.- Parameters
c_prev (
array) – Variable that holds the previous cell state. The cell state should be a zero array or the output of the previous call of LSTM.x (
array) – Variable that holds the sources of cell input, input gate, forget gate and output gate. It must have the second dimension whose size is four times of that of the cell state.
- Returns
Two
arrayobjectscandh.cis the updated cell state.hindicates the outgoing signal.- Return type
See the original paper proposing LSTM with forget gates: Long Short-Term Memory in Recurrent Neural Networks.
Example
Assuming
yis the current incoming signal,cis the previous cell state, andhis the previous outgoing signal from anlstmfunction. Each ofy,candhhasn_unitschannels. Most typical preparation ofxis>>> n_units = 100 >>> c_prev = chainerx.zeros((1, n_units), chainerx.float32) >>> x = chainerx.zeros((1, 4 * n_units), chainerx.float32) >>> c, h = chainerx.lstm(c_prev, x)
It corresponds to calculate the input array
x, or the input sources \(a, i, f, o\), from the current incoming signalyand the previous outgoing signalh. Different parameters are used for different kind of input sources.