chainerx.tree_lstm¶
- chainerx.tree_lstm(*inputs)¶
TreeLSTM unit as an activation function.
This function implements TreeLSTM units both for N-ary TreeLSTM and Child-Sum TreeLSTM. Let the children cell states \(c_{\text{1}}, c_{\text{2}}, \dots, c_{\text{N}}\), and the incoming signal \(x\). First, the incoming signal \(x\) is split into (3 + N) arrays \(a, i, o, f_{\text{1}}, f_{\text{2}}, ..., f_{\text{N}}\) of the same shapes along the second axis. It means that \(x\) ‘s second axis must have (3 + N) times of the length of each \(c_{n}\). The splitted input signals are corresponding to
\(a\) : sources of cell input
\(i\) : sources of input gate
\(o\) : sources of output gate
\(f_{n}\) : sources of forget gate for n-th ary
Second, it computes outputs as
\[\begin{split}c &= \tanh(a) \text{sigmoid}(i) \\ & + c_{\text{1}} \text{sigmoid}(f_{\text{1}}), \\ & + c_{\text{2}} \text{sigmoid}(f_{\text{2}}), \\ & + ..., \\ & + c_{\text{N}} \text{sigmoid}(f_{\text{N}}), \\ h &= \tanh(c) \text{sigmoid}(o).\end{split}\]These are returned as a tuple of (N + 1) variables.
- Parameters
inputs (list of
array
) – Variable arguments which include all cell vectors from child-nodes, and an input vector. Each of the cell vectors and the input vector isarray
. The input vector must have the second dimension whose size is (N + 3) times of that of each cell, where N denotes the total number of cells.- Returns
Two
array
objectsc
andh
.c
is the updated cell state.h
indicates the outgoing signal.- Return type
See the papers for details: Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks and A Fast Unified Model for Parsing and Sentence Understanding. Tai et al.’s N-Ary TreeLSTM is little extended in Bowman et al., and this link is based on the variant by Bowman et al. Specifically, eq. 10 in Tai et al. only has one \(W\) matrix to be applied to \(x\), consistently for all children. On the other hand, Bowman et al.’s model has multiple matrices, each of which affects the forget gate for each child’s cell individually.
Example
Assuming
y
is the current input signal,c
is the previous cell state, andh
is the previous output signal from antree_lstm()
function. Each ofy
,c
andh
hasn_units
channels. Using 2-ary (binary) TreeLSTM,most typical preparation of
x
is>>> c1 = chainerx.ones((4, 10), dtype = chainerx.float32) >>> c2 = chainerx.ones((4, 10), dtype = chainerx.float32) >>> x = chainerx.ones((4, 50), dtype = chainerx.float32) >>> c, h = chainerx.tree_lstm(c1, c2, x)