diff options
-rw-r--r-- | src/backends/backendsCommon/StringMapping.hpp | 1 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 13 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 6 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/backends/reference/workloads/ElementwiseFunction.cpp | 3 | ||||
-rw-r--r-- | src/backends/reference/workloads/Maximum.hpp | 22 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.cpp | 3 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.hpp | 14 |
9 files changed, 64 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/StringMapping.hpp b/src/backends/backendsCommon/StringMapping.hpp index 6312e68945..f6af8217a9 100644 --- a/src/backends/backendsCommon/StringMapping.hpp +++ b/src/backends/backendsCommon/StringMapping.hpp @@ -19,6 +19,7 @@ public: enum Id { RefAdditionWorkload_Execute, RefSubtractionWorkload_Execute, + RefMaximumWorkload_Execute, RefMultiplicationWorkload_Execute, RefDivisionWorkload_Execute, MAX_STRING_ID diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 00e4c5c09c..7222af6402 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -308,6 +308,19 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, &FalseFuncU8<>); } +bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported) const +{ + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + bool RefLayerSupport::IsMeanSupported(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index defa962847..73e5394fc7 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -116,6 +116,12 @@ public: const TensorInfo* cellToOutputWeights, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsMaximumSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + + bool IsMeanSupported(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index c93dc31ea9..eef5b24df7 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -261,7 +261,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction( std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMaximum( const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info); + return MakeWorkload<RefMaximumFloat32Workload, RefMaximumUint8Workload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean( diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 86c5f908b9..b9c150f6d4 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND armnnRefBackendWorkloads_sources ElementwiseFunction.hpp FullyConnected.cpp FullyConnected.hpp + Maximum.hpp Merger.hpp Pad.cpp Pad.hpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index bea3d2fb89..bb15049faa 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -7,6 +7,8 @@ #include "Broadcast.hpp" #include <functional> +#include "Maximum.hpp" + namespace armnn { @@ -27,3 +29,4 @@ template struct armnn::ElementwiseFunction<std::plus<float>>; template struct armnn::ElementwiseFunction<std::minus<float>>; template struct armnn::ElementwiseFunction<std::multiplies<float>>; template struct armnn::ElementwiseFunction<std::divides<float>>; +template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
\ No newline at end of file diff --git a/src/backends/reference/workloads/Maximum.hpp b/src/backends/reference/workloads/Maximum.hpp new file mode 100644 index 0000000000..524afffc44 --- /dev/null +++ b/src/backends/reference/workloads/Maximum.hpp @@ -0,0 +1,22 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <iostream> + +namespace armnn +{ + template<typename T> + struct maximum + { + T + operator () (const T& inputData0, const T& inputData1) const + { + return std::max(inputData0, inputData1); + } + }; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 8e312a7dd1..60a1b990f7 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -67,3 +67,6 @@ template class armnn::BaseUint8ElementwiseWorkload<armnn::MultiplicationQueueDes template class armnn::BaseFloat32ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>; template class armnn::BaseUint8ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>; + +template class armnn::BaseFloat32ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>; +template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 156613a49f..2772b77631 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -9,6 +9,7 @@ #include <backendsCommon/StringMapping.hpp> #include <backendsCommon/Workload.hpp> #include <backendsCommon/WorkloadData.hpp> +#include "Maximum.hpp" namespace armnn { @@ -119,4 +120,17 @@ using RefDivisionUint8Workload = DivisionQueueDescriptor, StringMapping::RefDivisionWorkload_Execute>; + +using RefMaximumFloat32Workload = + RefElementwiseWorkload<armnn::maximum<float>, + DataType::Float32, + MaximumQueueDescriptor, + StringMapping::RefMaximumWorkload_Execute>; + +using RefMaximumUint8Workload = + RefElementwiseWorkload<armnn::maximum<float>, + DataType::QuantisedAsymm8, + MaximumQueueDescriptor, + StringMapping::RefMaximumWorkload_Execute>; + } // armnn |