From 0ae5de9124a0094e656244ad2f807c084966fc04 Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Thu, 14 Mar 2019 10:32:11 +0000 Subject: COMPMID-1995: Prepare Graph to support different input/output quantization info - Added support for different input/output qinfo in ActivationLayer and DepthwiseConv - Added support for different input/output qinfo in ConcatenateLayer introducing ConcatDescriptor - Added reshape validate - Allow OutputLayer to return a specific connection index from the input - Not run Inplace and Depth mutator when input/output quantization info are different Change-Id: I03f5e416fc43ddd284e1501887202a3145f76d8a Signed-off-by: Isabella Gottardi Reviewed-on: https://review.mlplatform.org/c/852 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Georgios Pinitas --- src/graph/mutators/DepthConcatSubTensorMutator.cpp | 7 ++++--- src/graph/mutators/InPlaceOperationMutator.cpp | 6 +++--- src/graph/mutators/NodeFusionMutator.cpp | 11 +++++++++-- 3 files changed, 16 insertions(+), 8 deletions(-) (limited to 'src/graph/mutators') diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp index a170c4d899..0e0a26b886 100644 --- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp +++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -69,11 +69,12 @@ void DepthConcatSubTensorMutator::mutate(Graph &g) continue; } - // Check that all tensor have the same target and valid inputs + // Check that all tensor have the same target, valid inputs and same quantization info bool is_valid = std::all_of(node->input_edges().cbegin(), node->input_edges().cend(), [&](const EdgeID & eid) { - return (g.edge(eid) != nullptr) && (g.edge(eid)->tensor() != nullptr) && (g.edge(eid)->tensor()->desc().target == output_tensor->desc().target); + return (g.edge(eid) != nullptr) && (g.edge(eid)->tensor() != nullptr) && (g.edge(eid)->tensor()->desc().target == output_tensor->desc().target) + && (g.edge(eid)->tensor()->desc().quant_info == output_tensor->desc().quant_info); }); // Create subtensors diff --git a/src/graph/mutators/InPlaceOperationMutator.cpp b/src/graph/mutators/InPlaceOperationMutator.cpp index 31921b328e..1c2985dce6 100644 --- a/src/graph/mutators/InPlaceOperationMutator.cpp +++ b/src/graph/mutators/InPlaceOperationMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -56,8 +56,8 @@ void InPlaceOperationMutator::mutate(Graph &g) ARM_COMPUTE_ERROR_ON(current_output_tensor == nullptr || new_output_tensor == nullptr); - // Prevent in-place operation if there is an accessor bound to the in-place tensor - if(new_output_tensor->accessor() == nullptr) + // Prevent in-place operation if there is an accessor bound to the in-place tensor or quantization info are different + if(new_output_tensor->accessor() == nullptr || current_output_tensor->desc().quant_info == new_output_tensor->desc().quant_info) { ARM_COMPUTE_LOG_GRAPH_VERBOSE("Switching to in-place computation for the node with ID : " << node->id() << " and name : " << node->name() << std::endl); diff --git a/src/graph/mutators/NodeFusionMutator.cpp b/src/graph/mutators/NodeFusionMutator.cpp index 5927a597bb..724307e7b7 100644 --- a/src/graph/mutators/NodeFusionMutator.cpp +++ b/src/graph/mutators/NodeFusionMutator.cpp @@ -211,10 +211,17 @@ void NodeFusionMutator::mutate(Graph &g) { return true; }; - auto qs8_prec = [](INode & n) + auto qs8_prec = [&g](INode & n) { ARM_COMPUTE_ERROR_ON(n.output(0) == nullptr); - return n.output(0)->desc().data_type == DataType::QASYMM8; + + const auto output_edge_id = *n.output_edges().begin(); + const auto output_edge = g.edge(output_edge_id); + // To perform fusion the two nodes must have same output quantization information + const bool same_qinfo = n.output(0)->desc().quant_info == output_edge->producer()->output(0)->desc().quant_info; + const bool output_qasymm8 = n.output(0)->desc().data_type == DataType::QASYMM8; + + return output_qasymm8 && same_qinfo; }; // Fusion mutations -- cgit v1.2.1