aboutsummaryrefslogtreecommitdiff
path: root/examples/graph_alexnet.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-03-21 17:45:31 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commited99f411d52949720a4d64d91664cd71e46b79d5 (patch)
treed903b523dea830aeb48d59a66b8da59e4dcf707a /examples/graph_alexnet.cpp
parent6528aa20e768f2d801328aa164d672b7fdfe266f (diff)
downloadComputeLibrary-ed99f411d52949720a4d64d91664cd71e46b79d5.tar.gz
COMPMID-1018 - Add Winograd support in VGG16 and Alexnet examples
Change-Id: I4a2deee9e4b2c54ea79d2895cfeca44190133b24 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125453 Reviewed-by: Pablo Tello <pablo.tello@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'examples/graph_alexnet.cpp')
-rw-r--r--examples/graph_alexnet.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/examples/graph_alexnet.cpp b/examples/graph_alexnet.cpp
index a396c7686c..f887f97a12 100644
--- a/examples/graph_alexnet.cpp
+++ b/examples/graph_alexnet.cpp
@@ -57,8 +57,10 @@ public:
const int int_target_hint = argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0;
TargetHint target_hint = set_target_hint(int_target_hint);
- const bool is_gemm_convolution5x5 = Graph::gpu_target() == arm_compute::GPUTarget::MIDGARD || target_hint == TargetHint::NEON;
- ConvolutionMethodHint convolution_5x5_hint = is_gemm_convolution5x5 ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT;
+ const bool is_gemm_convolution5x5 = Graph::gpu_target() == arm_compute::GPUTarget::MIDGARD || target_hint == TargetHint::NEON;
+ const bool is_winograd_convolution3x3 = target_hint == TargetHint::OPENCL;
+ ConvolutionMethodHint convolution_5x5_hint = is_gemm_convolution5x5 ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT;
+ ConvolutionMethodHint convolution_3x3_hint = is_winograd_convolution3x3 ? ConvolutionMethodHint::WINOGRAD : ConvolutionMethodHint::GEMM;
// Parse arguments
if(argc < 2)
@@ -114,7 +116,7 @@ public:
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
<< NormalizationLayer(NormalizationLayerInfo(NormType::CROSS_MAP, 5, 0.0001f, 0.75f))
<< PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 0, 0)))
- << ConvolutionMethodHint::GEMM
+ << convolution_3x3_hint
// Layer 3
<< ConvolutionLayer(
3U, 3U, 384U,