diff options
Diffstat (limited to 'src/backends/tosaCommon/TosaMappings.cpp')
-rw-r--r-- | src/backends/tosaCommon/TosaMappings.cpp | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/src/backends/tosaCommon/TosaMappings.cpp b/src/backends/tosaCommon/TosaMappings.cpp index 71d2012cbc..a37eaf29b3 100644 --- a/src/backends/tosaCommon/TosaMappings.cpp +++ b/src/backends/tosaCommon/TosaMappings.cpp @@ -23,10 +23,18 @@ void SetBasicBlockConstantTensorData(Layer* layer, TosaSerializationBasicBlock* } } +TosaSerializationBasicBlock* CreateEmptyTosaSerializationBasicBlock() +{ + // empty basic block when no tosa mapping implemented/exists + TosaSerializationOperator* op = + new TosaSerializationOperator(Op_UNKNOWN, Attribute_NONE, nullptr, {}, {}); + return new TosaSerializationBasicBlock("", {op}, {}, {}, {}); +} + TosaSerializationBasicBlock* GetTosaMapping(const LayerType type, const std::vector<const TensorInfo*>& inputs, const std::vector<const TensorInfo*>& outputs, - const BaseDescriptor& /*descriptor*/, + const BaseDescriptor& descriptor, bool isMain = false) { switch (type) @@ -35,11 +43,30 @@ TosaSerializationBasicBlock* GetTosaMapping(const LayerType type, { return ConvertAdditionToTosaOperator(inputs, outputs, isMain); } + case LayerType::Pooling2d: + { + auto poolDesc = PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor); + + bool avgPoolIgnoreValue = + (poolDesc->m_PoolType == PoolingAlgorithm::Average) && + (poolDesc->m_PaddingMethod == PaddingMethod::IgnoreValue); + + if (poolDesc->m_PoolType == PoolingAlgorithm::L2) + { + return CreateEmptyTosaSerializationBasicBlock(); + } + else if (avgPoolIgnoreValue) + { + return ConvertAvgPool2DIgnoreValueToTosaOperator(inputs, outputs, isMain, poolDesc); + } + else + { + return ConvertPooling2DToTosaOperator(inputs, outputs, isMain, poolDesc); + } + } default: { - // empty basic block when no tosa mapping implemented/exists - TosaSerializationOperator* op = new TosaSerializationOperator(Op_UNKNOWN, Attribute_NONE, nullptr, {}, {}); - return new TosaSerializationBasicBlock("", {op}, {}, {}, {}); + return CreateEmptyTosaSerializationBasicBlock(); } } } |