// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "type_conversion.h" #include "quant_util.h" #include "template_types.h" #include using namespace TosaReference; using namespace Eigen; using namespace tosa; template OpRescale::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_RESCALE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); INIT_ATTRIBUTE(Rescale); } template OpRescale::~OpRescale() { if (attribute) delete attribute; } template int OpRescale::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) { return 1; } // output and input must be the same rank and size if (inputs[0]->matchRankSize(*outputs[0])) { printNodeValidationError("OpRescale: input and output rank/size must match"); return 1; } in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); ASSERT_MEM(in && out); return 0; } template int OpRescale::eval() { int32_t input_zp = attribute->input_zp(); int32_t output_zp = attribute->output_zp(); std::vector multiplier = attribute->multiplier(); std::vector shift = attribute->shift(); //bool scale32 = attribute->scale32(); bool double_round = attribute->double_round(); bool per_channel = attribute->per_channel(); if (TosaReference::TypeChecker::is_symmetric(InDtype)) { if (input_zp != 0) { FATAL_ERROR_NODE("input tensor is symmetric type %s but zeropoint is %d instead of 0", EnumNamesDType()[InDtype], input_zp); } } if (TosaReference::TypeChecker::is_symmetric(OutDtype)) { if (output_zp != 0) { FATAL_ERROR_NODE("output tensor is symmetric type %s but zeropoint is %d instead of 0", EnumNamesDType()[OutDtype], output_zp); } } // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn] Eigen::array shape_2d; shape_2d[0] = 1; if (Rank > 0) { for (int i = 0; i < Rank - 1; i++) { shape_2d[0] *= this->in->getShape()[i]; } shape_2d[1] = this->in->getShape()[Rank - 1]; } else { shape_2d[1] = 1; } ETensor2 input_reshaped = this->in->getTensor().reshape(shape_2d); ETensor2 output_2d(shape_2d); // TODO: pass scale32 in when 16-bit mode implemented if (per_channel) { ETensor2 curr_channel_slice_prescaled; ETensor2 curr_channel_slice_postscaled; int32_t channel_multiplier, channel_shift; Eigen::array begin, size; size = Eigen::array({ shape_2d[0], 1 }); for (int32_t i = 0; i < shape_2d[1]; i++) { begin = Eigen::array({ 0, i }); curr_channel_slice_prescaled = input_reshaped.slice(begin, size); channel_multiplier = multiplier[i]; channel_shift = shift[i]; curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; int32_t scaled = TosaReference::QuantUtil::apply_scale( input_zp_shifted, channel_multiplier, channel_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max(out_val, QMin); out_val = std::min(out_val, QMax); return out_val; }); for (int32_t j = 0; j < shape_2d[0]; j++) { output_2d(j, i) = curr_channel_slice_postscaled(j, 0); } } } else { int32_t tensor_multiplier = multiplier[0]; int32_t tensor_shift = shift[0]; output_2d = input_reshaped.unaryExpr( [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; int32_t scaled = TosaReference::QuantUtil::apply_scale(input_zp_shifted, tensor_multiplier, tensor_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max(out_val, QMin); out_val = std::min(out_val, QMax); return out_val; }); } // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn] Eigen::array output_shape; for (int i = 0; i < Rank; i++) { output_shape[i] = this->out->getShape()[i]; } this->out->getTensor() = output_2d.reshape(output_shape); return GraphNode::eval(); } template OpCast::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_CAST, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); } template OpCast::~OpCast() {} template int OpCast::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) { return 1; } // output and input must be the same rank and size if (inputs[0]->matchRankSize(*outputs[0])) { printNodeValidationError("OpCast: input and output rank/size must match"); return 1; } in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); ASSERT_MEM(in && out); return 0; } template int OpCast::eval() { this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn()); return GraphNode::eval(); } template CastHelper::CastHelper() { fcn = [](InEigenType in) -> OutEigenType { OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t) int64_t mask = (1L << OutBits) - 1; out = out & mask; return out; }; } template CastHelper::CastHelper() { fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; }; } template CastHelper::CastHelper() { fcn = [](bool in) -> OutEigenType { OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0; return out; }; } template CastHelper::CastHelper() { fcn = [](InEigenType in) -> float { float out = (OutEigenType)in; // default cast to float is round_to_nearest_float() return out; }; } template CastHelper::CastHelper() { fcn = [](float in) -> OutEigenType { OutEigenType out = std::round(in); out = std::max(out, OutMin); out = std::min(out, OutMax); return out; }; } // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);