diff options
Diffstat (limited to 'src/graph/mutators/SyntheticDataTypeMutator.cpp')
-rw-r--r-- | src/graph/mutators/SyntheticDataTypeMutator.cpp | 111 |
1 files changed, 72 insertions, 39 deletions
diff --git a/src/graph/mutators/SyntheticDataTypeMutator.cpp b/src/graph/mutators/SyntheticDataTypeMutator.cpp index 532c0e821b..3dc2480e85 100644 --- a/src/graph/mutators/SyntheticDataTypeMutator.cpp +++ b/src/graph/mutators/SyntheticDataTypeMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -26,8 +26,8 @@ #include "arm_compute/graph/GraphBuilder.h" #include "arm_compute/graph/ITensorAccessor.h" #include "arm_compute/graph/Logger.h" -#include "arm_compute/graph/Utils.h" #include "arm_compute/graph/nodes/Nodes.h" +#include "arm_compute/graph/Utils.h" #include "support/Cast.h" @@ -62,14 +62,12 @@ public: */ bool is_mutation_supported(Graph &g) { - const std::set<NodeType> unsupported_node_types = { NodeType::DetectionOutputLayer, - NodeType::NormalizationLayer, - NodeType::PriorBoxLayer - }; + const std::set<NodeType> unsupported_node_types = {NodeType::DetectionOutputLayer, NodeType::NormalizationLayer, + NodeType::PriorBoxLayer}; - for(const auto &utype : unsupported_node_types) + for (const auto &utype : unsupported_node_types) { - if(!g.nodes(utype).empty()) + if (!g.nodes(utype).empty()) { return false; } @@ -83,12 +81,12 @@ bool is_mutation_supported(Graph &g) */ void remove_optimized_nodes(Graph &g) { - const std::set<NodeType> optimized_node_types = { NodeType::BatchNormalizationLayer }; + const std::set<NodeType> optimized_node_types = {NodeType::BatchNormalizationLayer}; - for(const auto &opt_type : optimized_node_types) + for (const auto &opt_type : optimized_node_types) { const std::vector<NodeID> opt_nodes_ids = g.nodes(opt_type); - for(const auto &node_id : opt_nodes_ids) + for (const auto &node_id : opt_nodes_ids) { INode *node = g.node(node_id); @@ -108,7 +106,7 @@ void remove_optimized_nodes(Graph &g) g.remove_node(node->id()); // Update connections - for(auto &driving_node : driving_nodes) + for (auto &driving_node : driving_nodes) { g.add_connection(producer->id(), producer_edge_id, driving_node.node_id, driving_node.index); } @@ -120,15 +118,28 @@ void remove_optimized_nodes(Graph &g) * * @param[in,out] g Graph to convert tensors of. */ -void convert_tensors(Graph &g) +void convert_tensors(Graph &g, DataType data_type) { auto &tensors = g.tensors(); - for(auto &tensor : tensors) + for (auto &tensor : tensors) { - if(tensor != nullptr) + if (tensor != nullptr) { - tensor->desc().data_type = DataType::QASYMM8; - tensor->desc().quant_info = QuantizationInfo(0.125f, -10); + switch (data_type) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + tensor->desc().quant_info = QuantizationInfo(0.125f, -10); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported mutation type"); + break; + } + } + tensor->desc().data_type = data_type; } } } @@ -143,7 +154,7 @@ template <typename NT> void convert_special_node(Graph &g, std::function<bool(INode *, Tensor *)> const &f) { const std::vector<NodeID> nodes_ids = g.nodes(NT::node_type); - for(const auto &nodes_id : nodes_ids) + for (const auto &nodes_id : nodes_ids) { INode *node = arm_compute::utils::cast::polymorphic_downcast<NT *>(g.node(nodes_id)); ARM_COMPUTE_ERROR_ON(node == nullptr); @@ -161,23 +172,44 @@ void convert_special_node(Graph &g, std::function<bool(INode *, Tensor *)> const */ void convert_special_tensors(Graph &g) { - auto softmax_func = [](INode * node, Tensor * tensor) + auto softmax_func = [](INode *node, Tensor *tensor) { ARM_COMPUTE_UNUSED(node); - tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + if (tensor->desc().data_type == DataType::QASYMM8) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + } + else if (tensor->desc().data_type == DataType::QASYMM8_SIGNED) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128); + } return true; }; - auto act_func = [](INode * node, Tensor * tensor) + auto act_func = [](INode *node, Tensor *tensor) { auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(node); - if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + if (tensor->desc().data_type == DataType::QASYMM8) { - tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128); + if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128); + } + else if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + } } - else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + else if (tensor->desc().data_type == DataType::QASYMM8_SIGNED) { - tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 0); + } + else if (act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128); + } } return true; }; @@ -194,22 +226,19 @@ void convert_special_tensors(Graph &g) */ void handle_nodes_with_bias(Graph &g) { - const std::set<NodeType> special_node_types = { NodeType::ConvolutionLayer, - NodeType::DeconvolutionLayer, - NodeType::DepthwiseConvolutionLayer, - NodeType::FullyConnectedLayer - }; + const std::set<NodeType> special_node_types = {NodeType::ConvolutionLayer, NodeType::DeconvolutionLayer, + NodeType::DepthwiseConvolutionLayer, NodeType::FullyConnectedLayer}; - for(const auto &spc_type : special_node_types) + for (const auto &spc_type : special_node_types) { const std::vector<NodeID> scp_nodes_ids = g.nodes(spc_type); - for(const auto &node_id : scp_nodes_ids) + for (const auto &node_id : scp_nodes_ids) { INode *node = g.node(node_id); - if(node != nullptr) + if (node != nullptr) { Tensor *tensor = node->input(2); - if(tensor != nullptr) + if (tensor != nullptr) { tensor->desc().data_type = DataType::S32; } @@ -219,10 +248,10 @@ void handle_nodes_with_bias(Graph &g) params.name = params.name.empty() ? "" : params.name + "Bias"; TensorDescriptor b_desc = node->input(1)->desc(); - auto depth = b_desc.shape[get_dimension_idx(b_desc.layout, DataLayoutDimension::BATCHES)]; - b_desc.shape = TensorShape(depth); + auto depth = b_desc.shape[get_dimension_idx(b_desc.layout, DataLayoutDimension::BATCHES)]; + b_desc.shape = TensorShape(depth); - auto accessor = support::cpp14::make_unique<EmptyAccessor>(); + auto accessor = std::make_unique<EmptyAccessor>(); auto b_nid = GraphBuilder::add_const_node(g, params, b_desc, std::move(accessor)); g.add_connection(b_nid, 0, node_id, 2); } @@ -232,6 +261,10 @@ void handle_nodes_with_bias(Graph &g) } } // namespace +SyntheticDataTypeMutator::SyntheticDataTypeMutator(DataType mutate_type) : _mutate_type{mutate_type} +{ +} + const char *SyntheticDataTypeMutator::name() { return "SyntheticDataTypeMutator"; @@ -244,13 +277,13 @@ IGraphMutator::MutationType SyntheticDataTypeMutator::type() const void SyntheticDataTypeMutator::mutate(Graph &g) { - if(is_mutation_supported(g)) + if (is_mutation_supported(g)) { // Remove nodes that get optimized out (e.g. BatchNorm) remove_optimized_nodes(g); // Convert tensor - convert_tensors(g); + convert_tensors(g, _mutate_type); convert_special_tensors(g); // Handle special nodes |