From bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 27 Sep 2022 13:50:00 +0100 Subject: Rename FLOAT type to FP32 Update tensor operations naming to state input type as TxT in all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D, FULLY_CONNECTED, TRANSPOSE_CONV2D. Signed-off-by: Jeremy Johnson Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205 --- reference_model/src/ops/activation_funcs.cc | 12 +-- reference_model/src/ops/comparison.cc | 12 +-- reference_model/src/ops/data_layout.cc | 16 +-- reference_model/src/ops/data_nodes.cc | 2 +- reference_model/src/ops/ewise_binary.cc | 26 ++--- reference_model/src/ops/ewise_binary.h | 4 +- reference_model/src/ops/ewise_ternary.cc | 2 +- reference_model/src/ops/ewise_unary.cc | 32 +++--- reference_model/src/ops/image.cc | 6 +- reference_model/src/ops/op_factory.cc | 116 +++++++++++----------- reference_model/src/ops/op_factory.h | 2 +- reference_model/src/ops/reduction.cc | 8 +- reference_model/src/ops/scatter_gather.cc | 4 +- reference_model/src/ops/template_types.h | 6 +- reference_model/src/ops/tensor_ops.cc | 34 +++---- reference_model/src/ops/type_conversion.cc | 16 +-- reference_model/src/ops/type_conversion.h | 8 +- reference_model/src/subgraph_traverser.cc | 4 +- reference_model/src/tensor.cc | 12 +-- reference_model/src/tensor.h | 2 +- thirdparty/serialization_lib | 2 +- verif/generator/tosa_arg_gen.py | 38 +++---- verif/generator/tosa_error_if.py | 56 ++++++----- verif/generator/tosa_test_gen.py | 148 +++++++++++----------------- verif/generator/tosa_utils.py | 25 +++-- verif/tests/test_tosa_refmodel.py | 10 +- 26 files changed, 297 insertions(+), 306 deletions(-) diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 1c0c23a..61f7df6 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -28,7 +28,7 @@ int OpClamp::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: { InEigenType min = (InEigenType)attribute->min_fp(); InEigenType max = (InEigenType)attribute->max_fp(); @@ -59,7 +59,7 @@ int OpSigmoid::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; break; default: @@ -75,7 +75,7 @@ int OpTanh::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; break; default: @@ -87,12 +87,12 @@ int OpTanh::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index 5930c1a..f240aa5 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -28,7 +28,7 @@ int OpEqual::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; break; @@ -45,7 +45,7 @@ int OpGreater::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; break; @@ -62,7 +62,7 @@ int OpGreaterEqual::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; break; @@ -75,13 +75,13 @@ int OpGreaterEqual::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 1ed0be2..69b6a65 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -191,7 +191,7 @@ int OpPad::eval() pad_value = (InEigenType)attribute->pad_const_int(); break; case DType_FP16: - case DType_FLOAT: + case DType_FP32: pad_value = (InEigenType)attribute->pad_const_fp(); break; default: @@ -639,49 +639,49 @@ int OpTranspose::eval() // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16) -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); -DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP32); DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 4ff08be..5709a92 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -90,7 +90,7 @@ int OpIdentity::eval() // note OpConst is not templated DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 917d56e..098b0ea 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -143,7 +143,7 @@ int OpAdd::register_fcn() }; break; case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; default: @@ -371,7 +371,7 @@ int OpMaximum::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; @@ -388,7 +388,7 @@ int OpMinimum::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; @@ -407,7 +407,7 @@ int OpMul::register_fcn() switch (InDtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; break; case DType_INT32: @@ -457,7 +457,7 @@ int OpPow::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; break; default: @@ -482,7 +482,7 @@ int OpSub::register_fcn() }; break; case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; default: @@ -581,7 +581,7 @@ int OpTable::eval() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); @@ -617,24 +617,24 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); @@ -643,5 +643,5 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); // Instantiation of nodes for comparison operators opEqual, opGreater // and opGreaterEqual DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index b2c92a4..6b0efaf 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ namespace TosaReference // Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.) // which might be faster since it could be implemented with SIMD instructions // the way of registering lambda + .binaryExpr() might sacrifice performance here -// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...} +// but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...} // needs to revisit if performance becomes a bottleneck here template class BinaryNodeBase : public GraphNode diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index da046a7..d85da1a 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -108,7 +108,7 @@ int OpSelect<0, Dtype>::eval() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 52f5aff..00897cc 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -78,7 +78,7 @@ int OpAbs::register_fcn() { switch (Dtype) { - case DType_FLOAT: + case DType_FP32: case DType_FP16: case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; @@ -113,7 +113,7 @@ int OpCeil::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; break; default: @@ -161,7 +161,7 @@ int OpExp::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; break; default: @@ -177,7 +177,7 @@ int OpFloor::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; break; default: @@ -193,7 +193,7 @@ int OpLog::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; break; default: @@ -245,7 +245,7 @@ int OpNegate::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); return result; @@ -297,7 +297,7 @@ int OpReciprocal::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; break; default: @@ -313,7 +313,7 @@ int OpRsqrt::register_fcn() switch (Dtype) { case DType_FP16: - case DType_FLOAT: + case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; break; default: @@ -325,7 +325,7 @@ int OpRsqrt::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); @@ -333,29 +333,29 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 891261b..cf1d9f7 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -63,7 +63,7 @@ int OpResize::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT && OutDtype != DType_FP16) + if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -71,7 +71,7 @@ int OpResize::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT && OutDtype != DType_FP16) + if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -225,4 +225,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); -DEF_INSTANTIATE_THREE_TYPE(OpResize, FLOAT, FLOAT, float); +DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index fd73eb5..1ff8229 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -49,66 +49,66 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // tensor_ops case Op_ARGMAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); break; case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FLOAT); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); break; case Op_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FLOAT); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48); break; case Op_CONV3D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FLOAT); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48); break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FLOAT); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48); break; case Op_FULLY_CONNECTED: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FLOAT); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48); break; case Op_MATMUL: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FLOAT); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); - DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); break; case Op_TRANSPOSE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FLOAT); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48); @@ -117,23 +117,23 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // activation_funcs case Op_CLAMP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); break; case Op_SIGMOID: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); break; case Op_TANH: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); break; // ewise_binary case Op_ADD: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); break; case Op_ARITHMETIC_RIGHT_SHIFT: @@ -180,28 +180,28 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_MAXIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); break; case Op_MINIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); break; case Op_MUL: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); break; case Op_POW: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); break; case Op_SUB: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; case Op_TABLE: @@ -212,7 +212,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // ewise_unary case Op_ABS: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); break; case Op_BITWISE_NOT: @@ -222,46 +222,46 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_CEIL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); break; case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); break; case Op_EXP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); break; case Op_FLOOR: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); break; case Op_LOG: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); break; case Op_LOGICAL_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); break; case Op_NEGATE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); break; case Op_RECIPROCAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); break; case Op_RSQRT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); break; // ewise_ternary case Op_SELECT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); @@ -271,17 +271,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // comparison case Op_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); break; case Op_GREATER: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); break; case Op_GREATER_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); break; @@ -294,32 +294,32 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_MAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); break; case Op_REDUCE_MIN: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); break; case Op_REDUCE_PRODUCT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); break; case Op_REDUCE_SUM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; // data layout case Op_CONCAT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); @@ -327,7 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); @@ -335,7 +335,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); - DEF_FACTORY_RESHAPE(OpReshape, FLOAT); + DEF_FACTORY_RESHAPE(OpReshape, FP32); DEF_FACTORY_RESHAPE(OpReshape, INT8); DEF_FACTORY_RESHAPE(OpReshape, INT16); DEF_FACTORY_RESHAPE(OpReshape, INT32); @@ -343,7 +343,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); - DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); @@ -351,7 +351,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_SLICE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); @@ -359,7 +359,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_TILE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); @@ -368,7 +368,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_TRANSPOSE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); @@ -380,14 +380,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, INT16); DEF_FACTORY_ONE_TYPE(OpGather, INT32); DEF_FACTORY_ONE_TYPE(OpGather, FP16); - DEF_FACTORY_ONE_TYPE(OpGather, FLOAT); + DEF_FACTORY_ONE_TYPE(OpGather, FP32); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); DEF_FACTORY_ONE_TYPE(OpScatter, INT16); DEF_FACTORY_ONE_TYPE(OpScatter, INT32); DEF_FACTORY_ONE_TYPE(OpScatter, FP16); - DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT); + DEF_FACTORY_ONE_TYPE(OpScatter, FP32); break; // image @@ -397,7 +397,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16); DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); - DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OpResize, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32); break; // data_nodes @@ -405,7 +405,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, return new OpConst(sgt, id); case Op_IDENTITY: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); @@ -421,23 +421,23 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); - DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); - DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); - DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 25dfc6e..b525e69 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -111,7 +111,7 @@ return new OP(sgt, attribute, id); \ } -#define DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OP, DTYPE1, DTYPE2) \ +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ return new OP(sgt, attribute, id); \ diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index 03ee660..eccba09 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -159,20 +159,20 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index 25174bd..b6c4043 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -227,10 +227,10 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); -DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); -DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 9511c31..3de4899 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -65,7 +65,7 @@ using Tensor6 = TensorTemplate>; template struct GetEigenType; template <> -struct GetEigenType +struct GetEigenType { using type = float; }; @@ -301,9 +301,9 @@ struct GetAccDType static constexpr DType value = DType_FP16; }; template <> -struct GetAccDType +struct GetAccDType { - static constexpr DType value = DType_FLOAT; + static constexpr DType value = DType_FP32; }; }; // namespace TosaReference diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index c617dda..7db5182 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -512,7 +512,7 @@ int OpAvgPool2d::eval() .contract(div_map_w.reshape(Eigen::array{ 1, out_width }), contract_dims) .reshape(Eigen::array{ 1, out_height, out_width, 1 }) .broadcast(bcast); - if (Dtype != DType_FLOAT && Dtype != DType_FP16) + if (Dtype != DType_FP32 && Dtype != DType_FP16) { try { @@ -1679,41 +1679,41 @@ int OpTransposeConv2d::eval() // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); -DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FLOAT); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FLOAT, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); // [in_t, weight_t, acc_t] DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FLOAT); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FLOAT); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FLOAT); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FLOAT); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48); @@ -1721,17 +1721,17 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FLOAT); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FLOAT, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); -DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FLOAT); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 50e710a..f51c38c 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -312,7 +312,7 @@ CastHelper::CastHelper() } template -CastHelper::CastHelper() +CastHelper::CastHelper() { fcn = [](InEigenType in) -> float { float out = (OutEigenType)in; // default cast to float is round_to_nearest_float() @@ -321,7 +321,7 @@ CastHelper::CastHelper() } template -CastHelper::CastHelper() +CastHelper::CastHelper() { fcn = [](float in) -> OutEigenType { OutEigenType out = std::round(in); @@ -339,23 +339,23 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 5f197cf..b0de30c 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -137,11 +137,11 @@ private: }; template -class CastHelper +class CastHelper { public: using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -154,10 +154,10 @@ private: }; template -class CastHelper +class CastHelper { public: - using InEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using FcnType = std::function; static constexpr int32_t OutMin = GetQMin::value; diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index cbb7001..15d82e6 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -394,7 +394,7 @@ int SubgraphTraverser::allocateTensor() tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); } break; - case DType_FLOAT: + case DType_FP32: { std::vector fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); @@ -702,7 +702,7 @@ int SubgraphTraverser::validateGraph() DType dtype = currTensor->getDtype(); // Float-point disallowed - if (dtype == DType_FLOAT || dtype == DType_FP16) + if (dtype == DType_FP32 || dtype == DType_FP16) { WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point " "disabled, but %s tensor %s found\n", diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index cbe12a9..8e65a27 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -93,7 +93,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) switch (getDtype()) { - case DType_FLOAT: + case DType_FP32: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -161,8 +161,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) for (uint32_t i=0; i < elements; i++) { fdatabuf[i] = half_float::half_cast(f16databuf[i]); } - // Fall through to DType_FLOAT case - case DType_FLOAT: + // Fall through to DType_FP32 case + case DType_FP32: if (setTensorValueFloat(elements, fdatabuf)) { if (f16databuf) @@ -229,7 +229,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const switch (getDtype()) { - case DType_FLOAT: + case DType_FP32: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -429,7 +429,7 @@ int TosaReference::Tensor::readfromVector(const std::vector& vals) uint32_t elements = getElementCount(); switch (getDtype()) { - case DType_FLOAT: + case DType_FP32: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -532,7 +532,7 @@ int TosaReference::Tensor::writeToVector(std::vector& vals) switch (getDtype()) { - case DType_FLOAT: + case DType_FP32: if (vals.size() != elements) { WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 78a210e..efd7e62 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -642,7 +642,7 @@ public: { switch (tensorDtype_) { - case DType_FLOAT: + case DType_FP32: case DType_FP16: switch (rank) { diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 485a11d..e1072a9 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 485a11d8cb67c8062c632f0987cd31cedbe93d6d +Subproject commit e1072a9ed871fd474e7b09b7a74ae7be5f0a6f78 diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index e0c6cf0..791fbf7 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -776,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] in (DType.FP16, DType.FLOAT): + if dtypeList[0] in (DType.FP16, DType.FP32): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -1106,10 +1106,10 @@ class TosaArgGen: @staticmethod def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): - if isinstance(dtypes, list) or isinstance(dtypes, tuple): - input_dtype = dtypes[0] - else: - input_dtype = dtypes + assert isinstance(dtypes, list) or isinstance( + dtypes, tuple + ), f"{dtypes} unexpected" + input_dtype = dtypes[0] if error_name == ErrorIf.WrongOutputType: accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype) @@ -1129,9 +1129,9 @@ class TosaArgGen: elif dtype == DType.INT16: accum_dtypes = [DType.INT48] elif dtype == DType.FP16: - accum_dtypes = [DType.FP16, DType.FLOAT] - elif dtype == DType.FLOAT: - accum_dtypes = [DType.FLOAT] + accum_dtypes = [DType.FP16, DType.FP32] + elif dtype == DType.FP32: + accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" @@ -1245,7 +1245,7 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype in (DType.FP16, DType.FLOAT): + elif dtype in (DType.FP16, DType.FP32): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: @@ -1303,9 +1303,9 @@ class TosaArgGen: elif dtype == DType.INT8 or dtype == DType.INT16: accum_dtypes = [DType.INT32] elif dtype == DType.FP16: - accum_dtypes = [DType.FP16, DType.FLOAT] - elif dtype == DType.FLOAT: - accum_dtypes = [DType.FLOAT] + accum_dtypes = [DType.FP16, DType.FP32] + elif dtype == DType.FP32: + accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" else: @@ -1408,20 +1408,20 @@ class TosaArgGen: if error_name == ErrorIf.WrongOutputType: dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype) elif inDtype == DType.INT8: - dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32] elif inDtype == DType.INT16: - dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32] elif inDtype == DType.INT32: - dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: dtypeList = [DType.INT8, DType.INT16, DType.INT32] - elif inDtype == DType.FLOAT: + elif inDtype == DType.FP32: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type - dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] else: raise Exception("Unexpected input dtype: {}".format(inDtype)) @@ -1826,8 +1826,8 @@ class TosaArgGen: outputDTypeList = [DType.INT48] elif dtype == DType.FP16: outputDTypeList = [DType.FP16] - elif dtype == DType.FLOAT: - outputDTypeList = [DType.FLOAT] + elif dtype == DType.FP32: + outputDTypeList = [DType.FP32] elif error_name == ErrorIf.WrongInputType: # If an incorrect input type is used then we set a 'correct' # output type to avoid other errors diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index a766803..abe1a97 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -119,7 +119,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif mode == ResizeMode.NEAREST and dtype == DType.INT16: @@ -128,7 +128,7 @@ class TosaErrorIfArgGen: DType.INT8, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: @@ -137,7 +137,7 @@ class TosaErrorIfArgGen: DType.INT8, DType.INT16, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: @@ -146,7 +146,7 @@ class TosaErrorIfArgGen: DType.INT8, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif dtype == DType.FP16: @@ -156,9 +156,9 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ) - elif dtype == DType.FLOAT: + elif dtype == DType.FP32: incorrect_types = ( DType.INT4, DType.INT8, @@ -299,8 +299,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT] + if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -366,6 +366,16 @@ class TosaErrorValidator: } wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes)) + # Turn the wrong dtypes into required list of types + if op["op"] in [ + Op.FULLY_CONNECTED, + Op.CONV2D, + Op.CONV3D, + Op.DEPTHWISE_CONV2D, + Op.TRANSPOSE_CONV2D, + ]: + wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes] + if op["op"] == Op.CLAMP: wrong_input_dtypes.remove(DType.INT48) @@ -415,7 +425,7 @@ class TosaErrorValidator: and output_dtype != DType.INT48 ) or (input_dtype == DType.FP16 and output_dtype != DType.FP16) - or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) + or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True @@ -430,28 +440,28 @@ class TosaErrorValidator: or (input_dtype == DType.INT16 and output_dtype != DType.INT48) or ( input_dtype == DType.FP16 - and output_dtype not in (DType.FP16, DType.FLOAT) + and output_dtype not in (DType.FP16, DType.FP32) ) - or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) + or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] + input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: if ( - input_dtype not in (DType.FP16, DType.FLOAT) + input_dtype not in (DType.FP16, DType.FP32) and output_dtype != DType.INT32 ): error_result = True elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True - elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT: + elif input_dtype == DType.FP32 and output_dtype != DType.FP32: error_result = True elif op["op"] == Op.TABLE: @@ -477,7 +487,7 @@ class TosaErrorValidator: DType.BOOL, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ] ) @@ -488,7 +498,7 @@ class TosaErrorValidator: DType.BOOL, DType.INT8, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ] ) @@ -499,7 +509,7 @@ class TosaErrorValidator: DType.BOOL, DType.INT8, DType.INT16, - DType.FLOAT, + DType.FP32, DType.FP16, ] ) @@ -508,7 +518,7 @@ class TosaErrorValidator: and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( - input_dtype == DType.FLOAT + input_dtype == DType.FP32 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) ): @@ -526,9 +536,9 @@ class TosaErrorValidator: or input_dtype == DType.INT16 and output_dtype != DType.INT48 or input_dtype == DType.FP16 - and output_dtype not in (DType.FP16, DType.FLOAT) - or input_dtype == DType.FLOAT - and output_dtype != DType.FLOAT + and output_dtype not in (DType.FP16, DType.FP32) + or input_dtype == DType.FP32 + and output_dtype != DType.FP32 ): error_result = True # invalid input types are ignored, to avoid reporting multiple errors @@ -2306,12 +2316,12 @@ class TosaInvalidValidator: not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) - and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) + and not (input_dtype == DType.FP32 and output_dtype == DType.FP32) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] + input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] ) else: # Invalid resize mode diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 9ff6ec5..78d86cd 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -13,6 +13,7 @@ from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator +from generator.tosa_utils import DTYPE_ATTRIBUTES from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes from tosa.DType import DType @@ -83,7 +84,7 @@ class TosaTestGen: ) elif dtype == DType.FP16: return np.float16(self.rng.random(size=shape)) - elif dtype == DType.FLOAT: + elif dtype == DType.FP32: return np.float32(self.rng.random(size=shape)) else: raise Exception("Unrecognized Dtype: {}".format(dtype)) @@ -128,7 +129,7 @@ class TosaTestGen: return np.int32(self.rng.integers(low=low, high=high, size=1))[0] def getRandNumberDType(self, dtype): - if dtype == DType.FLOAT: + if dtype == DType.FP32: return self.rng.random() elif dtype == DType.FP16: rand_f32 = self.rng.random() @@ -162,58 +163,26 @@ class TosaTestGen: return "x".join(sStr) - def typeStr(self, t): - if isinstance(t, list): - assert len(t) >= 2 - return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1])) + def typeStr(self, dtype): + if isinstance(dtype, list) or isinstance(dtype, tuple): + assert len(dtype) >= 2 + strs = [self.typeStr(t) for t in dtype] + # Limit types to the first 2 as the 3rd is the accumulator + return "x".join(strs[:2]) else: - if t == DType.BOOL: - return "b" - elif t == DType.INT4: - return "i4" - elif t == DType.INT8: - return "i8" - elif t == DType.UINT8: - return "u8" - elif t == DType.INT16: - return "i16" - elif t == DType.UINT16: - return "u16" - elif t == DType.INT32: - return "i32" - elif t == DType.INT48: - return "i48" - elif t == DType.FP16: - return "f16" - elif t == DType.FLOAT: - return "float" + if dtype in DTYPE_ATTRIBUTES: + return DTYPE_ATTRIBUTES[dtype]["str"] else: - raise Exception("Unknown dtype, cannot convert to string: {}".format(t)) + raise Exception( + "Unknown dtype, cannot convert to string: {}".format(dtype) + ) - def typeWidth(self, t): + def typeWidth(self, dtype): """Get the datatype width for data types""" - if t == DType.INT4: - return 4 - elif t == DType.INT8: - return 8 - elif t == DType.UINT8: - return 8 - elif t == DType.INT16: - return 16 - elif t == DType.UINT16: - return 16 - elif t == DType.INT32: - return 32 - elif t == DType.INT48: - return 48 - elif t == DType.FP16: - return 16 - elif t == DType.FLOAT: - return 32 - elif t == DType.BOOL: - return 1 + if dtype in DTYPE_ATTRIBUTES: + return DTYPE_ATTRIBUTES[dtype]["width"] else: - raise Exception(f"Unknown dtype, cannot determine width: {t}") + raise Exception(f"Unknown dtype, cannot determine width: {dtype}") # Argument generators # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list]) @@ -355,7 +324,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype not in (DType.FP16, DType.FLOAT): + if a.dtype not in (DType.FP16, DType.FP32): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -1074,7 +1043,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype in (DType.FP16, DType.FLOAT): + if a.dtype in (DType.FP16, DType.FP32): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1086,7 +1055,7 @@ class TosaTestGen: result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) attr = ts.TosaSerializerAttribute() - attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT)) + attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32)) self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr) return result_tens @@ -1890,7 +1859,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32): + if a.dtype in (DType.FP32, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2001,7 +1970,7 @@ class TosaTestGen: if error_name == ErrorIf.CondGraphOutputNotMatchingBool: cond_tens = self.ser.addOutput( - [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]) + [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32]) ) else: cond_tens = self.ser.addOutput([], DType.BOOL) @@ -2429,7 +2398,7 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FLOAT, DType.FP16] + TYPE_FP = [DType.FP32, DType.FP16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 TYPE_INT_FP = [ @@ -2437,30 +2406,31 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.FP16, - DType.FLOAT, + DType.FP32, ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32 TYPE_FIB = [ DType.FP16, - DType.FLOAT, + DType.FP32, DType.INT8, DType.INT16, DType.INT32, DType.BOOL, ] - TYPE_FI16 = [DType.FLOAT, DType.INT16] + TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], [DType.FP16, DType.FP16, DType.FP16], - [DType.FP16, DType.FP16, DType.FLOAT], - DType.FLOAT, + [DType.FP16, DType.FP16, DType.FP32], + [DType.FP32, DType.FP32, DType.FP32], ] DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK) @@ -3478,7 +3448,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": (DType.FP16, DType.FLOAT, DType.INT32), + "types": (DType.FP16, DType.FP32, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3665,7 +3635,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT), + "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3706,7 +3676,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT), + "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3742,7 +3712,7 @@ class TosaTestGen: ), "types": ( DType.FP16, - DType.FLOAT, + DType.FP32, DType.INT8, DType.INT16, DType.INT32, @@ -3872,7 +3842,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3901,7 +3871,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3929,7 +3899,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3958,7 +3928,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] outputDType = rng.choice(wrong_dtypes) else: @@ -3984,7 +3954,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4016,7 +3986,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -4069,7 +4039,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) @@ -4131,7 +4101,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) @@ -4182,7 +4152,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) @@ -4217,7 +4187,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) @@ -4255,7 +4225,7 @@ class OutputShaper: DType.INT8, DType.INT16, DType.INT48, - DType.FLOAT, + DType.FP32, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -4263,9 +4233,9 @@ class OutputShaper: DType.INT8, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, ) - elif a.dtype == DType.FLOAT or a.dtype == DType.FP16: + elif a.dtype == DType.FP32 or a.dtype == DType.FP16: incorrect_types = ( DType.INT4, DType.INT8, @@ -4307,7 +4277,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, } wrong_dtypes = list(all_dtypes - set([input1.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4334,7 +4304,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) @@ -4358,7 +4328,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4376,7 +4346,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4412,7 +4382,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4440,7 +4410,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4464,7 +4434,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([values.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4491,7 +4461,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4512,7 +4482,7 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, - DType.FLOAT, + DType.FP32, ] wrong_dtypes.remove(output_dtype) output_dtype = rng.choice(wrong_dtypes) @@ -4619,7 +4589,7 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: - excludes = [DType.FP16, DType.FLOAT] + excludes = [DType.FP16, DType.FP32] else: excludes = [out_dtype] wrong_dtypes = list(usableDTypes(excludes=excludes)) diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 7fa31e7..104d9bb 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -5,6 +5,19 @@ from tosa.DType import DType # Maximum dimension size for output and inputs for RESIZE MAX_RESIZE_DIMENSION = 16384 +DTYPE_ATTRIBUTES = { + DType.BOOL: {"str": "b", "width": 1}, + DType.INT4: {"str": "i4", "width": 4}, + DType.INT8: {"str": "i8", "width": 8}, + DType.UINT8: {"str": "u8", "width": 8}, + DType.INT16: {"str": "i16", "width": 16}, + DType.UINT16: {"str": "u16", "width": 16}, + DType.INT32: {"str": "i32", "width": 32}, + DType.INT48: {"str": "i48", "width": 48}, + DType.FP16: {"str": "f16", "width": 16}, + DType.FP32: {"str": "f32", "width": 32}, +} + def valueToName(item, value): """Get the name of an attribute with the given value. @@ -88,10 +101,8 @@ def product(shape): def get_accum_dtype_from_tgTypes(dtypes): # Get accumulate data-type from the test generator's defined types - if isinstance(dtypes, list) or isinstance(dtypes, tuple): - return dtypes[-1] - else: - return dtypes + assert isinstance(dtypes, list) or isinstance(dtypes, tuple) + return dtypes[-1] def get_wrong_output_type(op_name, rng, input_dtype): @@ -102,7 +113,7 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT8, DType.INT16, DType.INT48, - DType.FLOAT, + DType.FP32, DType.FP16, ) elif input_dtype == DType.INT16: @@ -111,10 +122,10 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT8, DType.INT16, DType.INT32, - DType.FLOAT, + DType.FP32, DType.FP16, ) - elif input_dtype == DType.FLOAT or input_dtype == DType.FP16: + elif input_dtype == DType.FP32 or input_dtype == DType.FP16: incorrect_types = ( DType.INT4, DType.INT8, diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py index fbe3a7d..b608fd8 100644 --- a/verif/tests/test_tosa_refmodel.py +++ b/verif/tests/test_tosa_refmodel.py @@ -45,7 +45,7 @@ REF_MODEL_TYPE_TO_OUT = { "uint8": "u8", "int16": "i16", "int32": "i32", - "float": "float", + "fp32": "f32", "fp16": "f16", } @@ -123,21 +123,21 @@ class BuildTosaTest: # Tests - op_name, ref_model_type, num_expected_tests TEST_PARAMS = [ ("add", "int32", 1), - ("add", "float", 1), + ("add", "fp32", 1), ("abs", "int32", 1), - ("abs", "float", 1), + ("abs", "fp32", 1), ("abs", "fp16", 1), ("negate", "int8", 1), ("negate", "int16", 1), ("negate", "int32", 1), - ("negate", "float", 1), + ("negate", "fp32", 1), ("negate", "fp16", 1), # One test per axis (shape dimensions) ("concat", "bool", SHAPE_DIMS), ("concat", "int8", SHAPE_DIMS), ("concat", "int16", SHAPE_DIMS), ("concat", "int32", SHAPE_DIMS), - ("concat", "float", SHAPE_DIMS), + ("concat", "fp32", SHAPE_DIMS), ("concat", "fp16", SHAPE_DIMS), ] -- cgit v1.2.1