aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/SyntheticDataTypeMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/SyntheticDataTypeMutator.cpp')
-rw-r--r--src/graph/mutators/SyntheticDataTypeMutator.cpp72
1 files changed, 33 insertions, 39 deletions
diff --git a/src/graph/mutators/SyntheticDataTypeMutator.cpp b/src/graph/mutators/SyntheticDataTypeMutator.cpp
index 74d040b81d..3dc2480e85 100644
--- a/src/graph/mutators/SyntheticDataTypeMutator.cpp
+++ b/src/graph/mutators/SyntheticDataTypeMutator.cpp
@@ -26,8 +26,8 @@
#include "arm_compute/graph/GraphBuilder.h"
#include "arm_compute/graph/ITensorAccessor.h"
#include "arm_compute/graph/Logger.h"
-#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/nodes/Nodes.h"
+#include "arm_compute/graph/Utils.h"
#include "support/Cast.h"
@@ -62,14 +62,12 @@ public:
*/
bool is_mutation_supported(Graph &g)
{
- const std::set<NodeType> unsupported_node_types = { NodeType::DetectionOutputLayer,
- NodeType::NormalizationLayer,
- NodeType::PriorBoxLayer
- };
+ const std::set<NodeType> unsupported_node_types = {NodeType::DetectionOutputLayer, NodeType::NormalizationLayer,
+ NodeType::PriorBoxLayer};
- for(const auto &utype : unsupported_node_types)
+ for (const auto &utype : unsupported_node_types)
{
- if(!g.nodes(utype).empty())
+ if (!g.nodes(utype).empty())
{
return false;
}
@@ -83,12 +81,12 @@ bool is_mutation_supported(Graph &g)
*/
void remove_optimized_nodes(Graph &g)
{
- const std::set<NodeType> optimized_node_types = { NodeType::BatchNormalizationLayer };
+ const std::set<NodeType> optimized_node_types = {NodeType::BatchNormalizationLayer};
- for(const auto &opt_type : optimized_node_types)
+ for (const auto &opt_type : optimized_node_types)
{
const std::vector<NodeID> opt_nodes_ids = g.nodes(opt_type);
- for(const auto &node_id : opt_nodes_ids)
+ for (const auto &node_id : opt_nodes_ids)
{
INode *node = g.node(node_id);
@@ -108,7 +106,7 @@ void remove_optimized_nodes(Graph &g)
g.remove_node(node->id());
// Update connections
- for(auto &driving_node : driving_nodes)
+ for (auto &driving_node : driving_nodes)
{
g.add_connection(producer->id(), producer_edge_id, driving_node.node_id, driving_node.index);
}
@@ -123,11 +121,11 @@ void remove_optimized_nodes(Graph &g)
void convert_tensors(Graph &g, DataType data_type)
{
auto &tensors = g.tensors();
- for(auto &tensor : tensors)
+ for (auto &tensor : tensors)
{
- if(tensor != nullptr)
+ if (tensor != nullptr)
{
- switch(data_type)
+ switch (data_type)
{
case DataType::QASYMM8:
case DataType::QASYMM8_SIGNED:
@@ -156,7 +154,7 @@ template <typename NT>
void convert_special_node(Graph &g, std::function<bool(INode *, Tensor *)> const &f)
{
const std::vector<NodeID> nodes_ids = g.nodes(NT::node_type);
- for(const auto &nodes_id : nodes_ids)
+ for (const auto &nodes_id : nodes_ids)
{
INode *node = arm_compute::utils::cast::polymorphic_downcast<NT *>(g.node(nodes_id));
ARM_COMPUTE_ERROR_ON(node == nullptr);
@@ -174,41 +172,41 @@ void convert_special_node(Graph &g, std::function<bool(INode *, Tensor *)> const
*/
void convert_special_tensors(Graph &g)
{
- auto softmax_func = [](INode * node, Tensor * tensor)
+ auto softmax_func = [](INode *node, Tensor *tensor)
{
ARM_COMPUTE_UNUSED(node);
- if(tensor->desc().data_type == DataType::QASYMM8)
+ if (tensor->desc().data_type == DataType::QASYMM8)
{
tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0);
}
- else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED)
+ else if (tensor->desc().data_type == DataType::QASYMM8_SIGNED)
{
tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128);
}
return true;
};
- auto act_func = [](INode * node, Tensor * tensor)
+ auto act_func = [](INode *node, Tensor *tensor)
{
auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(node);
- if(tensor->desc().data_type == DataType::QASYMM8)
+ if (tensor->desc().data_type == DataType::QASYMM8)
{
- if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH)
+ if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH)
{
tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128);
}
- else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
+ else if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
{
tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0);
}
}
- else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED)
+ else if (tensor->desc().data_type == DataType::QASYMM8_SIGNED)
{
- if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH)
+ if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH)
{
tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 0);
}
- else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
+ else if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
{
tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128);
}
@@ -228,22 +226,19 @@ void convert_special_tensors(Graph &g)
*/
void handle_nodes_with_bias(Graph &g)
{
- const std::set<NodeType> special_node_types = { NodeType::ConvolutionLayer,
- NodeType::DeconvolutionLayer,
- NodeType::DepthwiseConvolutionLayer,
- NodeType::FullyConnectedLayer
- };
+ const std::set<NodeType> special_node_types = {NodeType::ConvolutionLayer, NodeType::DeconvolutionLayer,
+ NodeType::DepthwiseConvolutionLayer, NodeType::FullyConnectedLayer};
- for(const auto &spc_type : special_node_types)
+ for (const auto &spc_type : special_node_types)
{
const std::vector<NodeID> scp_nodes_ids = g.nodes(spc_type);
- for(const auto &node_id : scp_nodes_ids)
+ for (const auto &node_id : scp_nodes_ids)
{
INode *node = g.node(node_id);
- if(node != nullptr)
+ if (node != nullptr)
{
Tensor *tensor = node->input(2);
- if(tensor != nullptr)
+ if (tensor != nullptr)
{
tensor->desc().data_type = DataType::S32;
}
@@ -253,8 +248,8 @@ void handle_nodes_with_bias(Graph &g)
params.name = params.name.empty() ? "" : params.name + "Bias";
TensorDescriptor b_desc = node->input(1)->desc();
- auto depth = b_desc.shape[get_dimension_idx(b_desc.layout, DataLayoutDimension::BATCHES)];
- b_desc.shape = TensorShape(depth);
+ auto depth = b_desc.shape[get_dimension_idx(b_desc.layout, DataLayoutDimension::BATCHES)];
+ b_desc.shape = TensorShape(depth);
auto accessor = std::make_unique<EmptyAccessor>();
auto b_nid = GraphBuilder::add_const_node(g, params, b_desc, std::move(accessor));
@@ -266,8 +261,7 @@ void handle_nodes_with_bias(Graph &g)
}
} // namespace
-SyntheticDataTypeMutator::SyntheticDataTypeMutator(DataType mutate_type)
- : _mutate_type{ mutate_type }
+SyntheticDataTypeMutator::SyntheticDataTypeMutator(DataType mutate_type) : _mutate_type{mutate_type}
{
}
@@ -283,7 +277,7 @@ IGraphMutator::MutationType SyntheticDataTypeMutator::type() const
void SyntheticDataTypeMutator::mutate(Graph &g)
{
- if(is_mutation_supported(g))
+ if (is_mutation_supported(g))
{
// Remove nodes that get optimized out (e.g. BatchNorm)
remove_optimized_nodes(g);