aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/tensor_ops.cc
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc203
1 files changed, 112 insertions, 91 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b3845df..f8fd323 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -116,14 +116,14 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
}
int check_conv_attribute(tosa::TosaConvAttribute* attribute,
- uint32_t conv_dimension,
- std::vector<int32_t> input_shape,
- std::vector<int32_t> output_shape,
- std::vector<int32_t> weights,
- uint32_t offset_kernel,
- DType InDtype,
- DType WeightDtype,
- std::string& msg)
+ uint32_t conv_dimension,
+ std::vector<int32_t> input_shape,
+ std::vector<int32_t> output_shape,
+ std::vector<int32_t> weights,
+ uint32_t offset_kernel,
+ TOSA_REF_TYPE InDtype,
+ TOSA_REF_TYPE WeightDtype,
+ std::string& msg)
{
if (attribute->pad().size() != (2 * conv_dimension))
{
@@ -226,11 +226,13 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute,
return 1;
}
- if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
+ if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
+ {
msg = "Input zero point must be zero for non-int8 data";
return 1;
}
- if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
+ if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
+ {
msg = "Weight zero point must be zero for non-int8 data";
return 1;
}
@@ -318,7 +320,7 @@ int check_fft_shape(const std::vector<int32_t>& in_real,
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -330,14 +332,14 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArgMax<Rank, Dtype>::~OpArgMax()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -355,7 +357,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
return 1;
}
- if (outputs[0]->getDtype() != DType_INT32)
+ if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
{
printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
return 1;
@@ -400,7 +402,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::eval()
{
Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
@@ -410,7 +412,7 @@ int OpArgMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -422,14 +424,14 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -449,8 +451,10 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
+ "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
std::string msg;
if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
@@ -466,8 +470,9 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
// This calculates the number of padding elements used for each location along an axis
// Average pooling only divides by the number of elements used, not including padding.
// This function uses left/right, but is also used for vertical padding with top/bottom
-template <DType Dtype, DType AccDtype>
-ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
+ int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
{
ETensor1<int32_t> result(out_size);
@@ -495,7 +500,7 @@ ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size
// assuming input and output tensor have same scales like tflite reference
// so no need to scale input and output
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
int OpAvgPool2d<Dtype, AccDtype>::eval()
{
int in_batch = this->in->getShape()[0];
@@ -531,7 +536,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
- tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+ TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
DEBUG_INFO(OP,
"perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
@@ -556,7 +561,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
pad[3] = std::make_pair(0, 0);
ETensor4<InEigenType> input_val = this->in->getTensor();
- if (Dtype == DType_INT8)
+ if (Dtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
}
@@ -604,7 +609,8 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
dm2_h.contract(dm2_w, contract_dims)
.reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
.broadcast(bcast);
- if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
+ if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
+ Dtype != TOSA_REF_TYPE_FP64)
{
try
{
@@ -632,7 +638,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -644,14 +650,14 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -688,7 +694,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -781,7 +787,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -817,7 +823,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
// reshape back to [N, H, W, C]
this->output->getTensor() = biased_output.reshape(col2im_output_dims);
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -826,7 +832,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -838,14 +844,14 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -882,7 +888,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -959,7 +965,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1020,7 +1026,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1029,10 +1035,10 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1041,14 +1047,14 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1085,7 +1091,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -1149,7 +1155,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1205,7 +1211,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1214,10 +1220,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
@@ -1226,14 +1232,14 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave
INIT_ATTRIBUTE(FullyConnected);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1265,13 +1271,15 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
+ ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+ "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1289,7 +1297,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1299,7 +1307,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1307,7 +1315,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1319,14 +1327,14 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(MatMul);
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
OpMatMul<Dtype, OutDtype>::~OpMatMul()
{
if (attribute)
delete attribute;
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1382,13 +1390,15 @@ int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
}
W = b->getShape()[2];
- ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
- ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
+ "OpMatMul: A zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
+ "OpMatMul: B zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
int OpMatMul<Dtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1396,7 +1406,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
TIn a_val = this->a->getTensor();
TIn b_val = this->b->getTensor();
- if (Dtype == DType_INT8)
+ if (Dtype == TOSA_REF_TYPE_INT8)
{
a_val = a_val - (InEigenType)attribute->a_zp();
b_val = b_val - (InEigenType)attribute->b_zp();
@@ -1434,7 +1444,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1443,7 +1453,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1455,14 +1465,14 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpMaxPool2d<Dtype>::~OpMaxPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpMaxPool2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1493,7 +1503,7 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpMaxPool2d<Dtype>::eval()
{
int in_batch = this->in->getShape()[0];
@@ -1586,10 +1596,8 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
-OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_FFT2D, id_)
{
setRequiredOperands(2, 2);
@@ -1598,14 +1606,14 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(FFT);
}
-template <DType Dtype>
-OpFFT2d<Dtype>::~OpFFT2d() {
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::~OpFFT2d()
+{
if (attribute)
delete attribute;
}
-
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpFFT2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1643,7 +1651,7 @@ int OpFFT2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpFFT2d<Dtype>::eval()
{
int in_real_batch = this->in_real->getShape()[0];
@@ -1709,7 +1717,7 @@ int OpFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1719,11 +1727,11 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
setRequiredRank(3);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpRFFT2d<Dtype>::~OpRFFT2d() {}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpRFFT2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1759,7 +1767,7 @@ int OpRFFT2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpRFFT2d<Dtype>::eval()
{
int32_t in_batch = in->getShape()[0];
@@ -1815,10 +1823,10 @@ int OpRFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1827,14 +1835,14 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra
INIT_ATTRIBUTE(TransposeConv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1923,13 +1931,15 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
+ ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+ "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -1985,7 +1995,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -2040,7 +2050,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -2055,6 +2065,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
@@ -2062,6 +2073,7 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
// [in_t, weight_t, out_t]
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -2071,6 +2083,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
@@ -2079,6 +2092,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
@@ -2087,8 +2101,10 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
@@ -2097,6 +2113,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
@@ -2104,14 +2121,17 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
@@ -2120,3 +2140,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);