From 2df12b3abf2171869582581dd275d556bad87411 Mon Sep 17 00:00:00 2001 From: saoste01 Date: Wed, 28 Nov 2018 16:57:20 +0000 Subject: IVGCVSW-2254 Add Reference workload for Maximum Change-Id: Id7302c6b1df995ebe6eb8eb94bab38bee1b31b0b --- src/backends/backendsCommon/StringMapping.hpp | 1 + src/backends/reference/RefLayerSupport.cpp | 13 +++++++++++++ src/backends/reference/RefLayerSupport.hpp | 6 ++++++ src/backends/reference/RefWorkloadFactory.cpp | 2 +- src/backends/reference/workloads/CMakeLists.txt | 1 + .../reference/workloads/ElementwiseFunction.cpp | 3 +++ src/backends/reference/workloads/Maximum.hpp | 22 ++++++++++++++++++++++ .../reference/workloads/RefElementwiseWorkload.cpp | 3 +++ .../reference/workloads/RefElementwiseWorkload.hpp | 14 ++++++++++++++ 9 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 src/backends/reference/workloads/Maximum.hpp diff --git a/src/backends/backendsCommon/StringMapping.hpp b/src/backends/backendsCommon/StringMapping.hpp index 6312e68945..f6af8217a9 100644 --- a/src/backends/backendsCommon/StringMapping.hpp +++ b/src/backends/backendsCommon/StringMapping.hpp @@ -19,6 +19,7 @@ public: enum Id { RefAdditionWorkload_Execute, RefSubtractionWorkload_Execute, + RefMaximumWorkload_Execute, RefMultiplicationWorkload_Execute, RefDivisionWorkload_Execute, MAX_STRING_ID diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 00e4c5c09c..7222af6402 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -308,6 +308,19 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, &FalseFuncU8<>); } +bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional reasonIfUnsupported) const +{ + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + bool RefLayerSupport::IsMeanSupported(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index defa962847..73e5394fc7 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -116,6 +116,12 @@ public: const TensorInfo* cellToOutputWeights, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsMaximumSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional reasonIfUnsupported = EmptyOptional()) const override; + + bool IsMeanSupported(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index c93dc31ea9..eef5b24df7 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -261,7 +261,7 @@ std::unique_ptr RefWorkloadFactory::CreateSubtraction( std::unique_ptr RefWorkloadFactory::CreateMaximum( const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return MakeWorkload(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateMean( diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 86c5f908b9..b9c150f6d4 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND armnnRefBackendWorkloads_sources ElementwiseFunction.hpp FullyConnected.cpp FullyConnected.hpp + Maximum.hpp Merger.hpp Pad.cpp Pad.hpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index bea3d2fb89..bb15049faa 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -7,6 +7,8 @@ #include "Broadcast.hpp" #include +#include "Maximum.hpp" + namespace armnn { @@ -27,3 +29,4 @@ template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; \ No newline at end of file diff --git a/src/backends/reference/workloads/Maximum.hpp b/src/backends/reference/workloads/Maximum.hpp new file mode 100644 index 0000000000..524afffc44 --- /dev/null +++ b/src/backends/reference/workloads/Maximum.hpp @@ -0,0 +1,22 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnn +{ + template + struct maximum + { + T + operator () (const T& inputData0, const T& inputData1) const + { + return std::max(inputData0, inputData1); + } + }; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 8e312a7dd1..60a1b990f7 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -67,3 +67,6 @@ template class armnn::BaseUint8ElementwiseWorkload>; template class armnn::BaseUint8ElementwiseWorkload>; + +template class armnn::BaseFloat32ElementwiseWorkload>; +template class armnn::BaseUint8ElementwiseWorkload>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 156613a49f..2772b77631 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -9,6 +9,7 @@ #include #include #include +#include "Maximum.hpp" namespace armnn { @@ -119,4 +120,17 @@ using RefDivisionUint8Workload = DivisionQueueDescriptor, StringMapping::RefDivisionWorkload_Execute>; + +using RefMaximumFloat32Workload = + RefElementwiseWorkload, + DataType::Float32, + MaximumQueueDescriptor, + StringMapping::RefMaximumWorkload_Execute>; + +using RefMaximumUint8Workload = + RefElementwiseWorkload, + DataType::QuantisedAsymm8, + MaximumQueueDescriptor, + StringMapping::RefMaximumWorkload_Execute>; + } // armnn -- cgit v1.2.1