aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/graph_deepspeech_v0_4_1.cpp110
1 files changed, 52 insertions, 58 deletions
diff --git a/examples/graph_deepspeech_v0_4_1.cpp b/examples/graph_deepspeech_v0_4_1.cpp
index 6af85a585a..84650a6627 100644
--- a/examples/graph_deepspeech_v0_4_1.cpp
+++ b/examples/graph_deepspeech_v0_4_1.cpp
@@ -142,61 +142,56 @@ public:
get_weights_accessor(data_path, "ones.npy"))
.set_name("add_y");
- // TODO(COMPMID-2103): Use sub stream for FC weights and bias in LSTM cells
// Create LSTM Fully Connected weights and bias descriptors
- //TensorDescriptor lstm_weights_descriptor = TensorDescriptor(TensorShape(4096U, 8192U), common_params.data_type).set_layout(common_params.data_layout);
- //TensorDescriptor lstm_bias_descriptor = TensorDescriptor(TensorShape(8192U), common_params.data_type).set_layout(common_params.data_layout);
- //SubStream lstm_fc_weights(graph);
- //SubStream lstm_fc_bias(graph);
-
- //lstm_fc_weights << InputLayer(lstm_weights_descriptor,
- // get_weights_accessor(data_path, "rnn_lstm_cell_kernel_transpose.npy", weights_layout))
- // .set_name("h5/transpose");
- //lstm_fc_bias << InputLayer(lstm_bias_descriptor,
- // get_weights_accessor(data_path, "rnn_lstm_cell_MatMul_bias.npy"))
- // .set_name("MatMul_3_bias");
+ TensorDescriptor lstm_weights_descriptor = TensorDescriptor(TensorShape(4096U, 8192U), common_params.data_type).set_layout(common_params.data_layout);
+ TensorDescriptor lstm_bias_descriptor = TensorDescriptor(TensorShape(8192U), common_params.data_type).set_layout(common_params.data_layout);
+ SubStream lstm_fc_weights(graph);
+ SubStream lstm_fc_bias(graph);
+ lstm_fc_weights << ConstantLayer(lstm_weights_descriptor,
+ get_weights_accessor(data_path, "rnn_lstm_cell_kernel_transpose.npy", weights_layout))
+ .set_name("h5/transpose");
+ lstm_fc_bias << ConstantLayer(lstm_bias_descriptor,
+ get_weights_accessor(data_path, "rnn_lstm_cell_MatMul_bias.npy"))
+ .set_name("MatMul_3_bias");
// LSTM Block
- std::pair<SubStream, SubStream> new_state_1 = add_lstm_cell(data_path, unstack_nid, 0, previous_state, previous_state, add_y);
- std::pair<SubStream, SubStream> new_state_2 = add_lstm_cell(data_path, unstack_nid, 1, new_state_1.first, new_state_1.second, add_y);
- std::pair<SubStream, SubStream> new_state_3 = add_lstm_cell(data_path, unstack_nid, 2, new_state_2.first, new_state_2.second, add_y);
- std::pair<SubStream, SubStream> new_state_4 = add_lstm_cell(data_path, unstack_nid, 3, new_state_3.first, new_state_3.second, add_y);
- std::pair<SubStream, SubStream> new_state_5 = add_lstm_cell(data_path, unstack_nid, 4, new_state_4.first, new_state_4.second, add_y);
- std::pair<SubStream, SubStream> new_state_6 = add_lstm_cell(data_path, unstack_nid, 5, new_state_5.first, new_state_5.second, add_y);
- std::pair<SubStream, SubStream> new_state_7 = add_lstm_cell(data_path, unstack_nid, 6, new_state_6.first, new_state_6.second, add_y);
- std::pair<SubStream, SubStream> new_state_8 = add_lstm_cell(data_path, unstack_nid, 7, new_state_7.first, new_state_7.second, add_y);
- std::pair<SubStream, SubStream> new_state_9 = add_lstm_cell(data_path, unstack_nid, 8, new_state_8.first, new_state_8.second, add_y);
- std::pair<SubStream, SubStream> new_state_10 = add_lstm_cell(data_path, unstack_nid, 9, new_state_9.first, new_state_9.second, add_y);
- std::pair<SubStream, SubStream> new_state_11 = add_lstm_cell(data_path, unstack_nid, 10, new_state_10.first, new_state_10.second, add_y);
- std::pair<SubStream, SubStream> new_state_12 = add_lstm_cell(data_path, unstack_nid, 11, new_state_11.first, new_state_11.second, add_y);
- std::pair<SubStream, SubStream> new_state_13 = add_lstm_cell(data_path, unstack_nid, 12, new_state_12.first, new_state_12.second, add_y);
- std::pair<SubStream, SubStream> new_state_14 = add_lstm_cell(data_path, unstack_nid, 13, new_state_13.first, new_state_13.second, add_y);
- std::pair<SubStream, SubStream> new_state_15 = add_lstm_cell(data_path, unstack_nid, 14, new_state_14.first, new_state_14.second, add_y);
- std::pair<SubStream, SubStream> new_state_16 = add_lstm_cell(data_path, unstack_nid, 15, new_state_15.first, new_state_15.second, add_y);
-
- if(n_steps > 1)
- {
- // Concatenate new states on height
- const int axis = 1;
- graph << StackLayer(axis,
- std::move(new_state_1.second),
- std::move(new_state_2.second),
- std::move(new_state_3.second),
- std::move(new_state_4.second),
- std::move(new_state_5.second),
- std::move(new_state_6.second),
- std::move(new_state_7.second),
- std::move(new_state_8.second),
- std::move(new_state_9.second),
- std::move(new_state_10.second),
- std::move(new_state_11.second),
- std::move(new_state_12.second),
- std::move(new_state_13.second),
- std::move(new_state_14.second),
- std::move(new_state_15.second),
- std::move(new_state_16.second))
- .set_name("concat");
- }
+ std::pair<SubStream, SubStream> new_state_1 = add_lstm_cell(data_path, unstack_nid, 0, previous_state, previous_state, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_2 = add_lstm_cell(data_path, unstack_nid, 1, new_state_1.first, new_state_1.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_3 = add_lstm_cell(data_path, unstack_nid, 2, new_state_2.first, new_state_2.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_4 = add_lstm_cell(data_path, unstack_nid, 3, new_state_3.first, new_state_3.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_5 = add_lstm_cell(data_path, unstack_nid, 4, new_state_4.first, new_state_4.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_6 = add_lstm_cell(data_path, unstack_nid, 5, new_state_5.first, new_state_5.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_7 = add_lstm_cell(data_path, unstack_nid, 6, new_state_6.first, new_state_6.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_8 = add_lstm_cell(data_path, unstack_nid, 7, new_state_7.first, new_state_7.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_9 = add_lstm_cell(data_path, unstack_nid, 8, new_state_8.first, new_state_8.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_10 = add_lstm_cell(data_path, unstack_nid, 9, new_state_9.first, new_state_9.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_11 = add_lstm_cell(data_path, unstack_nid, 10, new_state_10.first, new_state_10.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_12 = add_lstm_cell(data_path, unstack_nid, 11, new_state_11.first, new_state_11.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_13 = add_lstm_cell(data_path, unstack_nid, 12, new_state_12.first, new_state_12.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_14 = add_lstm_cell(data_path, unstack_nid, 13, new_state_13.first, new_state_13.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_15 = add_lstm_cell(data_path, unstack_nid, 14, new_state_14.first, new_state_14.second, add_y, lstm_fc_weights, lstm_fc_bias);
+ std::pair<SubStream, SubStream> new_state_16 = add_lstm_cell(data_path, unstack_nid, 15, new_state_15.first, new_state_15.second, add_y, lstm_fc_weights, lstm_fc_bias);
+
+ // Concatenate new states on height
+ const int axis = 1;
+ graph << StackLayer(axis,
+ std::move(new_state_1.second),
+ std::move(new_state_2.second),
+ std::move(new_state_3.second),
+ std::move(new_state_4.second),
+ std::move(new_state_5.second),
+ std::move(new_state_6.second),
+ std::move(new_state_7.second),
+ std::move(new_state_8.second),
+ std::move(new_state_9.second),
+ std::move(new_state_10.second),
+ std::move(new_state_11.second),
+ std::move(new_state_12.second),
+ std::move(new_state_13.second),
+ std::move(new_state_14.second),
+ std::move(new_state_15.second),
+ std::move(new_state_16.second))
+ .set_name("concat");
graph << FullyConnectedLayer(
2048U,
@@ -251,10 +246,9 @@ private:
unsigned int unstack_idx,
SubStream previous_state_c,
SubStream previous_state_h,
- SubStream add_y)
- // TODO(COMPMID-2103): Use sub streams for FC weights and bias
- //SubStream lstm_fc_weights,
- //SubStream lstm_fc_bias)
+ SubStream add_y,
+ SubStream lstm_fc_weights,
+ SubStream lstm_fc_bias)
{
const std::string cell_name("rnn/lstm_cell_" + std::to_string(unstack_idx));
const DataLayoutDimension concat_dim = (common_params.data_layout == DataLayout::NHWC) ? DataLayoutDimension::CHANNEL : DataLayoutDimension::WIDTH;
@@ -269,8 +263,8 @@ private:
graph << FullyConnectedLayer(
8192U,
- get_weights_accessor(data_path, "rnn_lstm_cell_kernel_transpose.npy", DataLayout::NHWC),
- get_weights_accessor(data_path, "rnn_lstm_cell_MatMul_bias.npy"))
+ lstm_fc_weights,
+ lstm_fc_bias)
.set_name(cell_name + "/BiasAdd");
// Split Layer