aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/image.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r--reference_model/src/ops/image.cc29
1 files changed, 21 insertions, 8 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index cf1d9f7..66efee0 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -63,7 +63,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -71,7 +71,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -159,15 +159,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value)
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
{
- dy = fy - iy;
- dx = fx - ix;
+ dy = (resize_t)(fy - iy);
+ dx = (resize_t)(fx - ix);
}
else
{
- dy = y - (iy * scale_y_n);
- dx = x - (ix * scale_x_n);
+ dy = (resize_t)(y - (iy * scale_y_n));
+ dx = (resize_t)(x - (ix * scale_x_n));
}
int32_t iy0 = MAX(iy, 0);
@@ -190,6 +190,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
acc += (OutEigenType)v10 * dy * (1.0 - dx);
acc += (OutEigenType)v11 * dy * dx;
}
+ else if ((typeid(resize_t) == typeid(Eigen::bfloat16)))
+ {
+ Eigen::bfloat16 bf16_acc;
+ bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx);
+ bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx;
+ bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx);
+ bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx;
+ acc = (float)bf16_acc;
+ }
else
{
acc = (OutEigenType)v00 * (scale_y_n - dy) * (scale_x_n - dx);
@@ -201,7 +210,7 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
else
{
ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode");
- if (std::is_floating_point<resize_t>::value)
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
{
iy = (dy >= 0.5) ? iy1 : iy0;
ix = (dx >= 0.5) ? ix1 : ix0;
@@ -213,6 +222,9 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
}
acc = in->getTensor()(b, iy, ix, c);
}
+ if ((typeid(resize_t) == typeid(Eigen::bfloat16))) {
+ ASSERT_MSG(checkValidBFloat(acc), "Resize accumulator float value is not a valid bfloat16 value.");
+ }
out->getTensor()(b, oy, ox, c) = acc;
}
@@ -225,4 +237,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);