diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/core/Utils.cpp | 3 | ||||
-rw-r--r-- | src/graph/Utils.cpp | 19 | ||||
-rw-r--r-- | src/graph/mutators/SyntheticDataTypeMutator.cpp | 59 |
3 files changed, 67 insertions, 14 deletions
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp index babf1c4b91..8deb5979ac 100644 --- a/src/core/Utils.cpp +++ b/src/core/Utils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020 Arm Limited. + * Copyright (c) 2016-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -330,6 +330,7 @@ DataType data_type_from_name(const std::string &name) { "f16", DataType::F16 }, { "f32", DataType::F32 }, { "qasymm8", DataType::QASYMM8 }, + { "qasymm8_signed", DataType::QASYMM8_SIGNED }, }; #ifndef ARM_COMPUTE_EXCEPTIONS_DISABLED diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp index 2835af311a..8e12689fb9 100644 --- a/src/graph/Utils.cpp +++ b/src/graph/Utils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -81,9 +81,22 @@ PassManager create_default_pass_manager(Target target, const GraphConfig &cfg) const bool is_target_gc = target == Target::GC; // Passes that mutate graph IR - if(cfg.convert_to_uint8) + if(cfg.use_synthetic_type) { - pm.append(std::make_unique<SyntheticDataTypeMutator>(), !is_target_gc); + switch(cfg.synthetic_type) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + pm.append(std::make_unique<SyntheticDataTypeMutator>(cfg.synthetic_type), !is_target_gc); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported DataType for SyntheticDataTypeMutator"); + break; + } + } } pm.append(std::make_unique<NodeFusionMutator>(), !is_target_gc); pm.append(std::make_unique<GroupedConvolutionMutator>()); diff --git a/src/graph/mutators/SyntheticDataTypeMutator.cpp b/src/graph/mutators/SyntheticDataTypeMutator.cpp index 21bafa61e1..74d040b81d 100644 --- a/src/graph/mutators/SyntheticDataTypeMutator.cpp +++ b/src/graph/mutators/SyntheticDataTypeMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -120,15 +120,28 @@ void remove_optimized_nodes(Graph &g) * * @param[in,out] g Graph to convert tensors of. */ -void convert_tensors(Graph &g) +void convert_tensors(Graph &g, DataType data_type) { auto &tensors = g.tensors(); for(auto &tensor : tensors) { if(tensor != nullptr) { - tensor->desc().data_type = DataType::QASYMM8; - tensor->desc().quant_info = QuantizationInfo(0.125f, -10); + switch(data_type) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + { + tensor->desc().quant_info = QuantizationInfo(0.125f, -10); + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported mutation type"); + break; + } + } + tensor->desc().data_type = data_type; } } } @@ -164,20 +177,41 @@ void convert_special_tensors(Graph &g) auto softmax_func = [](INode * node, Tensor * tensor) { ARM_COMPUTE_UNUSED(node); - tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + if(tensor->desc().data_type == DataType::QASYMM8) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + } + else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128); + } return true; }; auto act_func = [](INode * node, Tensor * tensor) { auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(node); - if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + if(tensor->desc().data_type == DataType::QASYMM8) { - tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128); + if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 128); + } + else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + } } - else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + else if(tensor->desc().data_type == DataType::QASYMM8_SIGNED) { - tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, 0); + if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::TANH) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 128.f, 0); + } + else if(act_node->activation_info().activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + { + tensor->desc().quant_info = QuantizationInfo(1.f / 256.f, -128); + } } return true; }; @@ -232,6 +266,11 @@ void handle_nodes_with_bias(Graph &g) } } // namespace +SyntheticDataTypeMutator::SyntheticDataTypeMutator(DataType mutate_type) + : _mutate_type{ mutate_type } +{ +} + const char *SyntheticDataTypeMutator::name() { return "SyntheticDataTypeMutator"; @@ -250,7 +289,7 @@ void SyntheticDataTypeMutator::mutate(Graph &g) remove_optimized_nodes(g); // Convert tensor - convert_tensors(g); + convert_tensors(g, _mutate_type); convert_special_tensors(g); // Handle special nodes |