diff options
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r-- | src/backends/tosaReference/TosaRefLayerSupport.cpp | 170 |
1 files changed, 148 insertions, 22 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp index a39bfb6c4d..ce4abbf921 100644 --- a/src/backends/tosaReference/TosaRefLayerSupport.cpp +++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp @@ -13,24 +13,25 @@ #include <vector> #include <array> +#include <tuple> namespace armnn { -static bool RunTosaLayerChecks(TosaSerializationOperator* op, - const std::vector<TosaSerializationTensor*>& inputs, - const std::vector<TosaSerializationTensor*>& outputs, - const std::vector<Attribute>& supportedAttributes, - const std::vector<DType>& supportedTypes, - Optional<string&> reasonIfUnsupported) +static bool RunTosaLayerChecksSingleDataType(TosaSerializationOperator* op, + const std::vector<TosaSerializationTensor*>& inputs, + const std::vector<TosaSerializationTensor*>& outputs, + const std::vector<Attribute>& supportedAttributes, + const std::vector<DType>& supportedTypes, + Optional<string&> reasonIfUnsupported) { bool supported = true; - std::string opCode = std::to_string(op->GetOp()); + std::string opString = TosaOpToString(op->GetOp()); // Check Attribute from operator (GetAttribute) supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported, - std::string("TOSA Reference Operator: " + opCode + + std::string("TOSA Reference Operator: " + opString + " has an unsupported attribute.").c_str()); for (auto input : inputs) @@ -40,14 +41,14 @@ static bool RunTosaLayerChecks(TosaSerializationOperator* op, // Check Dtype from tensor (GetDtype) supported &= CheckSupportRule(TosaTypeAnyOf(input, supportedTypes), reasonIfUnsupported, - std::string("TOSA Reference Operator: " + opCode + " for input: " + + std::string("TOSA Reference Operator: " + opString + " for input: " + input->GetName() + " has an unsupported data type: " + dataTypeCode).c_str()); // Check Shape from tensor (GetShape) supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input), reasonIfUnsupported, - std::string("Tosa Reference Operator: " + opCode + " for input: " + + std::string("Tosa Reference Operator: " + opString + " for input: " + input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); } @@ -58,20 +59,72 @@ static bool RunTosaLayerChecks(TosaSerializationOperator* op, // Check Dtype from tensor (GetDtype) supported &= CheckSupportRule(TosaTypeAnyOf(output, supportedTypes), reasonIfUnsupported, - std::string("TOSA Reference Operator: " + opCode + " for output: " + + std::string("TOSA Reference Operator: " + opString + " for output: " + output->GetName() + " has an unsupported data type: " + dataTypeCode).c_str()); // Check Shape from tensor (GetShape) supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output), reasonIfUnsupported, - std::string("Tosa Reference Operator: " + opCode + " for output: " + + std::string("Tosa Reference Operator: " + opString + " for output: " + output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); } return supported; } +static bool RunTosaLayerChecksInputOutputDataType(TosaSerializationOperator* op, + const std::vector<TosaSerializationTensor*>& inputs, + const std::vector<TosaSerializationTensor*>& outputs, + const std::vector<Attribute>& supportedAttributes, + const std::vector<std::tuple<DType,DType>>& supportedMappingTypes, + Optional<string&> reasonIfUnsupported) +{ + bool supported = true; + + std::string opString = TosaOpToString(op->GetOp()); + + // Check Attribute from operator (GetAttribute) + supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opString + + " has an unsupported attribute.").c_str()); + + supported &= CheckSupportRule(TosaAssertSize(inputs, outputs), reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opString + + " must have 1-to-1 mapping of inputs-to-outputs.").c_str()); + + for (uint32_t i = 0; i < inputs.size(); i++) + { + auto input = inputs[i]; + auto output = outputs[i]; + std::string inputDataTypeCode = std::to_string(input->GetDtype()); + std::string outputDataTypeCode = std::to_string(output->GetDtype()); + std::tuple<DType, DType> mappingType(input->GetDtype(), output->GetDtype()); + + // Check Dtype from tensor (GetDtype) + supported &= CheckSupportRule(TosaContainerContains(mappingType, supportedMappingTypes), + reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opString + " for input: " + + input->GetName() + " and output: " + output->GetName() + + " has an unsupported input data type: " + inputDataTypeCode + + " to output data type: " + outputDataTypeCode).c_str()); + + // Check Shape from tensor (GetShape) + supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input), + reasonIfUnsupported, + std::string("Tosa Reference Operator: " + opString + " for input: " + + input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); + + // Check Shape from tensor (GetShape) + supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output), + reasonIfUnsupported, + std::string("Tosa Reference Operator: " + opString + " for output: " + + output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); + } + + return supported; +} + static bool IsTosaLayerSupported(TosaSerializationOperator* op, const std::vector<TosaSerializationTensor*>& inputs, const std::vector<TosaSerializationTensor*>& outputs, @@ -81,8 +134,6 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, { case tosa::Op_ADD: { - bool supported = true; - std::vector<Attribute> supportedAttributes = { Attribute_NONE @@ -97,14 +148,84 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, }; // Check the attribute, data types and bounds for inputs and outputs. - supported = RunTosaLayerChecks(op, - inputs, - outputs, - supportedAttributes, - supportedTypes, - reasonIfUnsupported); - - return supported; + return RunTosaLayerChecksSingleDataType(op, + inputs, + outputs, + supportedAttributes, + supportedTypes, + reasonIfUnsupported); + } + case tosa::Op_AVG_POOL2D: + { + std::vector<Attribute> supportedAttributes = + { + Attribute_PoolAttribute + }; + + std::vector<std::tuple<DType, DType>> supportedTypesMapping = + { + std::tuple<DType, DType>(DType_FP16, DType_FP16), + std::tuple<DType, DType>(DType_FP16, DType_FP32), + std::tuple<DType, DType>(DType_FP32, DType_FP32), + std::tuple<DType, DType>(DType_INT8, DType_INT32), + std::tuple<DType, DType>(DType_INT16, DType_INT32) + }; + + // Check the attribute, data types and bounds for inputs and outputs. + return RunTosaLayerChecksInputOutputDataType(op, + inputs, + outputs, + supportedAttributes, + supportedTypesMapping, + reasonIfUnsupported); + } + case tosa::Op_MAX_POOL2D: + { + std::vector<Attribute> supportedAttributes = + { + Attribute_PoolAttribute + }; + + std::vector<DType> supportedTypes = + { + DType_FP16, + DType_FP32, + DType_INT8, + DType_INT16 + }; + + // Check the attribute, data types and bounds for inputs and outputs. + return RunTosaLayerChecksSingleDataType(op, + inputs, + outputs, + supportedAttributes, + supportedTypes, + reasonIfUnsupported); + } + case tosa::Op_PAD: + { + std::vector<Attribute> supportedAttributes = + { + Attribute_PadAttribute + }; + + std::vector<DType> supportedTypes = + { + DType_FP16, + DType_FP32, + DType_INT8, + DType_INT16, + DType_INT32, + DType_BOOL + }; + + // Check the attribute, data types and bounds for inputs and outputs. + return RunTosaLayerChecksSingleDataType(op, + inputs, + outputs, + supportedAttributes, + supportedTypes, + reasonIfUnsupported); } default: SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend."); @@ -136,6 +257,11 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type, case LayerType::Input: case LayerType::Output: return true; + case LayerType::Pooling2d: + // Setup inputs and outputs + inputInfos.push_back(&infos[0]); + outputInfos.push_back(&infos[1]); + break; default: break; } |