aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp')
-rw-r--r--src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp22
1 files changed, 16 insertions, 6 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp b/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp
index ddbd0f05c0..a45614a284 100644
--- a/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp
+++ b/src/armnn/backends/ClWorkloads/ClBatchNormalizationFloat32Workload.hpp
@@ -10,21 +10,31 @@
namespace armnn
{
-class ClBatchNormalizationFloat32Workload : public Float32Workload<BatchNormalizationQueueDescriptor>
+arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const TensorInfo& mean,
+ const TensorInfo& var,
+ const TensorInfo& beta,
+ const TensorInfo& gamma,
+ const BatchNormalizationDescriptor& desc);
+
+class ClBatchNormalizationFloat32Workload : public FloatWorkload<BatchNormalizationQueueDescriptor>
{
public:
ClBatchNormalizationFloat32Workload(const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info);
- using Float32Workload<BatchNormalizationQueueDescriptor>::Float32Workload;
+ using FloatWorkload<BatchNormalizationQueueDescriptor>::FloatWorkload;
void Execute() const override;
private:
mutable arm_compute::CLBatchNormalizationLayer m_Layer;
- arm_compute::CLTensor m_Mean;
- arm_compute::CLTensor m_Variance;
- arm_compute::CLTensor m_Gamma;
- arm_compute::CLTensor m_Beta;
+ std::unique_ptr<arm_compute::CLTensor> m_Mean;
+ std::unique_ptr<arm_compute::CLTensor> m_Variance;
+ std::unique_ptr<arm_compute::CLTensor> m_Gamma;
+ std::unique_ptr<arm_compute::CLTensor> m_Beta;
+
+ void FreeUnusedTensors();
};
} //namespace armnn