diff options
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp')
-rw-r--r-- | src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp index 8a48523765..4a8e296619 100644 --- a/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp +++ b/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp @@ -14,23 +14,30 @@ namespace armnn { +RefBatchNormalizationUint8Workload::RefBatchNormalizationUint8Workload( + const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) + : Uint8Workload<BatchNormalizationQueueDescriptor>(descriptor, info), + m_Mean(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))), + m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))), + m_Beta(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))), + m_Gamma(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) {} void RefBatchNormalizationUint8Workload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationUint8Workload_Execute"); const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& varInfo = GetTensorInfo(m_Data.m_Variance); - const TensorInfo& meanInfo = GetTensorInfo(m_Data.m_Mean); - const TensorInfo& gammaInfo = GetTensorInfo(m_Data.m_Gamma); - const TensorInfo& betaInfo = GetTensorInfo(m_Data.m_Beta); + const TensorInfo& varInfo = GetTensorInfo(m_Variance.get()); + const TensorInfo& meanInfo = GetTensorInfo(m_Mean.get()); + const TensorInfo& gammaInfo = GetTensorInfo(m_Gamma.get()); + const TensorInfo& betaInfo = GetTensorInfo(m_Beta.get()); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); auto input = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0); - auto var = Dequantize(m_Data.m_Variance->GetConstTensor<uint8_t>(), varInfo); - auto mean = Dequantize(m_Data.m_Mean->GetConstTensor<uint8_t>(), meanInfo); - auto gamma = Dequantize(m_Data.m_Gamma->GetConstTensor<uint8_t>(), gammaInfo); - auto beta = Dequantize(m_Data.m_Beta->GetConstTensor<uint8_t>(), betaInfo); + auto var = Dequantize(m_Variance->GetConstTensor<uint8_t>(), varInfo); + auto mean = Dequantize(m_Mean->GetConstTensor<uint8_t>(), meanInfo); + auto gamma = Dequantize(m_Gamma->GetConstTensor<uint8_t>(), gammaInfo); + auto beta = Dequantize(m_Beta->GetConstTensor<uint8_t>(), betaInfo); std::vector<float> results(outputInfo.GetNumElements()); BatchNormImpl(m_Data, var.data(), mean.data(), gamma.data(), beta.data(), results.data(), input.data()); |