aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/TosaRefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/TosaRefLayerSupport.cpp')
-rw-r--r--src/backends/tosaReference/TosaRefLayerSupport.cpp35
1 files changed, 23 insertions, 12 deletions
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp
index 18530bb535..f5f34a814b 100644
--- a/src/backends/tosaReference/TosaRefLayerSupport.cpp
+++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp
@@ -95,25 +95,36 @@ bool TosaRefLayerSupport::IsLayerSupported(const LayerType& type,
IgnoreUnused(lstmParamsInfo);
IgnoreUnused(quantizedLstmInputParamsInfo);
- // Setup Inputs
- const auto input0 = infos[0];
- const TensorInfo* ptr0 = &input0;
- const auto input1 = infos[1];
- const TensorInfo* ptr1 = &input1;
- std::vector<const TensorInfo*> inputInfos = {ptr0, ptr1};
-
- // Setup Outputs
- const auto output = infos[2];
- const TensorInfo* ptr2 = &output;
- std::vector<const TensorInfo*> outputInfos = {ptr2};
+ std::vector<const TensorInfo*> inputInfos;
+ std::vector<const TensorInfo*> outputInfos;
+
+ switch (type)
+ {
+ case LayerType::Addition:
+ // Setup inputs and outputs
+ inputInfos.push_back(&infos[0]);
+ inputInfos.push_back(&infos[1]);
+ outputInfos.push_back(&infos[2]);
+ break;
+ case LayerType::Input:
+ case LayerType::Output:
+ return true;
+ default:
+ break;
+ }
auto mappings = GetTosaMapping(type, inputInfos, outputInfos, descriptor);
+ if (mappings->GetName() == "")
+ {
+ // There currently isn't a TOSA mapping for this layer, as the default was returned.
+ return false;
+ }
// Loop through block and get each tensor and operator
for (long unsigned int i = 0; i < mappings->GetOperators().size(); ++i)
{
// While looping over operators check for op_UNKNOWN which is unsupported
- if (mappings->GetOperators()[i]->GetOp() == tosa::Op_UNKNOWN) { return false;}
+ if (mappings->GetOperators()[i]->GetOp() == tosa::Op_UNKNOWN) { return false; }
// Loop over operators and get GetInput/OutputTensorNames, loop over resulting names and
// use GetTensorByName to pass pointers to tensors on to the IsTosaLayerSupported()