aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
diff options
context:
space:
mode:
authorPablo Marquez Tello <pablo.tello@arm.com>2023-11-21 10:10:01 +0000
committerPablo Marquez Tello <pablo.tello@arm.com>2023-11-27 17:16:45 +0000
commit8d4cdd43a74574e0f99f83f1adb1d391c0c85abe (patch)
tree614000681778c2f390897888ce69dfdd62561799 /src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
parent835577e1477003789c392d8faab4a3bb8f4040ba (diff)
downloadComputeLibrary-8d4cdd43a74574e0f99f83f1adb1d391c0c85abe.tar.gz
BatchNorm changes to enable fp16 in armv8a multi_isa builds
* Moved NCHW kernels fp16 and fp32 to their corresponding files src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp and src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp * Changes in filelist.json to include the new fp16 and fp32 files * Moved the template batch_normalization_nchw to impl.h as we need to instantiate it from fp16.cpp and fp32.cpp * Pooling layer: removed the guard __ARM_FEATURE_FP16_VECTOR_ARITHMETIC that prevented the FP16 kernel execution. * Partially resolves MLCE-1102 Change-Id: Ia8c85e9ffb76c9e387f9ae2685e5df5e52c8dc27 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10777 Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h')
-rw-r--r--src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h36
1 files changed, 12 insertions, 24 deletions
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
index 2e8ff0dc9a..679ade0fae 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_NEBATCHNORMALIZATIONLAYERKERNEL_H
-#define ARM_COMPUTE_NEBATCHNORMALIZATIONLAYERKERNEL_H
+#ifndef ACL_SRC_CORE_NEON_KERNELS_NEBATCHNORMALIZATIONLAYERKERNEL_H
+#define ACL_SRC_CORE_NEON_KERNELS_NEBATCHNORMALIZATIONLAYERKERNEL_H
#include "arm_compute/function_info/ActivationLayerInfo.h"
@@ -110,31 +110,19 @@ private:
/** Configure execution function in case of fused activation **/
void configure_fused();
- /** Template function to run batch normalization on fp32
- *
- * @tparam T Specialization data type
- * @tparam fused_activation Boolean that flags if its a fused activation or not
- * @tparam F Activation function functor to run
- *
- * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
- */
- template <typename T, bool fused_activation, typename F>
- void batch_normalization_nchw(const Window &window);
- /** Template function to run batch normalization on fp32 on tensors with NHWC format
- *
- * @tparam T Specialization data type
- * @tparam fused_activation Boolean that flags if its a fused activation or not
- * @tparam F Activation function functor to run
- *
- * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
- */
- template <typename T, bool fused_activation, typename F>
- void batch_normalization_nhwc(const Window &window);
/** Common signature for all the batch normalization functions
*
* @param[in] window Region on which to execute the kernel.
*/
- using BatchNormFunctionPtr = void (NEBatchNormalizationLayerKernel::*)(const Window &window);
+ using BatchNormFunctionPtr = void (*)(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info);
private:
BatchNormFunctionPtr _func;
@@ -148,4 +136,4 @@ private:
ActivationLayerInfo _act_info;
};
} // namespace arm_compute
-#endif /*ARM_COMPUTE_NEBATCHNORMALIZATIONLAYERKERNEL_H */
+#endif // ACL_SRC_CORE_NEON_KERNELS_NEBATCHNORMALIZATIONLAYERKERNEL_H