aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h
diff options
context:
space:
mode:
authorPablo Marquez Tello <pablo.tello@arm.com>2021-03-03 12:12:35 +0000
committerPablo Marquez Tello <pablo.tello@arm.com>2021-04-19 15:02:29 +0000
commitfe7ae817755577be29f4c07aa27d8ef9e821da45 (patch)
tree459b1b22f59cf5144cd72b839fbfdf21fa341479 /arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h
parent60c3b0e6821a80d78ffca5be30e05d062d071cd2 (diff)
downloadComputeLibrary-fe7ae817755577be29f4c07aa27d8ef9e821da45.tar.gz
CLInstanceNormalizationLayer NHWC optimisation
* Make changes to split the workload into two kernels. One kernel precomputes mean and variance and the second kernel just loads these precomputed values. * The new approach runs %30 faster than the original code for NHWC workloads like 32x192x256. * Resolves MLCE-337 Change-Id: I8356fcefa2d131ab4dcb32268ce7142421d073e4 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5355 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h37
1 files changed, 32 insertions, 5 deletions
diff --git a/arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h b/arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h
index d41f3fedf6..a6e5b1622b 100644
--- a/arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h
+++ b/arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,24 +25,44 @@
#define ARM_COMPUTE_CLINSTANCENORMALIZATIONLAYER_H
#include "arm_compute/core/Error.h"
-#include "arm_compute/runtime/CL/ICLSimpleFunction.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+#include "arm_compute/runtime/IFunction.h"
+
+#include <memory>
namespace arm_compute
{
class CLCompileContext;
class ICLTensor;
class ITensorInfo;
+class ICLKernel;
+class CLRuntimeContext;
/** Basic function to perform a Instance normalization.
*
* This function runs the following kernels:
* -# @ref CLInstanceNormalizationLayerKernel
*/
-class CLInstanceNormalizationLayer : public ICLSimpleFunction
+class CLInstanceNormalizationLayer : public IFunction
{
public:
- /** Default constructor */
- CLInstanceNormalizationLayer();
+ /** Constructor
+ *
+ * @param[in] ctx Runtime context to be used by the function
+ */
+ CLInstanceNormalizationLayer(CLRuntimeContext *ctx = nullptr);
+
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLInstanceNormalizationLayer(const CLInstanceNormalizationLayer &) = delete;
+ /** Default move constructor */
+ CLInstanceNormalizationLayer(CLInstanceNormalizationLayer &&) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLInstanceNormalizationLayer &operator=(const CLInstanceNormalizationLayer &) = delete;
+ /** Default move assignment operator */
+ CLInstanceNormalizationLayer &operator=(CLInstanceNormalizationLayer &&) = default;
+ /** Default destructor */
+ ~CLInstanceNormalizationLayer();
+
/** Set the input and output tensors.
*
* @param[in, out] input Source tensor. In case of @p output tensor = nullptr this tensor will store the result of the normalization.
@@ -79,6 +99,13 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *output, float gamma = 1.0f, float beta = 0.0f, float epsilon = 1e-12f, bool use_mixed_precision = true);
+ void run() override;
+
+private:
+ std::unique_ptr<ICLKernel> _inst_norm_kernel; /**< Kernel to run */
+ std::unique_ptr<ICLKernel> _mean_var_kernel; /**< Kernel to run */
+ CLTensor _mean_var_tensor;
+ CLRuntimeContext *_ctx; /**< Context to use */
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CLINSTANCENORMALIZATIONLAYER_H */