diff options
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp')
-rw-r--r-- | src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp index c421b0f212..fbc1f07111 100644 --- a/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp +++ b/src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp @@ -12,15 +12,22 @@ namespace armnn { +RefBatchNormalizationFloat32Workload::RefBatchNormalizationFloat32Workload( + const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) + : Float32Workload<BatchNormalizationQueueDescriptor>(descriptor, info), + m_Mean(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))), + m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))), + m_Beta(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))), + m_Gamma(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) {} void RefBatchNormalizationFloat32Workload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationFloat32Workload_Execute"); - const float* var = m_Data.m_Variance->GetConstTensor<float>(); - const float* mean = m_Data.m_Mean->GetConstTensor<float>(); - const float* gamma = m_Data.m_Gamma->GetConstTensor<float>(); - const float* beta = m_Data.m_Beta->GetConstTensor<float>(); + const float* var = m_Variance->GetConstTensor<float>(); + const float* mean = m_Mean->GetConstTensor<float>(); + const float* gamma = m_Gamma->GetConstTensor<float>(); + const float* beta = m_Beta->GetConstTensor<float>(); auto inputData = GetInputTensorDataFloat(0, m_Data); auto outputData = GetOutputTensorDataFloat(0, m_Data); |