From 23ae2eae1caefba4948e6afda154a66238b26c2a Mon Sep 17 00:00:00 2001 From: Andre Ghattas Date: Wed, 7 Aug 2019 12:18:38 +0100 Subject: IVGCVBENCH-1337 Added additional layer parameters to dot file and -v option * Generic layer parameters now show up in dot file * Convolution layer parameters have also been added to dot file * ExecucteNetwork has an additional -v flag which generated dot file if there Change-Id: I210bb19b45384eb3639b7e488c7a89049fa6f18d Signed-off-by: Andre Ghattas Signed-off-by: Szilard Papp --- src/armnn/layers/Convolution2dLayer.cpp | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) (limited to 'src/armnn/layers/Convolution2dLayer.cpp') diff --git a/src/armnn/layers/Convolution2dLayer.cpp b/src/armnn/layers/Convolution2dLayer.cpp index 2c7a570790..4300d55e1e 100644 --- a/src/armnn/layers/Convolution2dLayer.cpp +++ b/src/armnn/layers/Convolution2dLayer.cpp @@ -9,7 +9,7 @@ #include #include #include - +#include #include using namespace armnnUtils; @@ -20,6 +20,27 @@ namespace armnn Convolution2dLayer::Convolution2dLayer(const Convolution2dDescriptor& param, const char* name) : LayerWithParameters(1, 1, LayerType::Convolution2d, param, name) { + +} + +void Convolution2dLayer::SerializeLayerParameters(ParameterStringifyFunction& fn) const +{ + //using DescriptorType = Parameters; + const std::vector& inputShapes = + { + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + m_Weight->GetTensorInfo().GetShape() + }; + const TensorShape filterShape = inputShapes[1]; + DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout); + unsigned int filterWidth = filterShape[dataLayoutIndex.GetWidthIndex()]; + unsigned int filterHeight = filterShape[dataLayoutIndex.GetHeightIndex()]; + unsigned int outChannels = filterShape[0]; + + fn("OutputChannels",std::to_string(outChannels)); + fn("FilterWidth",std::to_string(filterWidth)); + fn("FilterHeight",std::to_string(filterHeight)); + LayerWithParameters::SerializeLayerParameters(fn); } std::unique_ptr Convolution2dLayer::CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const -- cgit v1.2.1