aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/image.cc
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-10-19 12:20:31 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-11-09 12:19:51 +0000
commit24dbc420aae556649f50e645bd94489dab2cc75a (patch)
tree490345da43e9c5bae0f450ba05ffe85874077e0a /reference_model/src/ops/image.cc
parent3b0544c1e7463295c49a48a162ebb9a546326829 (diff)
downloadreference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward <james.ward@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r--reference_model/src/ops/image.cc29
1 files changed, 21 insertions, 8 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index cf1d9f7..66efee0 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -63,7 +63,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -71,7 +71,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -159,15 +159,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value)
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
{
- dy = fy - iy;
- dx = fx - ix;
+ dy = (resize_t)(fy - iy);
+ dx = (resize_t)(fx - ix);
}
else
{
- dy = y - (iy * scale_y_n);
- dx = x - (ix * scale_x_n);
+ dy = (resize_t)(y - (iy * scale_y_n));
+ dx = (resize_t)(x - (ix * scale_x_n));
}
int32_t iy0 = MAX(iy, 0);
@@ -190,6 +190,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
acc += (OutEigenType)v10 * dy * (1.0 - dx);
acc += (OutEigenType)v11 * dy * dx;
}
+ else if ((typeid(resize_t) == typeid(Eigen::bfloat16)))
+ {
+ Eigen::bfloat16 bf16_acc;
+ bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx);
+ bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx;
+ bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx);
+ bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx;
+ acc = (float)bf16_acc;
+ }
else
{
acc = (OutEigenType)v00 * (scale_y_n - dy) * (scale_x_n - dx);
@@ -201,7 +210,7 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
else
{
ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode");
- if (std::is_floating_point<resize_t>::value)
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
{
iy = (dy >= 0.5) ? iy1 : iy0;
ix = (dx >= 0.5) ? ix1 : ix0;
@@ -213,6 +222,9 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
}
acc = in->getTensor()(b, iy, ix, c);
}
+ if ((typeid(resize_t) == typeid(Eigen::bfloat16))) {
+ ASSERT_MSG(checkValidBFloat(acc), "Resize accumulator float value is not a valid bfloat16 value.");
+ }
out->getTensor()(b, oy, ox, c) = acc;
}
@@ -225,4 +237,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);