diff options
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r-- | src/backends/tosaReference/TosaRefLayerSupport.cpp | 103 |
1 files changed, 65 insertions, 38 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp index c2b0b1b0b9..a39bfb6c4d 100644 --- a/src/backends/tosaReference/TosaRefLayerSupport.cpp +++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp @@ -17,6 +17,61 @@ namespace armnn { +static bool RunTosaLayerChecks(TosaSerializationOperator* op, + const std::vector<TosaSerializationTensor*>& inputs, + const std::vector<TosaSerializationTensor*>& outputs, + const std::vector<Attribute>& supportedAttributes, + const std::vector<DType>& supportedTypes, + Optional<string&> reasonIfUnsupported) +{ + bool supported = true; + + std::string opCode = std::to_string(op->GetOp()); + + // Check Attribute from operator (GetAttribute) + supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opCode + + " has an unsupported attribute.").c_str()); + + for (auto input : inputs) + { + std::string dataTypeCode = std::to_string(input->GetDtype()); + + // Check Dtype from tensor (GetDtype) + supported &= CheckSupportRule(TosaTypeAnyOf(input, supportedTypes), + reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opCode + " for input: " + + input->GetName() + " has an unsupported data type: " + + dataTypeCode).c_str()); + + // Check Shape from tensor (GetShape) + supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input), + reasonIfUnsupported, + std::string("Tosa Reference Operator: " + opCode + " for input: " + + input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); + } + + for (auto output : outputs) + { + std::string dataTypeCode = std::to_string(output->GetDtype()); + + // Check Dtype from tensor (GetDtype) + supported &= CheckSupportRule(TosaTypeAnyOf(output, supportedTypes), + reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opCode + " for output: " + + output->GetName() + " has an unsupported data type: " + + dataTypeCode).c_str()); + + // Check Shape from tensor (GetShape) + supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output), + reasonIfUnsupported, + std::string("Tosa Reference Operator: " + opCode + " for output: " + + output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); + } + + return supported; +} + static bool IsTosaLayerSupported(TosaSerializationOperator* op, const std::vector<TosaSerializationTensor*>& inputs, const std::vector<TosaSerializationTensor*>& outputs, @@ -28,54 +83,26 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, { bool supported = true; - std::array<Attribute, 1> supportedAttributes = + std::vector<Attribute> supportedAttributes = { Attribute_NONE }; - // Check Attribute from operator (GetAttribute) - supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported, - std::string("TOSA Reference addition: operator has an unsupported attribute.").c_str()); - - std::array<DType, 9> supportedTypes = + // Only Int32, Fp32 and Fp16 are currently supported by the TOSA Reference Model. + std::vector<DType> supportedTypes = { - DType_BOOL, - DType_UINT8, - DType_UINT16, - DType_INT4, - DType_INT8, - DType_INT16, DType_INT32, DType_FP16, DType_FP32 }; - for (auto tensor : inputs) - { - // Check Dtype from tensor (GetDtype) - supported &= CheckSupportRule(TosaTypeAnyOf(tensor, supportedTypes), - reasonIfUnsupported, - std::string("TOSA Reference addition: " + tensor->GetName() + - " is not a supported type.").c_str()); - - // Check Shape from tensor (GetShape) - supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(tensor), - reasonIfUnsupported, - std::string("Tosa Reference addition: " + tensor->GetName() + " Shape.Size()" - " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str()); - } - - // Check Dtype from tensor (GetDtype) - supported &= CheckSupportRule(TosaTypeAnyOf(outputs[0], supportedTypes), - reasonIfUnsupported, - std::string("TOSA Reference addition: " + outputs[0]->GetName() + - " is not a supported type.").c_str()); - - // Check Shape from tensor (GetShape) - supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(outputs[0]), - reasonIfUnsupported, - std::string("Tosa Reference addition: " + outputs[0]->GetName() + " Shape.Size()" - " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str()); + // Check the attribute, data types and bounds for inputs and outputs. + supported = RunTosaLayerChecks(op, + inputs, + outputs, + supportedAttributes, + supportedTypes, + reasonIfUnsupported); return supported; } |