// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "image.h" #include "arith_util.h" #include "quant_util.h" using namespace TosaReference; using namespace Eigen; using namespace tosa; template OpResize::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_RESIZE, id_) { setRequiredOperands(1, 1); setRequiredRank(4, 4); INIT_ATTRIBUTE(Resize); } template OpResize::~OpResize() { if (attribute) delete attribute; } template int OpResize::checkTensorAttributes() { if (validateRequiredOperands()) return 1; if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) return 1; output_size = this->attribute->output_size(); stride = this->attribute->stride(); offset = this->attribute->offset(); shift = this->attribute->shift(); mode = this->attribute->mode(); int output_height = outputs[0]->getShape()[1]; int output_width = outputs[0]->getShape()[2]; if (this->mode == ResizeMode_BILINEAR) { if (OutDtype != DType_INT32 && OutDtype != DType_INT48) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; } } else { if (OutDtype != DType_INT8 && OutDtype != DType_INT16) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; } } if (output_size[0] != output_height || output_size[1] != output_width) { printNodeValidationError("OpResize: attribute output_size doesn't match output [height, width]"); return 1; } if (shift < 1 || shift > 11) { printNodeValidationError("OpResize: attribute shift should be within [1, 11]"); return 1; } if (stride[0] <= 0 || stride[1] <= 0) { printNodeValidationError("OpResize: invalid attribute stride"); return 1; } in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); ASSERT_MEM(in && out); return 0; } template int OpResize::eval() { int in_batch = in->getShape()[0]; int in_height = in->getShape()[1]; int in_width = in->getShape()[2]; int in_channels = in->getShape()[3]; int out_batch = out->getShape()[0]; int out_height = out->getShape()[1]; int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); for (int b = 0; b < out_batch; b++) for (int c = 0; c < out_channels; c++) for (int oy = 0; oy < out_height; oy++) for (int ox = 0; ox < out_width; ox++) { int y = oy * stride[0] + offset[0]; int x = ox * stride[1] + offset[1]; int iy = y >> shift; int dy = y - (iy << shift); int ix = x >> shift; int dx = x - (ix << shift); int iy0 = MAX(iy, 0); int iy1 = MIN(iy + 1, in_height - 1); int ix0 = MAX(ix, 0); int ix1 = MIN(ix + 1, in_width - 1); ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, iy1, ix0, ix1); InEigenType v00 = in->getTensor()(b, iy0, ix0, c); InEigenType v01 = in->getTensor()(b, iy0, ix1, c); InEigenType v10 = in->getTensor()(b, iy1, ix0, c); InEigenType v11 = in->getTensor()(b, iy1, ix1, c); OutEigenType acc; if (mode == ResizeMode_BILINEAR) { acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx); acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx; acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx); acc = acc + (OutEigenType)v11 * dy * dx; } else { iy = (dy >> (shift - 1)) != 0 ? iy1 : iy0; ix = (dx >> (shift - 1)) != 0 ? ix1 : ix0; acc = in->getTensor()(b, iy, ix, c); } out->getTensor()(b, oy, ox, c) = acc; } return GraphNode::eval(); } // template explicit instantiation DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);