diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
commit | c577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch) | |
tree | bd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp | |
parent | 4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff) | |
download | armnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz |
Release 18.08
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp')
-rw-r--r-- | src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp index c421b0f212..fbc1f07111 100644 --- a/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp +++ b/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp @@ -12,15 +12,22 @@ namespace armnn { +RefBatchNormalizationFloat32Workload::RefBatchNormalizationFloat32Workload( + const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) + : Float32Workload<BatchNormalizationQueueDescriptor>(descriptor, info), + m_Mean(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))), + m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))), + m_Beta(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))), + m_Gamma(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) {} void RefBatchNormalizationFloat32Workload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationFloat32Workload_Execute"); - const float* var = m_Data.m_Variance->GetConstTensor<float>(); - const float* mean = m_Data.m_Mean->GetConstTensor<float>(); - const float* gamma = m_Data.m_Gamma->GetConstTensor<float>(); - const float* beta = m_Data.m_Beta->GetConstTensor<float>(); + const float* var = m_Variance->GetConstTensor<float>(); + const float* mean = m_Mean->GetConstTensor<float>(); + const float* gamma = m_Gamma->GetConstTensor<float>(); + const float* beta = m_Beta->GetConstTensor<float>(); auto inputData = GetInputTensorDataFloat(0, m_Data); auto outputData = GetOutputTensorDataFloat(0, m_Data); |