// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include "RefLayerSupport.hpp" #include #include #include #include #include #include "InternalTypes.hpp" using namespace boost; namespace armnn { bool RefLayerSupport::IsActivationSupported(const TensorInfo& input, const TensorInfo& output, const ActivationDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsActivationSupportedRef(input, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsAdditionSupportedRef(input0, input1, output, reasonIfUnsupported); } bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, const TensorInfo& var, const TensorInfo& beta, const TensorInfo& gamma, const BatchNormalizationDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsBatchNormalizationSupportedRef(input, output, mean, var, beta, gamma, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsConstantSupported(const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsConstantSupportedRef(output, reasonIfUnsupported); } bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsConvertFp16ToFp32SupportedRef(input, output, reasonIfUnsupported); } bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsConvertFp32ToFp16SupportedRef(input, output, reasonIfUnsupported); } bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input, const TensorInfo& output, const Convolution2dDescriptor& descriptor, const TensorInfo& weights, const Optional& biases, Optional reasonIfUnsupported) const { return armnn::IsConvolution2dSupportedRef(input, output, descriptor, weights, biases, reasonIfUnsupported); } bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input, const TensorInfo& output, const DepthwiseConvolution2dDescriptor& descriptor, const TensorInfo& weights, const Optional& biases, Optional reasonIfUnsupported) const { return armnn::IsDepthwiseConvolutionSupportedRef(input, output, descriptor, weights, biases, reasonIfUnsupported); } bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsDivisionSupportedRef(input0, input1, output, reasonIfUnsupported); } bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input, const FakeQuantizationDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsFakeQuantizationSupportedRef(input, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsFloorSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsFloorSupportedRef(input, output, reasonIfUnsupported); } bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& weights, const TensorInfo& biases, const FullyConnectedDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsFullyConnectedSupportedRef(input, output, weights, biases, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported) const { return armnn::IsInputSupportedRef(input, reasonIfUnsupported); } bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsL2NormalizationSupportedRef(input, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, const TensorInfo& forgetGateBias, const TensorInfo& cellBias, const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, const TensorInfo* cellToOutputWeights, Optional reasonIfUnsupported) const { return armnn::IsLstmSupportedRef(input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut, output, descriptor, inputToForgetWeights, inputToCellWeights, inputToOutputWeights, recurrentToForgetWeights, recurrentToCellWeights, recurrentToOutputWeights, forgetGateBias, cellBias, outputGateBias, inputToInputWeights, recurrentToInputWeights, cellToInputWeights, inputGateBias, projectionWeights, projectionBias, cellToForgetWeights, cellToOutputWeights, reasonIfUnsupported); } bool RefLayerSupport::IsMeanSupported(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsMeanSupportedRef(input, output, descriptor,reasonIfUnsupported); } bool RefLayerSupport::IsMergerSupported(const std::vector inputs, const OriginsDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsMergerSupportedRef(inputs, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsMultiplicationSupportedRef(input0, input1, output, reasonIfUnsupported); } bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const NormalizationDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsNormalizationSupportedRef(input, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsOutputSupported(const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsOutputSupportedRef(output, reasonIfUnsupported); } bool RefLayerSupport::IsPadSupported(const TensorInfo& input, const TensorInfo& output, const PadDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsPadSupportedRef(input, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input, const TensorInfo& output, const PermuteDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsPermuteSupportedRef(input, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input, const TensorInfo& output, const Pooling2dDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsPooling2dSupportedRef(input, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input, Optional reasonIfUnsupported) const { return armnn::IsReshapeSupportedRef(input, reasonIfUnsupported); } bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input, Optional reasonIfUnsupported) const { return armnn::IsResizeBilinearSupportedRef(input, reasonIfUnsupported); } bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input, const TensorInfo& output, const SoftmaxDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsSoftmaxSupportedRef(input, output, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input, const ViewsDescriptor& descriptor, Optional reasonIfUnsupported) const { return armnn::IsSplitterSupportedRef(input, descriptor, reasonIfUnsupported); } bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) const { return armnn::IsSubtractionSupportedRef(input0, input1, output, reasonIfUnsupported); } // // Implementation functions // // TODO: Functions kept for backward compatibility. Remove once transition to plugable backends is complete! template bool IsSupportedForDataTypeRef(Optional reasonIfUnsupported, DataType dataType, Float32Func floatFuncPtr, Uint8Func uint8FuncPtr, Params&&... params) { return IsSupportedForDataTypeGeneric(reasonIfUnsupported, dataType, &FalseFunc, floatFuncPtr, uint8FuncPtr, std::forward(params)...); } bool IsActivationSupportedRef(const TensorInfo& input, const TensorInfo& output, const ActivationDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsAdditionSupportedRef(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) { ignore_unused(input1); ignore_unused(output); return IsSupportedForDataTypeRef(reasonIfUnsupported, input0.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsBatchNormalizationSupportedRef(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, const TensorInfo& var, const TensorInfo& beta, const TensorInfo& gamma, const BatchNormalizationDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsConstantSupportedRef(const TensorInfo& output, Optional reasonIfUnsupported) { return IsSupportedForDataTypeRef(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsConvolution2dSupportedRef(const TensorInfo& input, const TensorInfo& output, const Convolution2dDescriptor& descriptor, const TensorInfo& weights, const Optional& biases, Optional reasonIfUnsupported) { ignore_unused(descriptor); ignore_unused(output); ignore_unused(weights); ignore_unused(biases); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input, const TensorInfo& output, const DepthwiseConvolution2dDescriptor& descriptor, const TensorInfo& weights, const Optional& biases, Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); ignore_unused(weights); ignore_unused(biases); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsDivisionSupportedRef(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) { ignore_unused(input1); ignore_unused(output); return IsSupportedForDataTypeRef(reasonIfUnsupported, input0.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsSubtractionSupportedRef(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) { ignore_unused(input1); ignore_unused(output); return IsSupportedForDataTypeRef(reasonIfUnsupported, input0.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsFullyConnectedSupportedRef(const TensorInfo& input, const TensorInfo& output, const TensorInfo& weights, const TensorInfo& biases, const FullyConnectedDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); ignore_unused(weights); ignore_unused(biases); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsInputSupportedRef(const TensorInfo& input, Optional reasonIfUnsupported) { return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsL2NormalizationSupportedRef(const TensorInfo& input, const TensorInfo& output, const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &FalseFuncU8<>); } bool IsMergerSupportedRef(const std::vector inputs, const OriginsDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, inputs[0]->GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsMultiplicationSupportedRef(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported) { ignore_unused(input1); ignore_unused(output); return IsSupportedForDataTypeRef(reasonIfUnsupported, input0.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsNormalizationSupportedRef(const TensorInfo& input, const TensorInfo& output, const NormalizationDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &FalseFuncU8<>); } bool IsOutputSupportedRef(const TensorInfo& output, Optional reasonIfUnsupported) { return IsSupportedForDataTypeRef(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsPermuteSupportedRef(const TensorInfo& input, const TensorInfo& output, const PermuteDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsPooling2dSupportedRef(const TensorInfo& input, const TensorInfo& output, const Pooling2dDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsResizeBilinearSupportedRef(const TensorInfo& input, Optional reasonIfUnsupported) { return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsSoftmaxSupportedRef(const TensorInfo& input, const TensorInfo& output, const SoftmaxDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsSplitterSupportedRef(const TensorInfo& input, const ViewsDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsFakeQuantizationSupportedRef(const TensorInfo& input, const FakeQuantizationDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &FalseFuncU8<>); } bool IsReshapeSupportedRef(const TensorInfo& input, Optional reasonIfUnsupported) { return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsFloorSupportedRef(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) { ignore_unused(output); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &FalseFuncU8<>); } bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, const TensorInfo& forgetGateBias, const TensorInfo& cellBias, const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, const TensorInfo* cellToOutputWeights, Optional reasonIfUnsupported) { ignore_unused(input); ignore_unused(outputStateIn); ignore_unused(cellStateIn); ignore_unused(scratchBuffer); ignore_unused(outputStateOut); ignore_unused(cellStateOut); ignore_unused(output); ignore_unused(descriptor); ignore_unused(inputToForgetWeights); ignore_unused(inputToCellWeights); ignore_unused(inputToOutputWeights); ignore_unused(recurrentToForgetWeights); ignore_unused(recurrentToCellWeights); ignore_unused(recurrentToOutputWeights); ignore_unused(forgetGateBias); ignore_unused(cellBias); ignore_unused(outputGateBias); ignore_unused(inputToInputWeights); ignore_unused(recurrentToInputWeights); ignore_unused(cellToInputWeights); ignore_unused(inputGateBias); ignore_unused(projectionWeights); ignore_unused(projectionBias); ignore_unused(cellToForgetWeights); ignore_unused(cellToOutputWeights); ignore_unused(reasonIfUnsupported); return false; } bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) { return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &FalseInputFuncF32<>, &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &FalseOutputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>)); } bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) { return (IsSupportedForDataTypeGeneric(reasonIfUnsupported, input.GetDataType(), &FalseInputFuncF16<>, &TrueFunc<>, &FalseFuncU8<>) && IsSupportedForDataTypeGeneric(reasonIfUnsupported, output.GetDataType(), &TrueFunc<>, &FalseOutputFuncF32<>, &FalseFuncU8<>)); } bool IsMeanSupportedRef(const TensorInfo& input, const TensorInfo& output, const MeanDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); return IsSupportedForDataTypeRef(reasonIfUnsupported, input.GetDataType(), &TrueFunc<>, &TrueFunc<>); } bool IsPadSupportedRef(const TensorInfo& input, const TensorInfo& output, const PadDescriptor& descriptor, Optional reasonIfUnsupported) { ignore_unused(output); ignore_unused(descriptor); return false; } } // namespace armnn