diff options
Diffstat (limited to 'src/backends/gpuFsa/workloads/GpuFsaConstantWorkload.cpp')
-rw-r--r-- | src/backends/gpuFsa/workloads/GpuFsaConstantWorkload.cpp | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/src/backends/gpuFsa/workloads/GpuFsaConstantWorkload.cpp b/src/backends/gpuFsa/workloads/GpuFsaConstantWorkload.cpp new file mode 100644 index 0000000000..39d3c0ddab --- /dev/null +++ b/src/backends/gpuFsa/workloads/GpuFsaConstantWorkload.cpp @@ -0,0 +1,114 @@ +// +// Copyright © 2024 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "GpuFsaConstantWorkload.hpp" +#include "GpuFsaWorkloadUtils.hpp" + +#include <Half.hpp> +#include <aclCommon/ArmComputeTensorUtils.hpp> +#include <gpuFsa/GpuFsaTensorHandle.hpp> +#include <armnn/backends/TensorHandle.hpp> + +namespace armnn +{ + +arm_compute::Status GpuFsaConstantWorkloadValidate(const TensorInfo& output) +{ + const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output); + + std::array<arm_compute::DataType,8> supportedTypes = { + arm_compute::DataType::F16, + arm_compute::DataType::F32, + arm_compute::DataType::QASYMM8, + arm_compute::DataType::QASYMM8_SIGNED, + arm_compute::DataType::QSYMM16, + arm_compute::DataType::QSYMM8, + arm_compute::DataType::QSYMM8_PER_CHANNEL, + arm_compute::DataType::S32 + }; + auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type()); + + if (it != end(supportedTypes)) + { + return arm_compute::Status{}; + } + else + { + return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"}; + } +} + +GpuFsaConstantWorkload::GpuFsaConstantWorkload(const ConstantQueueDescriptor& descriptor, + const WorkloadInfo& info, + const arm_compute::CLCompileContext&) + : GpuFsaBaseWorkload<ConstantQueueDescriptor>(descriptor, info) + , m_RanOnce(false) +{ +} + +void GpuFsaConstantWorkload::Execute() const +{ + // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data + // on the first inference, then reused for subsequent inferences. + // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not + // have been configured at the time. + if (!m_RanOnce) + { + const ConstantQueueDescriptor& data = this->m_Data; + + ARMNN_ASSERT(data.m_LayerOutput != nullptr); + arm_compute::CLTensor& output = static_cast<GpuFsaTensorHandle*>(data.m_Outputs[0])->GetTensor(); + arm_compute::DataType computeDataType = static_cast<GpuFsaTensorHandle*>(data.m_Outputs[0])->GetDataType(); + + switch (computeDataType) + { + case arm_compute::DataType::F16: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>()); + break; + } + case arm_compute::DataType::F32: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>()); + break; + } + case arm_compute::DataType::QASYMM8: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>()); + break; + } + case arm_compute::DataType::QASYMM8_SIGNED: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>()); + break; + } + case arm_compute::DataType::QSYMM16: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int16_t>()); + break; + } + case arm_compute::DataType::QSYMM8: + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>()); + break; + } + case arm_compute::DataType::S32: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>()); + break; + } + default: + { + ARMNN_ASSERT_MSG(false, "Unknown data type"); + break; + } + } + + m_RanOnce = true; + } +} + +} //namespace armnn |