diff options
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp')
-rw-r--r-- | src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp b/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp index ddbd0f05c0..a45614a284 100644 --- a/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp +++ b/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp @@ -10,21 +10,31 @@ namespace armnn { -class ClBatchNormalizationFloat32Workload : public Float32Workload<BatchNormalizationQueueDescriptor> +arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input, + const TensorInfo& output, + const TensorInfo& mean, + const TensorInfo& var, + const TensorInfo& beta, + const TensorInfo& gamma, + const BatchNormalizationDescriptor& desc); + +class ClBatchNormalizationFloat32Workload : public FloatWorkload<BatchNormalizationQueueDescriptor> { public: ClBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info); - using Float32Workload<BatchNormalizationQueueDescriptor>::Float32Workload; + using FloatWorkload<BatchNormalizationQueueDescriptor>::FloatWorkload; void Execute() const override; private: mutable arm_compute::CLBatchNormalizationLayer m_Layer; - arm_compute::CLTensor m_Mean; - arm_compute::CLTensor m_Variance; - arm_compute::CLTensor m_Gamma; - arm_compute::CLTensor m_Beta; + std::unique_ptr<arm_compute::CLTensor> m_Mean; + std::unique_ptr<arm_compute::CLTensor> m_Variance; + std::unique_ptr<arm_compute::CLTensor> m_Gamma; + std::unique_ptr<arm_compute::CLTensor> m_Beta; + + void FreeUnusedTensors(); }; } //namespace armnn |