aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.cpp
diff options
context:
space:
mode:
authorSimon Obute <simon.obute@arm.com>2021-09-03 15:50:13 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2021-09-24 16:06:30 +0100
commit51f67776a695c217a32596af806afeeb080f5528 (patch)
tree33ccfd87ba365bcc6fc86d5a2181991a130b3061 /src/armnnDeserializer/Deserializer.cpp
parentf10b15a8946f39bdf3f60cebc59d2963069eedca (diff)
downloadarmnn-51f67776a695c217a32596af806afeeb080f5528.tar.gz
IVGCVSW-3705 Add Channel Shuffle Front end and Ref Implementation
* Add front end * Add reference workload * Add unit tests * Add Serializer and Deserializer * Update ArmNN Versioning Signed-off-by: Simon Obute <simon.obute@arm.com> Change-Id: I9ac1f953af3974382eac8e8d62d794d2344e8f47
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp28
1 files changed, 27 insertions, 1 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 074429b73c..13415814a2 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -215,8 +215,9 @@ m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupporte
m_ParserFunctions[Layer_ArgMinMaxLayer] = &DeserializerImpl::ParseArgMinMax;
m_ParserFunctions[Layer_BatchToSpaceNdLayer] = &DeserializerImpl::ParseBatchToSpaceNd;
m_ParserFunctions[Layer_BatchNormalizationLayer] = &DeserializerImpl::ParseBatchNormalization;
- m_ParserFunctions[Layer_ComparisonLayer] = &DeserializerImpl::ParseComparison;
m_ParserFunctions[Layer_CastLayer] = &DeserializerImpl::ParseCast;
+ m_ParserFunctions[Layer_ChannelShuffleLayer] = &DeserializerImpl::ParseChannelShuffle;
+ m_ParserFunctions[Layer_ComparisonLayer] = &DeserializerImpl::ParseComparison;
m_ParserFunctions[Layer_ConcatLayer] = &DeserializerImpl::ParseConcat;
m_ParserFunctions[Layer_ConstantLayer] = &DeserializerImpl::ParseConstant;
m_ParserFunctions[Layer_Convolution2dLayer] = &DeserializerImpl::ParseConvolution2d;
@@ -293,6 +294,8 @@ LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& gr
return graphPtr->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer()->base();
case Layer::Layer_CastLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_CastLayer()->base();
+ case Layer::Layer_ChannelShuffleLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->base();
case Layer::Layer_ComparisonLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_ComparisonLayer()->base();
case Layer::Layer_ConcatLayer:
@@ -1780,7 +1783,30 @@ const armnnSerializer::OriginsDescriptor* GetOriginsDescriptor(const armnnSerial
throw armnn::Exception("unknown layer type, should be concat or merger");
}
}
+void IDeserializer::DeserializerImpl::ParseChannelShuffle(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 1);
+
+ TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+ armnn::ChannelShuffleDescriptor descriptor;
+ descriptor.m_Axis = graph->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->descriptor()->axis();
+ descriptor.m_NumGroups =
+ graph->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->descriptor()->numGroups();
+
+ auto layerName = GetLayerName(graph, layerIndex);
+ IConnectableLayer* layer = m_Network->AddChannelShuffleLayer(descriptor, layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
void IDeserializer::DeserializerImpl::ParseComparison(GraphPtr graph, unsigned int layerIndex)
{
CHECK_LAYERS(graph, 0, layerIndex);