chainerx.slstm¶
-
chainerx.slstm(c_prev1, c_prev2, x1, x2)¶ S-LSTM units as an activation function.
This function implements S-LSTM unit. It is an extension of LSTM unit applied to tree structures. The function is applied to binary trees. Each node has two child nodes. It gets four arguments, previous cell states
c_prev1andc_prev2, and input arraysx1andx2. First both input arraysx1andx2are split into eight arrays \(a_1, i_1, f_1, o_1\), and \(a_2, i_2, f_2, o_2\). They have the same shape along the second axis. It means thatx1andx2‘s second axis must have 4 times the length ofc_prev1andc_prev2. The split input arrays are corresponding to\(a_i\) : sources of cell input
\(i_i\) : sources of input gate
\(f_i\) : sources of forget gate
\(o_i\) : sources of output gate
It computes the updated cell state
cand the outgoing signalhas.\[\begin{split}c &= \tanh(a_1 + a_2) \sigma(i_1 + i_2) + c_{\text{prev}1} \sigma(f_1) + c_{\text{prev}2} \sigma(f_2), \\ h &= \tanh(c) \sigma(o_1 + o_2),\end{split}\]where \(\sigma\) is the elementwise sigmoid function. The function returns
candhas a tuple.- Parameters
c_prev1 (
array) – Variable that holds the previous cell state of the first child node. The cell state should be a zero array or the output of the previous call of LSTM.c_prev2 (
array) – Variable that holds the previous cell state of the second child node.x1 (
array) – Variable that holds the sources of cell input, input gate, forget gate and output gate from the first child node. It must have the second dimension whose size is four times of that of the cell state.x2 (
array) – Variable that holds the input sources from the second child node.
- Returns
Two
arrayobjectscandh.cis the cell state.hindicates the outgoing signal.- Return type
See detail in paper: Long Short-Term Memory Over Tree Structures.
Example
Assuming
c1,c2is the previous cell state of children, andh1,h2is the previous outgoing signal from children. Each ofc1,c2,h1andh2hasn_unitschannels. Most typical preparation ofx1,x2is:>>> n_units = 100 >>> c1 = chainerx.ones((1, n_units), np.float32) >>> c2 = chainerx.ones((1, n_units), np.float32) >>> x1 = chainerx.ones((1, 4 * n_units), chainerx.float32) >>> x2 = chainerx.ones((1, 4 * n_units), chainerx.float32) >>> c, h = chainerx.slstm(c1, c2, x1, x2)