diff options
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r-- | src/backends/tosaReference/TosaRefLayerSupport.cpp | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp index 18530bb535..f5f34a814b 100644 --- a/src/backends/tosaReference/TosaRefLayerSupport.cpp +++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp @@ -95,25 +95,36 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type, IgnoreUnused(lstmParamsInfo); IgnoreUnused(quantizedLstmInputParamsInfo); - // Setup Inputs - const auto input0 = infos[0]; - const TensorInfo* ptr0 = &input0; - const auto input1 = infos[1]; - const TensorInfo* ptr1 = &input1; - std::vector<const TensorInfo*> inputInfos = {ptr0, ptr1}; - - // Setup Outputs - const auto output = infos[2]; - const TensorInfo* ptr2 = &output; - std::vector<const TensorInfo*> outputInfos = {ptr2}; + std::vector<const TensorInfo*> inputInfos; + std::vector<const TensorInfo*> outputInfos; + + switch (type) + { + case LayerType::Addition: + // Setup inputs and outputs + inputInfos.push_back(&infos[0]); + inputInfos.push_back(&infos[1]); + outputInfos.push_back(&infos[2]); + break; + case LayerType::Input: + case LayerType::Output: + return true; + default: + break; + } auto mappings = GetTosaMapping(type, inputInfos, outputInfos, descriptor); + if (mappings->GetName() == "") + { + // There currently isn't a TOSA mapping for this layer, as the default was returned. + return false; + } // Loop through block and get each tensor and operator for (long unsigned int i = 0; i < mappings->GetOperators().size(); ++i) { // While looping over operators check for op_UNKNOWN which is unsupported - if (mappings->GetOperators()[i]->GetOp() == tosa::Op_UNKNOWN) { return false;} + if (mappings->GetOperators()[i]->GetOp() == tosa::Op_UNKNOWN) { return false; } // Loop over operators and get GetInput/OutputTensorNames, loop over resulting names and // use GetTensorByName to pass pointers to tensors on to the IsTosaLayerSupported() |