diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLMeanStdDev.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLMeanStdDev.h | 52 |
1 files changed, 44 insertions, 8 deletions
diff --git a/arm_compute/runtime/CL/functions/CLMeanStdDev.h b/arm_compute/runtime/CL/functions/CLMeanStdDev.h index 7622138236..2e46563423 100644 --- a/arm_compute/runtime/CL/functions/CLMeanStdDev.h +++ b/arm_compute/runtime/CL/functions/CLMeanStdDev.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,7 +27,10 @@ #include "arm_compute/core/CL/OpenCL.h" #include "arm_compute/core/CL/kernels/CLFillBorderKernel.h" #include "arm_compute/core/CL/kernels/CLMeanStdDevKernel.h" +#include "arm_compute/runtime/CL/CLMemoryGroup.h" +#include "arm_compute/runtime/CL/functions/CLReductionOperation.h" #include "arm_compute/runtime/IFunction.h" +#include "arm_compute/runtime/IMemoryManager.h" namespace arm_compute { @@ -36,23 +39,56 @@ class CLMeanStdDev : public IFunction { public: /** Default Constructor. */ - CLMeanStdDev(); + CLMeanStdDev(std::shared_ptr<IMemoryManager> memory_manager = nullptr); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLMeanStdDev(const CLMeanStdDev &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CLMeanStdDev &operator=(const CLMeanStdDev &) = delete; + /** Allow instances of this class to be moved */ + CLMeanStdDev(CLMeanStdDev &&) = default; + /** Allow instances of this class to be moved */ + CLMeanStdDev &operator=(CLMeanStdDev &&) = default; + /** Default destructor */ + ~CLMeanStdDev() = default; /** Initialise the kernel's inputs and outputs. * - * @param[in, out] input Input image. Data types supported: U8. (Written to only for border filling) + * @param[in, out] input Input image. Data types supported: U8/F16/F32. (Written to only for border filling) * @param[out] mean Output average pixel value. - * @param[out] stddev (Optional)Output standard deviation of pixel values. + * @param[out] stddev (Optional) Output standard deviation of pixel values. */ void configure(ICLImage *input, float *mean, float *stddev = nullptr); + /** Static function to check if given info will lead to a valid configuration of @ref CLMeanStdDev + * + * @param[in] input Input image. Data types supported: U8/F16/F32. + * @param[in] mean Output average pixel value. + * @param[in] stddev (Optional) Output standard deviation of pixel values. + * + * @return a status + */ + static Status validate(ITensorInfo *input, float *mean, float *stddev = nullptr); // Inherited methods overridden: void run() override; private: - CLMeanStdDevKernel _mean_stddev_kernel; /**< Kernel that standard deviation calculation. */ - CLFillBorderKernel _fill_border_kernel; /**< Kernel that fills the border with zeroes. */ - cl::Buffer _global_sum; /**< Variable that holds the global sum among calls in order to ease reduction */ - cl::Buffer _global_sum_squared; /**< Variable that holds the global sum of squared values among calls in order to ease reduction */ + template <typename T> + void run_float(); + void run_int(); + + CLMemoryGroup _memory_group; /**< Function's memory group */ + DataType _data_type; /**< Input data type. */ + unsigned int _num_pixels; /**< Number of image's pixels. */ + bool _run_stddev; /**< Flag for knowing if we should run stddev reduction function. */ + CLReductionOperation _reduction_operation_mean; /**< Reduction operation function for computing mean value. */ + CLReductionOperation _reduction_operation_stddev; /**< Reduction operation function for computing standard deviation. */ + CLTensor _reduction_output_mean; /**< Reduction operation output tensor for mean value. */ + CLTensor _reduction_output_stddev; /**< Reduction operation output tensor for standard deviation value. */ + float *_mean; /**< Pointer that holds the mean value. */ + float *_stddev; /**< Pointer that holds the standard deviation value. */ + CLMeanStdDevKernel _mean_stddev_kernel; /**< Kernel that standard deviation calculation. */ + CLFillBorderKernel _fill_border_kernel; /**< Kernel that fills the border with zeroes. */ + cl::Buffer _global_sum; /**< Variable that holds the global sum among calls in order to ease reduction */ + cl::Buffer _global_sum_squared; /**< Variable that holds the global sum of squared values among calls in order to ease reduction */ }; } #endif /*__ARM_COMPUTE_CLMEANSTDDEV_H__ */ |