From e662a940d3378cfe669ff7e259a6911713fc0df9 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Mon, 14 Oct 2019 15:12:00 +0100 Subject: IVGCVSW-3975 Add reference workload for LOG_SOFTMAX Signed-off-by: Aron Virginas-Tar Change-Id: I10bb7133e0e2d6d7199abdf39562b1226bbbd3e7 --- src/backends/backendsCommon/common.mk | 1 + src/backends/backendsCommon/test/CMakeLists.txt | 2 + src/backends/backendsCommon/test/LayerTests.hpp | 1 + .../test/layerTests/LogSoftmaxTestImpl.cpp | 251 ++++++++++++ .../test/layerTests/LogSoftmaxTestImpl.hpp | 33 ++ src/backends/reference/RefLayerSupport.cpp | 32 +- src/backends/reference/RefLayerSupport.hpp | 5 + src/backends/reference/RefWorkloadFactory.cpp | 430 +++++++++++---------- src/backends/reference/RefWorkloadFactory.hpp | 187 ++++----- src/backends/reference/backend.mk | 2 + src/backends/reference/test/RefLayerTests.cpp | 11 + src/backends/reference/workloads/CMakeLists.txt | 4 + src/backends/reference/workloads/LogSoftmax.cpp | 91 +++++ src/backends/reference/workloads/LogSoftmax.hpp | 20 + .../reference/workloads/RefLogSoftmaxWorkload.cpp | 36 ++ .../reference/workloads/RefLogSoftmaxWorkload.hpp | 21 + src/backends/reference/workloads/RefWorkloads.hpp | 1 + 17 files changed, 823 insertions(+), 305 deletions(-) create mode 100644 src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.cpp create mode 100644 src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.hpp create mode 100644 src/backends/reference/workloads/LogSoftmax.cpp create mode 100644 src/backends/reference/workloads/LogSoftmax.hpp create mode 100644 src/backends/reference/workloads/RefLogSoftmaxWorkload.cpp create mode 100644 src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk index 3da2259966..754a3a096c 100644 --- a/src/backends/backendsCommon/common.mk +++ b/src/backends/backendsCommon/common.mk @@ -57,6 +57,7 @@ COMMON_TEST_SOURCES := \ test/layerTests/GreaterTestImpl.cpp \ test/layerTests/InstanceNormalizationTestImpl.cpp \ test/layerTests/L2NormalizationTestImpl.cpp \ + test/layerTests/LogSoftmaxTestImpl.cpp \ test/layerTests/LstmTestImpl.cpp \ test/layerTests/MaximumTestImpl.cpp \ test/layerTests/MinimumTestImpl.cpp \ diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt index f7d58bf3d5..d353a77d15 100644 --- a/src/backends/backendsCommon/test/CMakeLists.txt +++ b/src/backends/backendsCommon/test/CMakeLists.txt @@ -91,6 +91,8 @@ list(APPEND armnnBackendsCommonUnitTests_sources layerTests/L2NormalizationTestImpl.cpp layerTests/L2NormalizationTestImpl.hpp layerTests/LayerTestResult.hpp + layerTests/LogSoftmaxTestImpl.cpp + layerTests/LogSoftmaxTestImpl.hpp layerTests/LstmTestImpl.cpp layerTests/LstmTestImpl.hpp layerTests/MaximumTestImpl.cpp diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp index 239d0d5e79..eb413140da 100644 --- a/src/backends/backendsCommon/test/LayerTests.hpp +++ b/src/backends/backendsCommon/test/LayerTests.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.cpp new file mode 100644 index 0000000000..0b73d37305 --- /dev/null +++ b/src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.cpp @@ -0,0 +1,251 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "LogSoftmaxTestImpl.hpp" + +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include + +#include + +namespace +{ + +template> +LayerTestResult LogSoftmaxTestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::TensorInfo& inputInfo, + const armnn::TensorInfo& outputInfo, + const std::vector& inputValues, + const std::vector& expectedOutputValues, + armnn::LogSoftmaxQueueDescriptor descriptor, + float qScale = 1.0f, + int32_t qOffset = 0) +{ + LayerTestResult result(outputInfo); + result.outputExpected = + MakeTensor(outputInfo, QuantizedVector(qScale, qOffset, expectedOutputValues)); + + std::unique_ptr inputHandle = workloadFactory.CreateTensorHandle(inputInfo); + std::unique_ptr outputHandle = workloadFactory.CreateTensorHandle(outputInfo); + + armnn::WorkloadInfo info; + + AddInputToWorkload(descriptor, info, inputInfo, inputHandle.get()); + AddOutputToWorkload(descriptor, info, outputInfo, outputHandle.get()); + + std::unique_ptr workload = workloadFactory.CreateLogSoftmax(descriptor, info); + + inputHandle->Allocate(); + outputHandle->Allocate(); + + auto inputTensor = MakeTensor(inputInfo, QuantizedVector(qScale, qOffset, inputValues)); + CopyDataToITensorHandle(inputHandle.get(), inputTensor.origin()); + + workload->Execute(); + + CopyDataFromITensorHandle(result.output.origin(), outputHandle.get()); + + return result; +} + +} // anonymous namespace + +template +LayerTestResult LogSoftmaxTest1( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + const armnn::TensorShape inputOutputShape{1, 1, 2, 4}; + + armnn::TensorInfo inputTensorInfo(inputOutputShape, ArmnnType); + armnn::TensorInfo outputTensorInfo(inputOutputShape, ArmnnType); + + std::vector inputValues + { + 0.f, -6.f, 2.f, 4.f, + 3.f, -2.f, 10.f, 1.f + }; + + std::vector expectedOutputValues + { + -4.14297f, -10.14297f, -2.14297f, -0.14297f, + -7.00104f, -12.00104f, -0.00105f, -9.00104f + }; + + armnn::LogSoftmaxQueueDescriptor descriptor; + descriptor.m_Parameters.m_Beta = 1.0f; // default beta + descriptor.m_Parameters.m_Axis = -1; // default axis + + return LogSoftmaxTestImpl( + workloadFactory, + memoryManager, + inputTensorInfo, + outputTensorInfo, + inputValues, + expectedOutputValues, + descriptor); +} + +template +LayerTestResult LogSoftmaxTest2( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + const armnn::TensorShape inputOutputShape{1, 1, 2, 4}; + + armnn::TensorInfo inputTensorInfo(inputOutputShape, ArmnnType); + armnn::TensorInfo outputTensorInfo(inputOutputShape, ArmnnType); + + std::vector inputValues + { + 0.f, -6.f, 2.f, 4.f, + 3.f, -2.f, 10.f, 1.f + }; + + std::vector expectedOutputValues + { + -4.14297f, -10.14297f, -2.14297f, -0.14297f, + -7.00104f, -12.00104f, -0.00105f, -9.00104f + }; + + armnn::LogSoftmaxQueueDescriptor descriptor; + descriptor.m_Parameters.m_Beta = 1.0f; // default beta + descriptor.m_Parameters.m_Axis = 3; // positive axis + + return LogSoftmaxTestImpl( + workloadFactory, + memoryManager, + inputTensorInfo, + outputTensorInfo, + inputValues, + expectedOutputValues, + descriptor); +} + +template +LayerTestResult LogSoftmaxTest3( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + const armnn::TensorShape inputOutputShape{1, 1, 2, 4}; + + armnn::TensorInfo inputTensorInfo(inputOutputShape, ArmnnType); + armnn::TensorInfo outputTensorInfo(inputOutputShape, ArmnnType); + + std::vector inputValues + { + 0.0f, -0.6f, 0.2f, 0.4f, + 0.3f, -0.2f, 1.0f, 0.1f + }; + + std::vector expectedOutputValues + { + -4.14297f, -10.14297f, -2.14297f, -0.14297f, + -7.00104f, -12.00104f, -0.00105f, -9.00104f + }; + + armnn::LogSoftmaxQueueDescriptor descriptor; + descriptor.m_Parameters.m_Beta = 10.0f; // non-default beta + descriptor.m_Parameters.m_Axis = 3; // positive axis + + return LogSoftmaxTestImpl( + workloadFactory, + memoryManager, + inputTensorInfo, + outputTensorInfo, + inputValues, + expectedOutputValues, + descriptor); +} + +template +LayerTestResult LogSoftmaxTest4( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) +{ + const armnn::TensorShape inputOutputShape{1, 1, 2, 4}; + + armnn::TensorInfo inputTensorInfo(inputOutputShape, ArmnnType); + armnn::TensorInfo outputTensorInfo(inputOutputShape, ArmnnType); + + std::vector inputValues + { + 0.f, -6.f, 2.f, 4.f, + 3.f, -2.f, 10.f, 1.f + }; + + std::vector expectedOutputValues + { + -3.048587f, -4.018149f, -8.000336f, -0.048587f, + -0.048587f, -0.018149f, -0.000335f, -3.048587f + }; + + armnn::LogSoftmaxQueueDescriptor descriptor; + descriptor.m_Parameters.m_Beta = 1.0f; // default beta + descriptor.m_Parameters.m_Axis = -2; // negative axis + + return LogSoftmaxTestImpl( + workloadFactory, + memoryManager, + inputTensorInfo, + outputTensorInfo, + inputValues, + expectedOutputValues, + descriptor); +} + +template LayerTestResult, 4> +LogSoftmaxTest1( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template LayerTestResult, 4> +LogSoftmaxTest2( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template LayerTestResult, 4> +LogSoftmaxTest3( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template LayerTestResult, 4> +LogSoftmaxTest4( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template LayerTestResult, 4> +LogSoftmaxTest1( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template LayerTestResult, 4> +LogSoftmaxTest2( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template LayerTestResult, 4> +LogSoftmaxTest3( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template LayerTestResult, 4> +LogSoftmaxTest4( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); diff --git a/src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.hpp new file mode 100644 index 0000000000..18a14ccd11 --- /dev/null +++ b/src/backends/backendsCommon/test/layerTests/LogSoftmaxTestImpl.hpp @@ -0,0 +1,33 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "LayerTestResult.hpp" + +#include + +#include +#include + +template> +LayerTestResult LogSoftmaxTest1( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template> +LayerTestResult LogSoftmaxTest2( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template> +LayerTestResult LogSoftmaxTest3( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); + +template> +LayerTestResult LogSoftmaxTest4( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager); diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 0d6b16cdf8..9342b29f47 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -897,6 +897,32 @@ bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + ignore_unused(descriptor); + + std::array supportedTypes = + { + DataType::Float32, + DataType::Float16 + }; + + bool supported = true; + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference LogSoftmax: input type not supported"); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference LogSoftmax: output type not supported"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference LogSoftmax: input and output types do not match"); + + return supported; +} + bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, @@ -1499,13 +1525,13 @@ bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input, }; supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, - "Reference concatenation: output type not supported"); + "Reference Softmax: output type not supported"); supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, - "Reference concatenation: input type not supported"); + "Reference Softmax: input type not supported"); supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, - "Reference concatenation: input type not supported"); + "Reference Softmax: input type not supported"); return supported; } diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 36080f7da4..5c71e8d337 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -149,6 +149,11 @@ public: const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported) const override; + bool IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 8c082749a4..1f6d1d7e8b 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -23,10 +23,9 @@ static const BackendId s_Id{RefBackendId()}; } template std::unique_ptr RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, - const WorkloadInfo& info) const + const WorkloadInfo& info) const { - return armnn::MakeWorkloadHelper(descriptor, - info); + return MakeWorkloadHelper(descriptor, info); } template @@ -95,285 +94,277 @@ std::unique_ptr RefWorkloadFactory::CreateTensorHandle(const Tens return std::make_unique(tensorInfo, m_MemoryManager); } -std::unique_ptr RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - if (info.m_InputTensorInfos.empty() ) - { - throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length"); - } - if (info.m_OutputTensorInfos.empty()) - { - throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length"); - } - - if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes()) - { - throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count."); - } - - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - if (info.m_InputTensorInfos.empty() ) - { - throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length"); - } - if (info.m_OutputTensorInfos.empty()) - { - throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length"); - } - if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes()) - { - throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count."); - } - - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor, - const WorkloadInfo& info) const + const WorkloadInfo& info) const { return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateBatchNormalization( + const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return CreateConcat(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateFullyConnected( - const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateConvertFp16ToFp32( + const ConvertFp16ToFp32QueueDescriptor& descriptor, + const WorkloadInfo& info) const { - if (IsQSymm16(info)) - { - return std::make_unique(descriptor, info); - } - return MakeWorkloadHelper(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateConvertFp32ToFp16( + const ConvertFp32ToFp16QueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateConvolution2d( - const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const { return std::make_unique(descriptor, info); } +std::unique_ptr RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + if (IsQSymm16(info)) + { + return std::make_unique(descriptor, info); + } + return MakeWorkload(descriptor, info); +} + std::unique_ptr RefWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor, const WorkloadInfo& info) const { return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateDepthwiseConvolution2d( - const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateDepthwiseConvolution2d( + const DepthwiseConvolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const { return std::make_unique(descriptor, info); } +std::unique_ptr RefWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info); +} + std::unique_ptr RefWorkloadFactory::CreateDetectionPostProcess( - const armnn::DetectionPostProcessQueueDescriptor& descriptor, const armnn::WorkloadInfo& info) const + const DetectionPostProcessQueueDescriptor& descriptor, + const WorkloadInfo& info) const { return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateNormalization( - const NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateMultiplication( - const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateFakeQuantization( + const FakeQuantizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return MakeWorkload(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateBatchNormalization( - const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateFullyConnected( + const FullyConnectedQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - if (descriptor.m_Inputs.empty()) - { - throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor."); - } - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - if (descriptor.m_Inputs.empty()) - { - throw InvalidArgumentException("RefWorkloadFactory: CreateMemImport() expected an input tensor."); - } - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - ResizeQueueDescriptor resizeDescriptor; - resizeDescriptor.m_Parameters.m_Method = ResizeMethod::Bilinear; - resizeDescriptor.m_Parameters.m_DataLayout = descriptor.m_Parameters.m_DataLayout; - resizeDescriptor.m_Parameters.m_TargetWidth = descriptor.m_Parameters.m_TargetWidth; - resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight; + if (info.m_InputTensorInfos.empty() ) + { + throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length"); + } + if (info.m_OutputTensorInfos.empty()) + { + throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length"); + } - return CreateResize(resizeDescriptor, info); + if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes()) + { + throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count."); + } + + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateFakeQuantization( - const FakeQuantizationQueueDescriptor& descriptor, +std::unique_ptr RefWorkloadFactory::CreateInstanceNormalization( + const InstanceNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const + const WorkloadInfo& info) const { return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return std::make_unique(descriptor, info); -} - -std::unique_ptr RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor, - const WorkloadInfo& info) const -{ - return std::make_unique(descriptor, info); -} - -std::unique_ptr RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateSpaceToDepth(const armnn::SpaceToDepthQueueDescriptor& descriptor, - const armnn::WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor, +std::unique_ptr RefWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (descriptor.m_Inputs.empty()) + { + throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor."); + } + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateConvertFp16ToFp32( - const ConvertFp16ToFp32QueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (descriptor.m_Inputs.empty()) + { + throw InvalidArgumentException("RefWorkloadFactory: CreateMemImport() expected an input tensor."); + } + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateConvertFp32ToFp16( - const ConvertFp32ToFp16QueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return CreateConcat(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateDivision( - const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateSubtraction( - const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateMaximum( - const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateMean( - const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); -} + if (info.m_InputTensorInfos.empty() ) + { + throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length"); + } + if (info.m_OutputTensorInfos.empty()) + { + throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length"); + } + if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes()) + { + throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count."); + } -std::unique_ptr RefWorkloadFactory::CreateMinimum( - const MinimumQueueDescriptor& descriptor, const WorkloadInfo& info) const -{ - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor, - const WorkloadInfo& info) const + const WorkloadInfo& info) const { if (IsQSymm16(info)) { @@ -386,81 +377,99 @@ std::unique_ptr RefWorkloadFactory::CreatePad(const PadQueueDescripto return MakeWorkload(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + if (IsQSymm16(info)) + { + return std::make_unique(descriptor, info); + } + return MakeWorkloadHelper(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return nullptr; } -std::unique_ptr RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, +std::unique_ptr RefWorkloadFactory::CreatePrelu(const PreluQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info); +} + +std::unique_ptr RefWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info); +} + +std::unique_ptr RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - if (IsQSymm16(info)) - { - return std::make_unique(descriptor, info); - } - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + ResizeQueueDescriptor resizeDescriptor; + resizeDescriptor.m_Parameters.m_Method = ResizeMethod::Bilinear; + resizeDescriptor.m_Parameters.m_DataLayout = descriptor.m_Parameters.m_DataLayout; + resizeDescriptor.m_Parameters.m_TargetWidth = descriptor.m_Parameters.m_TargetWidth; + resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight; + + return CreateResize(resizeDescriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor, - const armnn::WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return nullptr; + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreatePrelu(const PreluQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateTransposeConvolution2d( - const TransposeConvolution2dQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor, @@ -469,22 +478,23 @@ std::unique_ptr RefWorkloadFactory::CreateStack(const StackQueueDescr return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor, - const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } -std::unique_ptr RefWorkloadFactory::CreateInstanceNormalization( - const InstanceNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const +std::unique_ptr RefWorkloadFactory::CreateTransposeConvolution2d( + const TransposeConvolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + return std::make_unique(descriptor, info); } } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 0a1fab127c..41e9b28ea2 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -60,174 +60,177 @@ public: DataLayout dataLayout, const bool IsMemoryManaged = true) const override; - std::unique_ptr CreateInput(const InputQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; - - std::unique_ptr CreateOutput(const OutputQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateAbs(const AbsQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; std::unique_ptr CreateActivation(const ActivationQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateAddition(const AdditionQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateSplitter(const SplitterQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - ARMNN_DEPRECATED_MSG("Use CreateConcat instead") - std::unique_ptr CreateMerger(const MergerQueueDescriptor& descriptor, + std::unique_ptr CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateConstant(const ConstantQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreatePermute(const PermuteQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor, + const WorkloadInfo& info) const override; std::unique_ptr CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateDebug(const DebugQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor, const WorkloadInfo& info) const override; std::unique_ptr CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateDequantize(const DequantizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateNormalization(const NormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; - - std::unique_ptr CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; - - std::unique_ptr CreateAddition(const AdditionQueueDescriptor& descriptor, + std::unique_ptr CreateDivision(const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateEqual(const EqualQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateMemCopy(const MemCopyQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateMemImport(const MemImportQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateFloor(const FloorQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateResize(const ResizeQueueDescriptor& descriptor, + std::unique_ptr CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + + std::unique_ptr CreateGather(const GatherQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateGreater(const GreaterQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateInput(const InputQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + + std::unique_ptr CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; std::unique_ptr CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateConcat(const ConcatQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateConstant(const ConstantQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateLstm(const LstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateReshape(const ReshapeQueueDescriptor& descriptor, + std::unique_ptr CreateMaximum(const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateMean(const MeanQueueDescriptor& descriptor, + const WorkloadInfo& Info) const override; - std::unique_ptr CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateMemCopy(const MemCopyQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateFloor(const FloorQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateMemImport(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateLstm(const LstmQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateConcat instead") + std::unique_ptr CreateMerger(const MergerQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateMinimum(const MinimumQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateDivision(const DivisionQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateNormalization(const NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateSubtraction(const SubtractionQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateOutput(const OutputQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateMaximum(const MaximumQueueDescriptor& descriptor, + std::unique_ptr CreatePad(const PadQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + + std::unique_ptr CreatePermute(const PermuteQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateMean(const MeanQueueDescriptor& descriptor, - const WorkloadInfo& Info) const override; + std::unique_ptr CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreatePad(const PadQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateEqual(const EqualQueueDescriptor& descriptor, + std::unique_ptr CreatePrelu(const PreluQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; - - std::unique_ptr CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateQuantize(const QuantizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateMinimum(const MinimumQueueDescriptor& descriptor, + std::unique_ptr CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateGreater(const GreaterQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateResize(const ResizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateDebug(const DebugQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateResize instead") + std::unique_ptr CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; std::unique_ptr CreateRsqrt(const RsqrtQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; - - std::unique_ptr CreateGather(const GatherQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateSlice(const SliceQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateDequantize(const DequantizeQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateQuantize(const QuantizeQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreatePrelu(const PreluQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateSplitter(const SplitterQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; std::unique_ptr CreateStack(const StackQueueDescriptor& descriptor, const WorkloadInfo& info) const override; - std::unique_ptr CreateAbs(const AbsQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateSlice(const SliceQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateSubtraction(const SubtractionQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; - std::unique_ptr CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor, - const WorkloadInfo& info) const override; + std::unique_ptr CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; private: - template std::unique_ptr MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index f45b01549a..49b07a41d2 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -35,6 +35,7 @@ BACKEND_SOURCES := \ workloads/FullyConnected.cpp \ workloads/Gather.cpp \ workloads/InstanceNorm.cpp \ + workloads/LogSoftmax.cpp \ workloads/LstmUtils.cpp \ workloads/Mean.cpp \ workloads/Concatenate.cpp \ @@ -63,6 +64,7 @@ BACKEND_SOURCES := \ workloads/RefGatherWorkload.cpp \ workloads/RefInstanceNormalizationWorkload.cpp \ workloads/RefL2NormalizationWorkload.cpp \ + workloads/RefLogSoftmaxWorkload.cpp \ workloads/RefLstmWorkload.cpp \ workloads/RefMeanWorkload.cpp \ workloads/RefNormalizationWorkload.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index cef3a800ac..5de9b752ca 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -868,6 +868,17 @@ ARMNN_AUTO_TEST_CASE(L2Normalization2dShape, L2Normalization2dShapeTest); ARMNN_AUTO_TEST_CASE(L2NormalizationDefaultEpsilon, L2NormalizationDefaultEpsilonTest, DataLayout::NCHW) ARMNN_AUTO_TEST_CASE(L2NormalizationNonDefaultEpsilon, L2NormalizationNonDefaultEpsilonTest, DataLayout::NCHW) +// LogSoftmax +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat32_1, LogSoftmaxTest1) +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat32_2, LogSoftmaxTest2) +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat32_3, LogSoftmaxTest3) +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat32_4, LogSoftmaxTest4) + +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat16_1, LogSoftmaxTest1) +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat16_2, LogSoftmaxTest2) +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat16_3, LogSoftmaxTest3) +ARMNN_AUTO_TEST_CASE(LogSoftmaxFloat16_4, LogSoftmaxTest4) + // Pad ARMNN_AUTO_TEST_CASE(PadFloat322d, PadFloat322dTest) ARMNN_AUTO_TEST_CASE(PadFloat322dCustomPadding, PadFloat322dCustomPaddingTest) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 9a5f427d37..b8eb95c729 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -37,6 +37,8 @@ list(APPEND armnnRefBackendWorkloads_sources Gather.hpp InstanceNorm.cpp InstanceNorm.hpp + LogSoftmax.cpp + LogSoftmax.hpp LstmUtils.hpp LstmUtils.cpp Maximum.hpp @@ -95,6 +97,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefInstanceNormalizationWorkload.hpp RefL2NormalizationWorkload.cpp RefL2NormalizationWorkload.hpp + RefLogSoftmaxWorkload.cpp + RefLogSoftmaxWorkload.hpp RefLstmWorkload.cpp RefLstmWorkload.hpp RefMeanWorkload.cpp diff --git a/src/backends/reference/workloads/LogSoftmax.cpp b/src/backends/reference/workloads/LogSoftmax.cpp new file mode 100644 index 0000000000..3fa3dc0d8c --- /dev/null +++ b/src/backends/reference/workloads/LogSoftmax.cpp @@ -0,0 +1,91 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "LogSoftmax.hpp" + +#include + +#include + +#include +#include +#include + +namespace +{ + +inline bool ValidateAxis(int axis, unsigned int numDimensions) +{ + const int sNumDimensions = boost::numeric_cast(numDimensions); + return axis < sNumDimensions && axis >= -sNumDimensions; +} + +} // anonymous namespace + +namespace armnn +{ + +void LogSoftmax(Decoder& input, + Encoder& output, + const TensorInfo& inputInfo, + const LogSoftmaxDescriptor& descriptor) +{ + const unsigned int numDimensions = inputInfo.GetNumDimensions(); + + bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions); + BOOST_ASSERT_MSG(axisIsValid, + "Axis index is not in range [-numDimensions, numDimensions)."); + boost::ignore_unused(axisIsValid); + + unsigned int uAxis = descriptor.m_Axis < 0 ? + numDimensions - boost::numeric_cast(std::abs(descriptor.m_Axis)) : + boost::numeric_cast(descriptor.m_Axis); + + const TensorShape& inputShape = inputInfo.GetShape(); + const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis); + const unsigned int axisSize = inputShape[uAxis]; + const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape, + uAxis + 1, + inputShape.GetNumDimensions()); + + for (unsigned int outer = 0; outer < outerSize; ++outer) + { + for (unsigned int inner = 0; inner < innerSize; ++inner) + { + // Find max + input[outer * axisSize * innerSize + inner]; + float maxValue = input.Get(); + for (unsigned int i = 1u; i < axisSize; ++i) + { + input[(outer * axisSize + i) * innerSize + inner]; + maxValue = std::max(maxValue, input.Get()); + } + + // Compute sum + float sum = 0.0f; + for (unsigned int i = 0u; i < axisSize; ++i) + { + input[(outer * axisSize + i) * innerSize + inner]; + sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta); + } + + // Compute log sum + const float logSum = std::log(sum); + + // Compute result + for (unsigned int i = 0u; i < axisSize; ++i) + { + const unsigned int index = (outer * axisSize + i) * innerSize + inner; + + input [index]; + output[index]; + + output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum); + } + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/LogSoftmax.hpp b/src/backends/reference/workloads/LogSoftmax.hpp new file mode 100644 index 0000000000..2e383992c9 --- /dev/null +++ b/src/backends/reference/workloads/LogSoftmax.hpp @@ -0,0 +1,20 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "BaseIterator.hpp" + +#include + +namespace armnn +{ + +void LogSoftmax(Decoder& input, + Encoder& output, + const TensorInfo& inputInfo, + const LogSoftmaxDescriptor& descriptor); + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefLogSoftmaxWorkload.cpp b/src/backends/reference/workloads/RefLogSoftmaxWorkload.cpp new file mode 100644 index 0000000000..a987e79dda --- /dev/null +++ b/src/backends/reference/workloads/RefLogSoftmaxWorkload.cpp @@ -0,0 +1,36 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefLogSoftmaxWorkload.hpp" + +#include "Decoders.hpp" +#include "Encoders.hpp" +#include "LogSoftmax.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +#include + +namespace armnn +{ + +void RefLogSoftmaxWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefLogSoftmaxWorkload_Execute"); + + const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + std::unique_ptr> decoder = MakeDecoder(inputInfo, m_Data.m_Inputs[0]->Map()); + std::unique_ptr> encoder = MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()); + + BOOST_ASSERT(decoder != nullptr); + BOOST_ASSERT(encoder != nullptr); + + LogSoftmax(*decoder, *encoder, inputInfo, m_Data.m_Parameters); +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp b/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp new file mode 100644 index 0000000000..f5048d90b3 --- /dev/null +++ b/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp @@ -0,0 +1,21 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class RefLogSoftmaxWorkload : public BaseWorkload +{ +public: + using BaseWorkload::BaseWorkload; + virtual void Execute() const override; +}; + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 39dfa0517b..79d1935823 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -38,6 +38,7 @@ #include "RefGatherWorkload.hpp" #include "RefInstanceNormalizationWorkload.hpp" #include "RefL2NormalizationWorkload.hpp" +#include "RefLogSoftmaxWorkload.hpp" #include "RefLstmWorkload.hpp" #include "RefMeanWorkload.hpp" #include "RefNormalizationWorkload.hpp" -- cgit v1.2.1