aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/TosaRefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r--src/backends/tosaReference/TosaRefLayerSupport.cpp170
1 files changed, 148 insertions, 22 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp
index a39bfb6c4d..ce4abbf921 100644
--- a/src/backends/tosaReference/TosaRefLayerSupport.cpp
+++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp
@@ -13,24 +13,25 @@
#include <vector>
#include <array>
+#include <tuple>
namespace armnn
{
-static bool RunTosaLayerChecks(TosaSerializationOperator* op,
- const std::vector<TosaSerializationTensor*>& inputs,
- const std::vector<TosaSerializationTensor*>& outputs,
- const std::vector<Attribute>& supportedAttributes,
- const std::vector<DType>& supportedTypes,
- Optional<string&> reasonIfUnsupported)
+static bool RunTosaLayerChecksSingleDataType(TosaSerializationOperator* op,
+ const std::vector<TosaSerializationTensor*>& inputs,
+ const std::vector<TosaSerializationTensor*>& outputs,
+ const std::vector<Attribute>& supportedAttributes,
+ const std::vector<DType>& supportedTypes,
+ Optional<string&> reasonIfUnsupported)
{
bool supported = true;
- std::string opCode = std::to_string(op->GetOp());
+ std::string opString = TosaOpToString(op->GetOp());
// Check Attribute from operator (GetAttribute)
supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
- std::string("TOSA Reference Operator: " + opCode +
+ std::string("TOSA Reference Operator: " + opString +
" has an unsupported attribute.").c_str());
for (auto input : inputs)
@@ -40,14 +41,14 @@ static bool RunTosaLayerChecks(TosaSerializationOperator* op,
// Check Dtype from tensor (GetDtype)
supported &= CheckSupportRule(TosaTypeAnyOf(input, supportedTypes),
reasonIfUnsupported,
- std::string("TOSA Reference Operator: " + opCode + " for input: " +
+ std::string("TOSA Reference Operator: " + opString + " for input: " +
input->GetName() + " has an unsupported data type: " +
dataTypeCode).c_str());
// Check Shape from tensor (GetShape)
supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input),
reasonIfUnsupported,
- std::string("Tosa Reference Operator: " + opCode + " for input: " +
+ std::string("Tosa Reference Operator: " + opString + " for input: " +
input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
}
@@ -58,20 +59,72 @@ static bool RunTosaLayerChecks(TosaSerializationOperator* op,
// Check Dtype from tensor (GetDtype)
supported &= CheckSupportRule(TosaTypeAnyOf(output, supportedTypes),
reasonIfUnsupported,
- std::string("TOSA Reference Operator: " + opCode + " for output: " +
+ std::string("TOSA Reference Operator: " + opString + " for output: " +
output->GetName() + " has an unsupported data type: " +
dataTypeCode).c_str());
// Check Shape from tensor (GetShape)
supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output),
reasonIfUnsupported,
- std::string("Tosa Reference Operator: " + opCode + " for output: " +
+ std::string("Tosa Reference Operator: " + opString + " for output: " +
output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
}
return supported;
}
+static bool RunTosaLayerChecksInputOutputDataType(TosaSerializationOperator* op,
+ const std::vector<TosaSerializationTensor*>& inputs,
+ const std::vector<TosaSerializationTensor*>& outputs,
+ const std::vector<Attribute>& supportedAttributes,
+ const std::vector<std::tuple<DType,DType>>& supportedMappingTypes,
+ Optional<string&> reasonIfUnsupported)
+{
+ bool supported = true;
+
+ std::string opString = TosaOpToString(op->GetOp());
+
+ // Check Attribute from operator (GetAttribute)
+ supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
+ std::string("TOSA Reference Operator: " + opString +
+ " has an unsupported attribute.").c_str());
+
+ supported &= CheckSupportRule(TosaAssertSize(inputs, outputs), reasonIfUnsupported,
+ std::string("TOSA Reference Operator: " + opString +
+ " must have 1-to-1 mapping of inputs-to-outputs.").c_str());
+
+ for (uint32_t i = 0; i < inputs.size(); i++)
+ {
+ auto input = inputs[i];
+ auto output = outputs[i];
+ std::string inputDataTypeCode = std::to_string(input->GetDtype());
+ std::string outputDataTypeCode = std::to_string(output->GetDtype());
+ std::tuple<DType, DType> mappingType(input->GetDtype(), output->GetDtype());
+
+ // Check Dtype from tensor (GetDtype)
+ supported &= CheckSupportRule(TosaContainerContains(mappingType, supportedMappingTypes),
+ reasonIfUnsupported,
+ std::string("TOSA Reference Operator: " + opString + " for input: " +
+ input->GetName() + " and output: " + output->GetName() +
+ " has an unsupported input data type: " + inputDataTypeCode +
+ " to output data type: " + outputDataTypeCode).c_str());
+
+ // Check Shape from tensor (GetShape)
+ supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input),
+ reasonIfUnsupported,
+ std::string("Tosa Reference Operator: " + opString + " for input: " +
+ input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
+
+ // Check Shape from tensor (GetShape)
+ supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output),
+ reasonIfUnsupported,
+ std::string("Tosa Reference Operator: " + opString + " for output: " +
+ output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
+ }
+
+ return supported;
+}
+
static bool IsTosaLayerSupported(TosaSerializationOperator* op,
const std::vector<TosaSerializationTensor*>& inputs,
const std::vector<TosaSerializationTensor*>& outputs,
@@ -81,8 +134,6 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op,
{
case tosa::Op_ADD:
{
- bool supported = true;
-
std::vector<Attribute> supportedAttributes =
{
Attribute_NONE
@@ -97,14 +148,84 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op,
};
// Check the attribute, data types and bounds for inputs and outputs.
- supported = RunTosaLayerChecks(op,
- inputs,
- outputs,
- supportedAttributes,
- supportedTypes,
- reasonIfUnsupported);
-
- return supported;
+ return RunTosaLayerChecksSingleDataType(op,
+ inputs,
+ outputs,
+ supportedAttributes,
+ supportedTypes,
+ reasonIfUnsupported);
+ }
+ case tosa::Op_AVG_POOL2D:
+ {
+ std::vector<Attribute> supportedAttributes =
+ {
+ Attribute_PoolAttribute
+ };
+
+ std::vector<std::tuple<DType, DType>> supportedTypesMapping =
+ {
+ std::tuple<DType, DType>(DType_FP16, DType_FP16),
+ std::tuple<DType, DType>(DType_FP16, DType_FP32),
+ std::tuple<DType, DType>(DType_FP32, DType_FP32),
+ std::tuple<DType, DType>(DType_INT8, DType_INT32),
+ std::tuple<DType, DType>(DType_INT16, DType_INT32)
+ };
+
+ // Check the attribute, data types and bounds for inputs and outputs.
+ return RunTosaLayerChecksInputOutputDataType(op,
+ inputs,
+ outputs,
+ supportedAttributes,
+ supportedTypesMapping,
+ reasonIfUnsupported);
+ }
+ case tosa::Op_MAX_POOL2D:
+ {
+ std::vector<Attribute> supportedAttributes =
+ {
+ Attribute_PoolAttribute
+ };
+
+ std::vector<DType> supportedTypes =
+ {
+ DType_FP16,
+ DType_FP32,
+ DType_INT8,
+ DType_INT16
+ };
+
+ // Check the attribute, data types and bounds for inputs and outputs.
+ return RunTosaLayerChecksSingleDataType(op,
+ inputs,
+ outputs,
+ supportedAttributes,
+ supportedTypes,
+ reasonIfUnsupported);
+ }
+ case tosa::Op_PAD:
+ {
+ std::vector<Attribute> supportedAttributes =
+ {
+ Attribute_PadAttribute
+ };
+
+ std::vector<DType> supportedTypes =
+ {
+ DType_FP16,
+ DType_FP32,
+ DType_INT8,
+ DType_INT16,
+ DType_INT32,
+ DType_BOOL
+ };
+
+ // Check the attribute, data types and bounds for inputs and outputs.
+ return RunTosaLayerChecksSingleDataType(op,
+ inputs,
+ outputs,
+ supportedAttributes,
+ supportedTypes,
+ reasonIfUnsupported);
}
default:
SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend.");
@@ -136,6 +257,11 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
case LayerType::Input:
case LayerType::Output:
return true;
+ case LayerType::Pooling2d:
+ // Setup inputs and outputs
+ inputInfos.push_back(&infos[0]);
+ outputInfos.push_back(&infos[1]);
+ break;
default:
break;
}