diff options
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r-- | src/backends/tosaReference/TosaRefLayerSupport.cpp | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp index 848b7efdce..5cda85af20 100644 --- a/src/backends/tosaReference/TosaRefLayerSupport.cpp +++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp @@ -36,7 +36,7 @@ static bool RunTosaLayerChecksSingleDataType(TosaSerializationOperator* op, for (auto input : inputs) { - std::string dataTypeCode = std::to_string(input->GetDtype()); + std::string dataTypeCode = TosaDTypeToString(input->GetDtype()); // Check Dtype from tensor (GetDtype) supported &= CheckSupportRule(TosaTypeAnyOf(input, supportedTypes), @@ -54,7 +54,7 @@ static bool RunTosaLayerChecksSingleDataType(TosaSerializationOperator* op, for (auto output : outputs) { - std::string dataTypeCode = std::to_string(output->GetDtype()); + std::string dataTypeCode = TosaDTypeToString(output->GetDtype()); // Check Dtype from tensor (GetDtype) supported &= CheckSupportRule(TosaTypeAnyOf(output, supportedTypes), @@ -97,8 +97,8 @@ static bool RunTosaLayerChecksInputOutputDataType(TosaSerializationOperator* op, { auto input = inputs[i]; auto output = outputs[i]; - std::string inputDataTypeCode = std::to_string(input->GetDtype()); - std::string outputDataTypeCode = std::to_string(output->GetDtype()); + std::string inputDataTypeCode = TosaDTypeToString(input->GetDtype()); + std::string outputDataTypeCode = TosaDTypeToString(output->GetDtype()); std::tuple<DType, DType> mappingType(input->GetDtype(), output->GetDtype()); // Check Dtype from tensor (GetDtype) @@ -285,6 +285,24 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, return RunTosaLayerChecksSingleDataType( op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); } + case tosa::Op_RESHAPE: + { + std::vector<Attribute> supportedAttributes = { Attribute_ReshapeAttribute }; + + std::vector<DType> supportedTypes = + { + DType_FP16, + DType_FP32, + DType_INT8, + DType_INT16, + DType_INT32, + DType_BOOL + }; + + // Check the attribute, data types and bounds for inputs and outputs. + return RunTosaLayerChecksSingleDataType( + op, inputs, outputs, supportedAttributes, supportedTypes, reasonIfUnsupported); + } default: SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend."); return false; @@ -332,6 +350,7 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type, break; } case LayerType::Pooling2d: + case LayerType::Reshape: // Setup inputs and outputs inputInfos.push_back(&infos[0]); outputInfos.push_back(&infos[1]); |