aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/image.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r--reference_model/src/ops/image.cc55
1 files changed, 36 insertions, 19 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 190b354..ca12cfe 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -35,14 +35,14 @@ OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Resize);
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::~OpResize()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -64,7 +64,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
+ if (OutDtype != TOSA_REF_TYPE_INT32 && OutDtype != TOSA_REF_TYPE_INT48 && OutDtype != TOSA_REF_TYPE_FP32 &&
+ OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -72,7 +73,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
+ if (OutDtype != TOSA_REF_TYPE_INT8 && OutDtype != TOSA_REF_TYPE_INT16 && OutDtype != TOSA_REF_TYPE_FP32 &&
+ OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -87,7 +89,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::eval()
{
int in_batch = in->getShape()[0];
@@ -157,24 +159,38 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
int32_t y = oy * scale_y_d + offset_y;
int32_t x = ox * scale_x_d + offset_x;
- float fy = static_cast<float>(y) / static_cast<float>(scale_y_n);
- float fx = static_cast<float>(x) / static_cast<float>(scale_x_n);
-
- int32_t iy = floor(fy);
- int32_t ix = floor(fx);
-
+ int32_t iy;
+ int32_t ix;
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
- (typeid(resize_t) == typeid(half_float::half)))
+ if (std::is_same<resize_t, double>::value)
{
- dy = (resize_t)(fy - iy);
- dx = (resize_t)(fx - ix);
+ const double fy_double = static_cast<double>(y) / static_cast<double>(scale_y_n);
+ const double fx_double = static_cast<double>(x) / static_cast<double>(scale_x_n);
+ iy = floor(fy_double);
+ ix = floor(fx_double);
+
+ dy = (resize_t)(fy_double - iy);
+ dx = (resize_t)(fx_double - ix);
}
else
{
- dy = (resize_t)(y - (iy * scale_y_n));
- dx = (resize_t)(x - (ix * scale_x_n));
+ const float fy = static_cast<float>(y) / static_cast<float>(scale_y_n);
+ const float fx = static_cast<float>(x) / static_cast<float>(scale_x_n);
+ iy = floor(fy);
+ ix = floor(fx);
+
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
+ {
+ dy = (resize_t)(fy - iy);
+ dx = (resize_t)(fx - ix);
+ }
+ else
+ {
+ dy = (resize_t)(y - (iy * scale_y_n));
+ dx = (resize_t)(x - (ix * scale_x_n));
+ }
}
int32_t iy0 = MAX(iy, 0);
@@ -248,3 +264,4 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double);