diff options
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r-- | src/backends/tosaReference/TosaRefLayerSupport.cpp | 162 |
1 files changed, 118 insertions, 44 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp index ce4abbf921..848b7efdce 100644 --- a/src/backends/tosaReference/TosaRefLayerSupport.cpp +++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp @@ -102,7 +102,7 @@ static bool RunTosaLayerChecksInputOutputDataType(TosaSerializationOperator* op, std::tuple<DType, DType> mappingType(input->GetDtype(), output->GetDtype()); // Check Dtype from tensor (GetDtype) - supported &= CheckSupportRule(TosaContainerContains(mappingType, supportedMappingTypes), + supported &= CheckSupportRule(TosaContainerContainsTwoTypes(mappingType, supportedMappingTypes), reasonIfUnsupported, std::string("TOSA Reference Operator: " + opString + " for input: " + input->GetName() + " and output: " + output->GetName() + @@ -125,6 +125,58 @@ static bool RunTosaLayerChecksInputOutputDataType(TosaSerializationOperator* op, return supported; } +static bool RunTosaLayerChecksInputWeightsOutputDataType( + TosaSerializationOperator* op, + const std::vector<TosaSerializationTensor*>& inputs, + const std::vector<TosaSerializationTensor*>& outputs, + const std::vector<Attribute>& supportedAttributes, + const std::vector<std::tuple<DType, DType, DType>>& supportedMappingTypes, + Optional<string&> reasonIfUnsupported) +{ + bool supported = true; + + std::string opString = TosaOpToString(op->GetOp()); + + // Check Attribute from operator (GetAttribute) + supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opString + + " has an unsupported attribute.").c_str()); + + // Check combination of input, weights and output types. + // Bias is the same as output type, so it is covered. + std::tuple<DType, DType, DType> mappingTypes(inputs[0]->GetDtype(), inputs[1]->GetDtype(), outputs[0]->GetDtype()); + + // Check Dtype from tensor (GetDtype) + supported &= CheckSupportRule(TosaContainerContainsThreeTypes(mappingTypes, supportedMappingTypes), + reasonIfUnsupported, + std::string("TOSA Reference Operator: " + opString + " for input 0: " + + inputs[0]->GetName() + ", input 1: " + inputs[1]->GetName() + + " and output: " + outputs[0]->GetName() + + " has an unsupported input data type combination.").c_str()); + + for (auto input : inputs) + { + // Check Shape from tensor (GetShape) + supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input), + reasonIfUnsupported, + std::string("Tosa Reference Operator: " + opString + " for input: " + + input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); + } + + for (auto output : outputs) + { + // Check Shape from tensor (GetShape) + supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output), + reasonIfUnsupported, + std::string("Tosa Reference Operator: " + opString + " for output: " + + output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str()); + } + + return supported; +} + + + static bool IsTosaLayerSupported(TosaSerializationOperator* op, const std::vector<TosaSerializationTensor*>& inputs, const std::vector<TosaSerializationTensor*>& outputs, @@ -134,10 +186,7 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, { case tosa::Op_ADD: { - std::vector<Attribute> supportedAttributes = - { - Attribute_NONE - }; + std::vector<Attribute> supportedAttributes = { Attribute_NONE }; // Only Int32, Fp32 and Fp16 are currently supported by the TOSA Reference Model. std::vector<DType> supportedTypes = @@ -148,20 +197,47 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, }; // Check the attribute, data types and bounds for inputs and outputs. - return RunTosaLayerChecksSingleDataType(op, - inputs, - outputs, - supportedAttributes, - supportedTypes, - reasonIfUnsupported); + return RunTosaLayerChecksSingleDataType( + op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); } - case tosa::Op_AVG_POOL2D: + case tosa::Op_CONST: + { + std::vector<Attribute> supportedAttributes = { Attribute_NONE }; + + std::vector<DType> supportedTypes = + { + DType_FP16, + DType_FP32, + DType_UINT8, + DType_INT8, + DType_INT16, + DType_INT32, + DType_BOOL + }; + + // Check the attribute, data types and bounds for inputs and outputs. + return RunTosaLayerChecksSingleDataType( + op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); + } + case tosa::Op_CONV2D: { - std::vector<Attribute> supportedAttributes = + std::vector<Attribute> supportedAttributes = { Attribute_ConvAttribute }; + + std::vector<std::tuple<DType, DType, DType>> supportedTypesMapping = { - Attribute_PoolAttribute + std::tuple<DType, DType, DType>(DType_FP16, DType_FP16, DType_FP16), + std::tuple<DType, DType, DType>(DType_FP16, DType_FP16, DType_FP32), + std::tuple<DType, DType, DType>(DType_FP32, DType_FP32, DType_FP32), + std::tuple<DType, DType, DType>(DType_INT8, DType_INT8, DType_INT32) }; + return RunTosaLayerChecksInputWeightsOutputDataType( + op, inputs, outputs, supportedAttributes, supportedTypesMapping, reasonIfUnsupported); + } + case tosa::Op_AVG_POOL2D: + { + std::vector<Attribute> supportedAttributes = { Attribute_PoolAttribute }; + std::vector<std::tuple<DType, DType>> supportedTypesMapping = { std::tuple<DType, DType>(DType_FP16, DType_FP16), @@ -172,19 +248,12 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, }; // Check the attribute, data types and bounds for inputs and outputs. - return RunTosaLayerChecksInputOutputDataType(op, - inputs, - outputs, - supportedAttributes, - supportedTypesMapping, - reasonIfUnsupported); + return RunTosaLayerChecksInputOutputDataType( + op, inputs, outputs, supportedAttributes, supportedTypesMapping, reasonIfUnsupported); } case tosa::Op_MAX_POOL2D: { - std::vector<Attribute> supportedAttributes = - { - Attribute_PoolAttribute - }; + std::vector<Attribute> supportedAttributes = { Attribute_PoolAttribute }; std::vector<DType> supportedTypes = { @@ -195,19 +264,12 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, }; // Check the attribute, data types and bounds for inputs and outputs. - return RunTosaLayerChecksSingleDataType(op, - inputs, - outputs, - supportedAttributes, - supportedTypes, - reasonIfUnsupported); + return RunTosaLayerChecksSingleDataType( + op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); } case tosa::Op_PAD: { - std::vector<Attribute> supportedAttributes = - { - Attribute_PadAttribute - }; + std::vector<Attribute> supportedAttributes = { Attribute_PadAttribute }; std::vector<DType> supportedTypes = { @@ -220,12 +282,8 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, }; // Check the attribute, data types and bounds for inputs and outputs. - return RunTosaLayerChecksSingleDataType(op, - inputs, - outputs, - supportedAttributes, - supportedTypes, - reasonIfUnsupported); + return RunTosaLayerChecksSingleDataType( + op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); } default: SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend."); @@ -248,15 +306,31 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type, switch (type) { + case LayerType::Input: + case LayerType::Output: + return true; case LayerType::Addition: // Setup inputs and outputs inputInfos.push_back(&infos[0]); inputInfos.push_back(&infos[1]); outputInfos.push_back(&infos[2]); break; - case LayerType::Input: - case LayerType::Output: - return true; + case LayerType::Constant: + outputInfos.push_back(&infos[0]); + break; + case LayerType::Convolution2d: + { + inputInfos.push_back(&infos[0]); // input + outputInfos.push_back(&infos[1]); // output + inputInfos.push_back(&infos[2]); // weights + + auto conv2dDesc = PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor); + if(conv2dDesc->m_BiasEnabled) + { + inputInfos.push_back(&infos[3]); // bias + } + break; + } case LayerType::Pooling2d: // Setup inputs and outputs inputInfos.push_back(&infos[0]); @@ -266,7 +340,7 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type, break; } - auto mappings = GetTosaMapping(type, inputInfos, outputInfos, descriptor, false); + auto mappings = GetTosaMapping(nullptr, type, inputInfos, outputInfos, descriptor); if (mappings->GetName() == "") { // There currently isn't a TOSA mapping for this layer, as the default was returned. |