aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLMeanStdDev.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-06-15 16:15:26 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitd1794ebfa10d05af7d2458c5d506152fd38068d3 (patch)
treee3f286aaba86b1f0bcda3390ad4d8af96b965fc7 /src/runtime/CL/functions/CLMeanStdDev.cpp
parent7777b1aa865d3c17dcef31573d44fae421176109 (diff)
downloadComputeLibrary-d1794ebfa10d05af7d2458c5d506152fd38068d3.tar.gz
COMPMID-1226 Extend CLMeanStdDev to support FP32 / FP16
- Extend support for FP16 in CLReduction. - For F16/F32 MeanStdDev we perform one reduction operation for mean and one for stddev and we calculate the final result in the host CPU. Change-Id: Iad2099f26c0ba7969737d22f00c6c275634d875c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/135870 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLMeanStdDev.cpp')
-rw-r--r--src/runtime/CL/functions/CLMeanStdDev.cpp134
1 files changed, 124 insertions, 10 deletions
diff --git a/src/runtime/CL/functions/CLMeanStdDev.cpp b/src/runtime/CL/functions/CLMeanStdDev.cpp
index 838f7e73d2..157f306d0c 100644
--- a/src/runtime/CL/functions/CLMeanStdDev.cpp
+++ b/src/runtime/CL/functions/CLMeanStdDev.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,35 +21,149 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "arm_compute/runtime/CL/functions/CLMeanStdDev.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/functions/CLMeanStdDev.h"
using namespace arm_compute;
-CLMeanStdDev::CLMeanStdDev()
- : _mean_stddev_kernel(),
+CLMeanStdDev::CLMeanStdDev(std::shared_ptr<IMemoryManager> memory_manager) // NOLINT
+ : _memory_group(std::move(memory_manager)),
+ _data_type(),
+ _num_pixels(),
+ _run_stddev(),
+ _reduction_operation_mean(),
+ _reduction_operation_stddev(),
+ _reduction_output_mean(),
+ _reduction_output_stddev(),
+ _mean(nullptr),
+ _stddev(nullptr),
+ _mean_stddev_kernel(),
_fill_border_kernel(),
_global_sum(),
_global_sum_squared()
{
}
+Status CLMeanStdDev::validate(ITensorInfo *input, float *mean, float *stddev)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_TENSOR_NOT_2D(input);
+ if(is_data_type_float(input->data_type()))
+ {
+ ARM_COMPUTE_UNUSED(mean);
+ ARM_COMPUTE_UNUSED(stddev);
+
+ TensorShape output_shape = TensorShape{ 1, input->dimension(1) };
+ TensorInfo output_shape_info = TensorInfo(output_shape, 1, DataType::U8);
+ return CLReductionOperation::validate(input, &output_shape_info, 0, ReductionOperation::SUM);
+ }
+ else
+ {
+ return CLMeanStdDevKernel::validate(input, mean, nullptr, stddev, nullptr);
+ }
+}
+
void CLMeanStdDev::configure(ICLImage *input, float *mean, float *stddev)
{
- _global_sum = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
+ // In the case of F16/F32 we call reduction operation for calculating CLMeanStdDev
+ _data_type = input->info()->data_type();
- if(stddev != nullptr)
+ if(is_data_type_float(_data_type))
{
- _global_sum_squared = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
+ _num_pixels = input->info()->dimension(0) * input->info()->dimension(1);
+
+ _memory_group.manage(&_reduction_output_mean);
+ _reduction_operation_mean.configure(input, &_reduction_output_mean, 0, ReductionOperation::SUM);
+ _reduction_output_mean.allocator()->allocate();
+ _mean = mean;
+
+ if(stddev != nullptr)
+ {
+ _memory_group.manage(&_reduction_output_stddev);
+ _reduction_operation_stddev.configure(input, &_reduction_output_stddev, 0, ReductionOperation::SUM_SQUARE);
+ _reduction_output_stddev.allocator()->allocate();
+ _stddev = stddev;
+ _run_stddev = true;
+ }
}
+ else
+ {
+ _global_sum = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
- _mean_stddev_kernel.configure(input, mean, &_global_sum, stddev, &_global_sum_squared);
- _fill_border_kernel.configure(input, _mean_stddev_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast<uint8_t>(0)));
+ if(stddev != nullptr)
+ {
+ _global_sum_squared = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_ulong));
+ }
+
+ _mean_stddev_kernel.configure(input, mean, &_global_sum, stddev, &_global_sum_squared);
+ _fill_border_kernel.configure(input, _mean_stddev_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast<uint8_t>(0)));
+ }
}
-void CLMeanStdDev::run()
+template <typename T>
+void CLMeanStdDev::run_float()
+{
+ _memory_group.acquire();
+
+ // Perform reduction on x-axis
+ _reduction_operation_mean.run();
+ if(_run_stddev)
+ {
+ _reduction_operation_stddev.run();
+ _reduction_output_stddev.map(true);
+ }
+
+ _reduction_output_mean.map(true);
+
+ auto mean = static_cast<T>(0);
+
+ // Calculate final result for mean
+ for(unsigned int i = 0; i < _reduction_output_mean.info()->dimension(1); ++i)
+ {
+ mean += *reinterpret_cast<T *>(_reduction_output_mean.buffer() + _reduction_output_mean.info()->offset_element_in_bytes(Coordinates(0, i)));
+ }
+
+ mean /= _num_pixels;
+ *_mean = mean;
+
+ if(_run_stddev)
+ {
+ auto stddev = static_cast<T>(0);
+ // Calculate final result for stddev
+ for(unsigned int i = 0; i < _reduction_output_stddev.info()->dimension(1); ++i)
+ {
+ stddev += *reinterpret_cast<T *>(_reduction_output_stddev.buffer() + _reduction_output_stddev.info()->offset_element_in_bytes(Coordinates(0, i)));
+ }
+ *_stddev = std::sqrt((stddev / _num_pixels) - (mean * mean));
+
+ _reduction_output_stddev.unmap();
+ }
+ _reduction_output_mean.unmap();
+
+ _memory_group.release();
+}
+
+void CLMeanStdDev::run_int()
{
CLScheduler::get().enqueue(_fill_border_kernel);
CLScheduler::get().enqueue(_mean_stddev_kernel);
}
+
+void CLMeanStdDev::run()
+{
+ switch(_data_type)
+ {
+ case DataType::F16:
+ run_float<half>();
+ break;
+ case DataType::F32:
+ run_float<float>();
+ break;
+ case DataType::U8:
+ run_int();
+ break;
+ default:
+ ARM_COMPUTE_ERROR_ON("Not supported");
+ }
+}