diff options
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r-- | reference_model/src/ops/image.cc | 55 |
1 files changed, 36 insertions, 19 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 190b354..ca12cfe 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -35,14 +35,14 @@ OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Resize); } -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> OpResize<InDtype, OutDtype, resize_t>::~OpResize() { if (attribute) delete attribute; } -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -64,7 +64,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) + if (OutDtype != TOSA_REF_TYPE_INT32 && OutDtype != TOSA_REF_TYPE_INT48 && OutDtype != TOSA_REF_TYPE_FP32 && + OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -72,7 +73,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) + if (OutDtype != TOSA_REF_TYPE_INT8 && OutDtype != TOSA_REF_TYPE_INT16 && OutDtype != TOSA_REF_TYPE_FP32 && + OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -87,7 +89,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() return 0; } -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> int OpResize<InDtype, OutDtype, resize_t>::eval() { int in_batch = in->getShape()[0]; @@ -157,24 +159,38 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() int32_t y = oy * scale_y_d + offset_y; int32_t x = ox * scale_x_d + offset_x; - float fy = static_cast<float>(y) / static_cast<float>(scale_y_n); - float fx = static_cast<float>(x) / static_cast<float>(scale_x_n); - - int32_t iy = floor(fy); - int32_t ix = floor(fx); - + int32_t iy; + int32_t ix; resize_t dy; resize_t dx; - if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || - (typeid(resize_t) == typeid(half_float::half))) + if (std::is_same<resize_t, double>::value) { - dy = (resize_t)(fy - iy); - dx = (resize_t)(fx - ix); + const double fy_double = static_cast<double>(y) / static_cast<double>(scale_y_n); + const double fx_double = static_cast<double>(x) / static_cast<double>(scale_x_n); + iy = floor(fy_double); + ix = floor(fx_double); + + dy = (resize_t)(fy_double - iy); + dx = (resize_t)(fx_double - ix); } else { - dy = (resize_t)(y - (iy * scale_y_n)); - dx = (resize_t)(x - (ix * scale_x_n)); + const float fy = static_cast<float>(y) / static_cast<float>(scale_y_n); + const float fx = static_cast<float>(x) / static_cast<float>(scale_x_n); + iy = floor(fy); + ix = floor(fx); + + if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) + { + dy = (resize_t)(fy - iy); + dx = (resize_t)(fx - ix); + } + else + { + dy = (resize_t)(y - (iy * scale_y_n)); + dx = (resize_t)(x - (ix * scale_x_n)); + } } int32_t iy0 = MAX(iy, 0); @@ -248,3 +264,4 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double); |