aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp')
-rw-r--r--src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp23
1 files changed, 15 insertions, 8 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp
index 8a48523765..4a8e296619 100644
--- a/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp
@@ -14,23 +14,30 @@
namespace armnn
{
+RefBatchNormalizationUint8Workload::RefBatchNormalizationUint8Workload(
+ const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : Uint8Workload<BatchNormalizationQueueDescriptor>(descriptor, info),
+ m_Mean(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))),
+ m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))),
+ m_Beta(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))),
+ m_Gamma(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) {}
void RefBatchNormalizationUint8Workload::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationUint8Workload_Execute");
const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
- const TensorInfo& varInfo = GetTensorInfo(m_Data.m_Variance);
- const TensorInfo& meanInfo = GetTensorInfo(m_Data.m_Mean);
- const TensorInfo& gammaInfo = GetTensorInfo(m_Data.m_Gamma);
- const TensorInfo& betaInfo = GetTensorInfo(m_Data.m_Beta);
+ const TensorInfo& varInfo = GetTensorInfo(m_Variance.get());
+ const TensorInfo& meanInfo = GetTensorInfo(m_Mean.get());
+ const TensorInfo& gammaInfo = GetTensorInfo(m_Gamma.get());
+ const TensorInfo& betaInfo = GetTensorInfo(m_Beta.get());
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
auto input = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0);
- auto var = Dequantize(m_Data.m_Variance->GetConstTensor<uint8_t>(), varInfo);
- auto mean = Dequantize(m_Data.m_Mean->GetConstTensor<uint8_t>(), meanInfo);
- auto gamma = Dequantize(m_Data.m_Gamma->GetConstTensor<uint8_t>(), gammaInfo);
- auto beta = Dequantize(m_Data.m_Beta->GetConstTensor<uint8_t>(), betaInfo);
+ auto var = Dequantize(m_Variance->GetConstTensor<uint8_t>(), varInfo);
+ auto mean = Dequantize(m_Mean->GetConstTensor<uint8_t>(), meanInfo);
+ auto gamma = Dequantize(m_Gamma->GetConstTensor<uint8_t>(), gammaInfo);
+ auto beta = Dequantize(m_Beta->GetConstTensor<uint8_t>(), betaInfo);
std::vector<float> results(outputInfo.GetNumElements());
BatchNormImpl(m_Data, var.data(), mean.data(), gamma.data(), beta.data(), results.data(), input.data());