From b4540bef0b0327683fe8e63f727c1212800dc2a9 Mon Sep 17 00:00:00 2001 From: David Beck Date: Mon, 24 Sep 2018 13:18:27 +0100 Subject: IVGCVSW-1898 : Ref backend folder structure * Reference backend is renamed to backends/reference as per https://confluence.arm.com/display/MLENG/Pluggable+backends Change-Id: I27a13c274eb60995dfb459e3c49c0e2f60bcd32c --- src/backends/reference/RefLayerSupport.cpp | 398 +++++++++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 src/backends/reference/RefLayerSupport.cpp (limited to 'src/backends/reference/RefLayerSupport.cpp') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp new file mode 100644 index 0000000000..d56cdebeda --- /dev/null +++ b/src/backends/reference/RefLayerSupport.cpp @@ -0,0 +1,398 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "LayerSupportCommon.hpp" +#include "RefLayerSupport.hpp" +#include +#include +#include + +#include +#include "InternalTypes.hpp" + +using namespace boost; + +namespace armnn +{ + +template +bool IsSupportedForDataTypeRef(std::string* reasonIfUnsupported, + DataType dataType, + Float32Func floatFuncPtr, + Uint8Func uint8FuncPtr, + Params&&... params) +{ + return IsSupportedForDataTypeGeneric(reasonIfUnsupported, + dataType, + &FalseFunc, + floatFuncPtr, + uint8FuncPtr, + std::forward(params)...); +} + +bool IsActivationSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const ActivationDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(output); + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsAdditionSupportedRef(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsBatchNormalizationSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const TensorInfo& mean, + const TensorInfo& var, + const TensorInfo& beta, + const TensorInfo& gamma, + const BatchNormalizationDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsConstantSupportedRef(const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + return IsSupportedForDataTypeRef(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsConvolution2dSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const Convolution2dDescriptor& descriptor, + const TensorInfo& weights, + const boost::optional& biases, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + ignore_unused(output); + ignore_unused(weights); + ignore_unused(biases); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const DepthwiseConvolution2dDescriptor& descriptor, + const TensorInfo& weights, + const boost::optional& biases, + std::string* reasonIfUnsupported) +{ + ignore_unused(output); + ignore_unused(descriptor); + ignore_unused(weights); + ignore_unused(biases); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsDivisionSupportedRef(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsSubtractionSupportedRef(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsFullyConnectedSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const TensorInfo& weights, + const TensorInfo& biases, + const FullyConnectedDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(output); + ignore_unused(descriptor); + ignore_unused(weights); + ignore_unused(biases); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsInputSupportedRef(const TensorInfo& input, + std::string* reasonIfUnsupported) +{ + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsL2NormalizationSupportedRef(const TensorInfo& input, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); +} + +bool IsMergerSupportedRef(const std::vector inputs, + const OriginsDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + inputs[0]->GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsMultiplicationSupportedRef(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + ignore_unused(input1); + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsNormalizationSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const NormalizationDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); +} + +bool IsOutputSupportedRef(const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + return IsSupportedForDataTypeRef(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsPermuteSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const PermuteDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsPooling2dSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const Pooling2dDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsResizeBilinearSupportedRef(const TensorInfo& input, + std::string* reasonIfUnsupported) +{ + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsSoftmaxSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const SoftmaxDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(output); + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsSplitterSupportedRef(const TensorInfo& input, + const ViewsDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsFakeQuantizationSupportedRef(const TensorInfo& input, + const FakeQuantizationDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + ignore_unused(descriptor); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); +} + +bool IsReshapeSupportedRef(const TensorInfo& input, + std::string* reasonIfUnsupported) +{ + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + +bool IsFloorSupportedRef(const TensorInfo& input, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + ignore_unused(output); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); +} + +bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, + const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, + const TensorInfo& output, const LstmDescriptor& descriptor, + const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, + const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, + const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, + const TensorInfo& forgetGateBias, const TensorInfo& cellBias, + const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, + const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, + const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, + const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, + const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported) +{ + ignore_unused(input); + ignore_unused(outputStateIn); + ignore_unused(cellStateIn); + ignore_unused(scratchBuffer); + ignore_unused(outputStateOut); + ignore_unused(cellStateOut); + ignore_unused(output); + ignore_unused(descriptor); + ignore_unused(inputToForgetWeights); + ignore_unused(inputToCellWeights); + ignore_unused(inputToOutputWeights); + ignore_unused(recurrentToForgetWeights); + ignore_unused(recurrentToCellWeights); + ignore_unused(recurrentToOutputWeights); + ignore_unused(forgetGateBias); + ignore_unused(cellBias); + ignore_unused(outputGateBias); + ignore_unused(inputToInputWeights); + ignore_unused(recurrentToInputWeights); + ignore_unused(cellToInputWeights); + ignore_unused(inputGateBias); + ignore_unused(projectionWeights); + ignore_unused(projectionBias); + ignore_unused(cellToForgetWeights); + ignore_unused(cellToOutputWeights); + return false; +} + +bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseInputFuncF32<>, + &FalseFuncU8<>) && + IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &FalseOutputFuncF16<>, + &TrueFunc<>, + &FalseFuncU8<>)); +} + +bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, + input.GetDataType(), + &FalseInputFuncF16<>, + &TrueFunc<>, + &FalseFuncU8<>) && + IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &FalseOutputFuncF32<>, + &FalseFuncU8<>)); +} + +bool IsMeanSupportedRef(const TensorInfo& input, + const TensorInfo& output, + const MeanDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + return false; +} + +} -- cgit v1.2.1