aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/TosaRefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r--src/backends/tosaReference/TosaRefLayerSupport.cpp142
1 files changed, 142 insertions, 0 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp
new file mode 100644
index 0000000000..80e982f1c4
--- /dev/null
+++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp
@@ -0,0 +1,142 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "TosaRefLayerSupport.hpp"
+#include <tosaCommon/TosaMappings.hpp>
+
+#include <armnn/Types.hpp>
+#include <armnn/utility/IgnoreUnused.hpp>
+#include <tosaCommon/TosaLayerSupportRules.hpp>
+#include <LayerSupportCommon.hpp>
+
+#include <vector>
+#include <array>
+
+namespace armnn
+{
+
+static bool IsTosaLayerSupported(TosaSerializationOperator* op,
+ const std::vector<TosaSerializationTensor*>& inputs,
+ const std::vector<TosaSerializationTensor*>& outputs,
+ Optional<string&> reasonIfUnsupported)
+{
+ switch(op->GetOp())
+ {
+ case tosa::Op_ADD:
+ {
+ bool supported = true;
+
+ std::array<Attribute, 1> supportedAttributes =
+ {
+ Attribute_NONE
+ };
+
+ // Check Attribute from operator (GetAttribute)
+ supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
+ std::string("TOSA Reference addition: operator has an unsupported attribute.").c_str());
+
+ std::array<DType, 8> supportedTypes =
+ {
+ DType_BOOL,
+ DType_UINT8,
+ DType_INT4,
+ DType_INT8,
+ DType_INT16,
+ DType_INT32,
+ DType_FLOAT,
+ DType_UINT16
+ };
+
+ for (auto tensor : inputs)
+ {
+ // Check Dtype from tensor (GetDtype)
+ supported &= CheckSupportRule(TosaTypeAnyOf(tensor, supportedTypes),
+ reasonIfUnsupported,
+ std::string("TOSA Reference addition: " + tensor->GetName() +
+ " is not a supported type.").c_str());
+
+ // Check Shape from tensor (GetShape)
+ supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(tensor),
+ reasonIfUnsupported,
+ std::string("Tosa Reference addition: " + tensor->GetName() + " Shape.Size()"
+ " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
+ }
+
+ // Check Dtype from tensor (GetDtype)
+ supported &= CheckSupportRule(TosaTypeAnyOf(outputs[0], supportedTypes),
+ reasonIfUnsupported,
+ std::string("TOSA Reference addition: " + outputs[0]->GetName() +
+ " is not a supported type.").c_str());
+
+ // Check Shape from tensor (GetShape)
+ supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(outputs[0]),
+ reasonIfUnsupported,
+ std::string("Tosa Reference addition: " + outputs[0]->GetName() + " Shape.Size()"
+ " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
+
+ return supported;
+ }
+ default:
+ SetValueChecked(reasonIfUnsupported, "Operation is currently unsupported by the TOSA Reference Backend.");
+ return false;
+ }
+}
+
+bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
+ const std::vector<TensorInfo>& infos,
+ const BaseDescriptor& descriptor,
+ const Optional<LstmInputParamsInfo>& lstmParamsInfo,
+ const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmInputParamsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ IgnoreUnused(lstmParamsInfo);
+ IgnoreUnused(quantizedLstmInputParamsInfo);
+
+ // Setup Inputs
+ const auto input0 = infos[0];
+ const TensorInfo* ptr0 = &input0;
+ const auto input1 = infos[1];
+ const TensorInfo* ptr1 = &input1;
+ std::vector<const TensorInfo*> inputInfos = {ptr0, ptr1};
+
+ // Setup Outputs
+ const auto output = infos[2];
+ const TensorInfo* ptr2 = &output;
+ std::vector<const TensorInfo*> outputInfos = {ptr2};
+
+ auto mappings = GetTosaMapping(type, inputInfos, outputInfos, descriptor);
+
+ // Loop through block and get each tensor and operator
+ for (long unsigned int i = 0; i < mappings->GetOperators().size(); ++i)
+ {
+ // While looping over operators check for op_UNKNOWN which is unsupported
+ if (mappings->GetOperators()[i]->GetOp() == tosa::Op_UNKNOWN) { return false;}
+
+ // Loop over operators and get GetInput/OutputTensorNames, loop over resulting names and
+ // use GetTensorByName to pass pointers to tensors on to the IsTosaLayerSupported()
+ std::vector<TosaSerializationTensor*> inputTensorsVect;
+ for (const auto& name : mappings->GetOperators()[i]->GetInputTensorNames())
+ {
+ inputTensorsVect.push_back(mappings->GetTensorByName(name));
+ }
+
+ std::vector<TosaSerializationTensor*> outputTensorsVect;
+ for (const auto& name : mappings->GetOperators()[i]->GetOutputTensorNames())
+ {
+ outputTensorsVect.push_back(mappings->GetTensorByName(name));
+ }
+
+ if (!IsTosaLayerSupported(mappings->GetOperators()[i],
+ inputTensorsVect,
+ outputTensorsVect,
+ reasonIfUnsupported))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace armnn