aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/TosaRefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r--src/backends/tosaReference/TosaRefLayerSupport.cpp103
1 files changed, 65 insertions, 38 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp
index c2b0b1b0b9..a39bfb6c4d 100644
--- a/src/backends/tosaReference/TosaRefLayerSupport.cpp
+++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp
@@ -17,6 +17,61 @@
namespace armnn
{
+static bool RunTosaLayerChecks(TosaSerializationOperator* op,
+ const std::vector<TosaSerializationTensor*>& inputs,
+ const std::vector<TosaSerializationTensor*>& outputs,
+ const std::vector<Attribute>& supportedAttributes,
+ const std::vector<DType>& supportedTypes,
+ Optional<string&> reasonIfUnsupported)
+{
+ bool supported = true;
+
+ std::string opCode = std::to_string(op->GetOp());
+
+ // Check Attribute from operator (GetAttribute)
+ supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
+ std::string("TOSA Reference Operator: " + opCode +
+ " has an unsupported attribute.").c_str());
+
+ for (auto input : inputs)
+ {
+ std::string dataTypeCode = std::to_string(input->GetDtype());
+
+ // Check Dtype from tensor (GetDtype)
+ supported &= CheckSupportRule(TosaTypeAnyOf(input, supportedTypes),
+ reasonIfUnsupported,
+ std::string("TOSA Reference Operator: " + opCode + " for input: " +
+ input->GetName() + " has an unsupported data type: " +
+ dataTypeCode).c_str());
+
+ // Check Shape from tensor (GetShape)
+ supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input),
+ reasonIfUnsupported,
+ std::string("Tosa Reference Operator: " + opCode + " for input: " +
+ input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
+ }
+
+ for (auto output : outputs)
+ {
+ std::string dataTypeCode = std::to_string(output->GetDtype());
+
+ // Check Dtype from tensor (GetDtype)
+ supported &= CheckSupportRule(TosaTypeAnyOf(output, supportedTypes),
+ reasonIfUnsupported,
+ std::string("TOSA Reference Operator: " + opCode + " for output: " +
+ output->GetName() + " has an unsupported data type: " +
+ dataTypeCode).c_str());
+
+ // Check Shape from tensor (GetShape)
+ supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output),
+ reasonIfUnsupported,
+ std::string("Tosa Reference Operator: " + opCode + " for output: " +
+ output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
+ }
+
+ return supported;
+}
+
static bool IsTosaLayerSupported(TosaSerializationOperator* op,
const std::vector<TosaSerializationTensor*>& inputs,
const std::vector<TosaSerializationTensor*>& outputs,
@@ -28,54 +83,26 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op,
{
bool supported = true;
- std::array<Attribute, 1> supportedAttributes =
+ std::vector<Attribute> supportedAttributes =
{
Attribute_NONE
};
- // Check Attribute from operator (GetAttribute)
- supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
- std::string("TOSA Reference addition: operator has an unsupported attribute.").c_str());
-
- std::array<DType, 9> supportedTypes =
+ // Only Int32, Fp32 and Fp16 are currently supported by the TOSA Reference Model.
+ std::vector<DType> supportedTypes =
{
- DType_BOOL,
- DType_UINT8,
- DType_UINT16,
- DType_INT4,
- DType_INT8,
- DType_INT16,
DType_INT32,
DType_FP16,
DType_FP32
};
- for (auto tensor : inputs)
- {
- // Check Dtype from tensor (GetDtype)
- supported &= CheckSupportRule(TosaTypeAnyOf(tensor, supportedTypes),
- reasonIfUnsupported,
- std::string("TOSA Reference addition: " + tensor->GetName() +
- " is not a supported type.").c_str());
-
- // Check Shape from tensor (GetShape)
- supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(tensor),
- reasonIfUnsupported,
- std::string("Tosa Reference addition: " + tensor->GetName() + " Shape.Size()"
- " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
- }
-
- // Check Dtype from tensor (GetDtype)
- supported &= CheckSupportRule(TosaTypeAnyOf(outputs[0], supportedTypes),
- reasonIfUnsupported,
- std::string("TOSA Reference addition: " + outputs[0]->GetName() +
- " is not a supported type.").c_str());
-
- // Check Shape from tensor (GetShape)
- supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(outputs[0]),
- reasonIfUnsupported,
- std::string("Tosa Reference addition: " + outputs[0]->GetName() + " Shape.Size()"
- " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
+ // Check the attribute, data types and bounds for inputs and outputs.
+ supported = RunTosaLayerChecks(op,
+ inputs,
+ outputs,
+ supportedAttributes,
+ supportedTypes,
+ reasonIfUnsupported);
return supported;
}