aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClMultiplicationWorkload.cpp
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2018-10-01 11:32:48 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commite2ec330b015fc34331f0023283b964ef97e1c5bb (patch)
tree76ce917f5d243ceb473dcdc2263e09781954b635 /src/backends/cl/workloads/ClMultiplicationWorkload.cpp
parentb30c53342756699dfaa6effd0a94aa9e69c2063a (diff)
downloadarmnn-e2ec330b015fc34331f0023283b964ef97e1c5bb.tar.gz
IVGCVSW-1207 - Remove typing from ClMultiplicationWorkload
Don't need this now as it uses the compute library validation function, and all of the code for the supported types is identical. Adds Uint8 support to Cl backend, and unit test cases. Change-Id: I35d4edacc1aca241e95d1b19ae525a23d9513c99
Diffstat (limited to 'src/backends/cl/workloads/ClMultiplicationWorkload.cpp')
-rw-r--r--src/backends/cl/workloads/ClMultiplicationWorkload.cpp60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/backends/cl/workloads/ClMultiplicationWorkload.cpp b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp
new file mode 100644
index 0000000000..9d23caa695
--- /dev/null
+++ b/src/backends/cl/workloads/ClMultiplicationWorkload.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClMultiplicationWorkload.hpp"
+#include <backends/cl/ClTensorHandle.hpp>
+#include <backends/CpuTensorHandle.hpp>
+#include "ClWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+arm_compute::Status ClMultiplicationWorkloadValidate(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output)
+{
+ const arm_compute::TensorInfo aclInput1 = armcomputetensorutils::BuildArmComputeTensorInfo(input0);
+ const arm_compute::TensorInfo aclInput2 = armcomputetensorutils::BuildArmComputeTensorInfo(input1);
+ const arm_compute::TensorInfo aclOutput = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+ // At the time of writing, configure() will fail if a rounding policy other than TO_ZERO is supplied to it,
+ // when providing a scale of 1.0 for F32 tensors, even though the provided rounding policy appears to be
+ // ignored for F32 tensors.
+ return arm_compute::CLPixelWiseMultiplication::validate(&aclInput1,
+ &aclInput2,
+ &aclOutput,
+ 1.0f,
+ arm_compute::ConvertPolicy::SATURATE,
+ arm_compute::RoundingPolicy::TO_ZERO);
+}
+
+
+ClMultiplicationWorkload::ClMultiplicationWorkload(const MultiplicationQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<MultiplicationQueueDescriptor>(descriptor, info)
+{
+ m_Data.ValidateInputsOutputs("ClMultiplicationWorkload", 2, 1);
+
+ arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ // Construct
+ m_PixelWiseMultiplication.configure(&input0,
+ &input1,
+ &output,
+ 1.0f,
+ arm_compute::ConvertPolicy::SATURATE,
+ arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
+}
+
+void ClMultiplicationWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClMultiplicationWorkload_Execute");
+
+ // Executes the layer.
+ m_PixelWiseMultiplication.run();
+}
+
+} //namespace armnn