aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r--src/backends/cl/workloads/ClSoftmaxBaseWorkload.cpp7
-rw-r--r--src/backends/cl/workloads/ClSoftmaxBaseWorkload.hpp4
-rw-r--r--src/backends/cl/workloads/ClSoftmaxFloatWorkload.cpp2
3 files changed, 9 insertions, 4 deletions
diff --git a/src/backends/cl/workloads/ClSoftmaxBaseWorkload.cpp b/src/backends/cl/workloads/ClSoftmaxBaseWorkload.cpp
index b1dc404a6f..2f6d380f94 100644
--- a/src/backends/cl/workloads/ClSoftmaxBaseWorkload.cpp
+++ b/src/backends/cl/workloads/ClSoftmaxBaseWorkload.cpp
@@ -6,6 +6,7 @@
#include "ClSoftmaxBaseWorkload.hpp"
#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <aclCommon/ArmComputeUtils.hpp>
#include <arm_compute/runtime/CL/functions/CLSoftmaxLayer.h>
@@ -13,12 +14,14 @@ namespace armnn
{
arm_compute::Status ClSoftmaxWorkloadValidate(const TensorInfo& input,
- const TensorInfo& output)
+ const TensorInfo& output,
+ const SoftmaxDescriptor& descriptor)
{
const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
- return arm_compute::CLSoftmaxLayer::validate(&aclInputInfo, &aclOutputInfo);
+ unsigned int aclAxis = ComputeSoftmaxAclAxis(input);
+ return arm_compute::CLSoftmaxLayer::validate(&aclInputInfo, &aclOutputInfo, descriptor.m_Beta, aclAxis);
}
}
diff --git a/src/backends/cl/workloads/ClSoftmaxBaseWorkload.hpp b/src/backends/cl/workloads/ClSoftmaxBaseWorkload.hpp
index b800056cdf..8d73060162 100644
--- a/src/backends/cl/workloads/ClSoftmaxBaseWorkload.hpp
+++ b/src/backends/cl/workloads/ClSoftmaxBaseWorkload.hpp
@@ -5,6 +5,7 @@
#pragma once
+#include <armnn/Descriptors.hpp>
#include <armnn/Tensor.hpp>
#include <arm_compute/core/Error.h>
@@ -12,6 +13,7 @@ namespace armnn
{
arm_compute::Status ClSoftmaxWorkloadValidate(const TensorInfo& input,
- const TensorInfo& output);
+ const TensorInfo& output,
+ const SoftmaxDescriptor& descriptor);
} // namespace armnn
diff --git a/src/backends/cl/workloads/ClSoftmaxFloatWorkload.cpp b/src/backends/cl/workloads/ClSoftmaxFloatWorkload.cpp
index c78ab039ef..f2f8d17901 100644
--- a/src/backends/cl/workloads/ClSoftmaxFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClSoftmaxFloatWorkload.cpp
@@ -14,7 +14,7 @@ namespace armnn
{
ClSoftmaxFloatWorkload::ClSoftmaxFloatWorkload(const SoftmaxQueueDescriptor& descriptor, const WorkloadInfo& info,
- std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
+ std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
: FloatWorkload<SoftmaxQueueDescriptor>(descriptor, info)
, m_SoftmaxLayer(memoryManager)
{