diff options
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.cpp')
-rw-r--r-- | src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.cpp | 74 |
1 files changed, 63 insertions, 11 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.cpp b/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.cpp index dabd495d59..1849c5d411 100644 --- a/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.cpp +++ b/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.cpp @@ -7,36 +7,88 @@ #include "backends/ClTensorHandle.hpp" #include "backends/CpuTensorHandle.hpp" #include "backends/ArmComputeTensorUtils.hpp" +#include "backends/ClLayerSupport.hpp" namespace armnn { using namespace armcomputetensorutils; +arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input, + const TensorInfo& output, + const TensorInfo& mean, + const TensorInfo& var, + const TensorInfo& beta, + const TensorInfo& gamma, + const BatchNormalizationDescriptor &desc) +{ + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + const arm_compute::TensorInfo aclMeanInfo = BuildArmComputeTensorInfo(mean); + const arm_compute::TensorInfo aclVarInfo = BuildArmComputeTensorInfo(var); + const arm_compute::TensorInfo aclBetaInfo = BuildArmComputeTensorInfo(beta); + const arm_compute::TensorInfo aclGammaInfo = BuildArmComputeTensorInfo(gamma); + + return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo, + &aclOutputInfo, + &aclMeanInfo, + &aclVarInfo, + &aclBetaInfo, + &aclGammaInfo, + desc.m_Eps); +} + ClBatchNormalizationFloat32Workload::ClBatchNormalizationFloat32Workload( const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Float32Workload<BatchNormalizationQueueDescriptor>(descriptor, info) + : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info) { - BuildArmComputeTensor(m_Mean, m_Data.m_Mean->GetTensorInfo()); - BuildArmComputeTensor(m_Variance, m_Data.m_Variance->GetTensorInfo()); - BuildArmComputeTensor(m_Gamma, m_Data.m_Gamma->GetTensorInfo()); - BuildArmComputeTensor(m_Beta, m_Data.m_Beta->GetTensorInfo()); + m_Mean = std::make_unique<arm_compute::CLTensor>(); + BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo()); + + m_Variance = std::make_unique<arm_compute::CLTensor>(); + BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo()); + + m_Gamma = std::make_unique<arm_compute::CLTensor>(); + BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo()); + + m_Beta = std::make_unique<arm_compute::CLTensor>(); + BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo()); m_Data.ValidateInputsOutputs("ClBatchNormalizationFloat32Workload", 1, 1); arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); - m_Layer.configure(&input, &output, &m_Mean, &m_Variance, &m_Beta, &m_Gamma, m_Data.m_Parameters.m_Eps); - InitialiseArmComputeClTensorData(m_Mean, m_Data.m_Mean->GetConstTensor<float>()); - InitialiseArmComputeClTensorData(m_Variance, m_Data.m_Variance->GetConstTensor<float>()); - InitialiseArmComputeClTensorData(m_Beta, m_Data.m_Beta->GetConstTensor<float>()); - InitialiseArmComputeClTensorData(m_Gamma, m_Data.m_Gamma->GetConstTensor<float>()); + m_Layer.configure(&input, + &output, + m_Mean.get(), + m_Variance.get(), + m_Beta.get(), + m_Gamma.get(), + m_Data.m_Parameters.m_Eps); + + InitializeArmComputeClTensorDataForFloatTypes(*m_Mean, m_Data.m_Mean); + InitializeArmComputeClTensorDataForFloatTypes(*m_Variance, m_Data.m_Variance); + InitializeArmComputeClTensorDataForFloatTypes(*m_Beta, m_Data.m_Beta); + InitializeArmComputeClTensorDataForFloatTypes(*m_Gamma, m_Data.m_Gamma); + + // Force Compute Library to perform the necessary copying and reshaping, after which + // delete all the input tensors that will no longer be needed + m_Layer.prepare(); + FreeUnusedTensors(); } void ClBatchNormalizationFloat32Workload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT(Compute::GpuAcc, "ClBatchNormalizationFloat32Workload_Execute"); + ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloat32Workload_Execute"); m_Layer.run(); } +void ClBatchNormalizationFloat32Workload::FreeUnusedTensors() +{ + FreeTensorIfUnused(m_Mean); + FreeTensorIfUnused(m_Variance); + FreeTensorIfUnused(m_Gamma); + FreeTensorIfUnused(m_Beta); +} + } //namespace armnn
\ No newline at end of file |