diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-03 16:54:25 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-06-04 15:13:51 +0000 |
commit | 3122bd574a3d29774c535ca2136de361da626e88 (patch) | |
tree | c2fcc19be67f5a35c30d042b80ba3157ef87bd21 /src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp | |
parent | 550fe36f687e73c78b57ebfeee9f98fd35f40f24 (diff) | |
download | armnn-3122bd574a3d29774c535ca2136de361da626e88.tar.gz |
IVGCVSW-3212 Refactor the Reference BatchNormalization workloads to
handle Float32 and QAsymm8 types
* Removed the type-specific workload implementations
* Added type-independent RefBatchNormalizationWorkload implementation
* Reworked BachNormImpl to use decoders/encoders
* Improved the validation of the BatchNorm queue descriptor
* Fixed unit tests where necessary
Change-Id: Icf3fa1332292d38ec2fa0b1cb984cab78426034b
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp | 47 |
1 files changed, 0 insertions, 47 deletions
diff --git a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp b/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp deleted file mode 100644 index e248ad4b9d..0000000000 --- a/src/backends/reference/workloads/RefBatchNormalizationUint8Workload.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefBatchNormalizationUint8Workload.hpp" - -#include "BatchNormImpl.hpp" -#include "RefWorkloadUtils.hpp" - -#include "Profiling.hpp" - -#include <vector> - -namespace armnn -{ -RefBatchNormalizationUint8Workload::RefBatchNormalizationUint8Workload( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : Uint8Workload<BatchNormalizationQueueDescriptor>(descriptor, info), - m_Mean(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Mean))), - m_Variance(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Variance))), - m_Beta(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Beta))), - m_Gamma(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Gamma))) {} - -void RefBatchNormalizationUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchNormalizationUint8Workload_Execute"); - - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& varInfo = GetTensorInfo(m_Variance.get()); - const TensorInfo& meanInfo = GetTensorInfo(m_Mean.get()); - const TensorInfo& gammaInfo = GetTensorInfo(m_Gamma.get()); - const TensorInfo& betaInfo = GetTensorInfo(m_Beta.get()); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - - auto input = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0); - auto var = Dequantize(m_Variance->GetConstTensor<uint8_t>(), varInfo); - auto mean = Dequantize(m_Mean->GetConstTensor<uint8_t>(), meanInfo); - auto gamma = Dequantize(m_Gamma->GetConstTensor<uint8_t>(), gammaInfo); - auto beta = Dequantize(m_Beta->GetConstTensor<uint8_t>(), betaInfo); - - std::vector<float> results(outputInfo.GetNumElements()); - BatchNormImpl(m_Data, var.data(), mean.data(), gamma.data(), beta.data(), results.data(), input.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); -} - -} //namespace armnn |