aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc315
1 files changed, 172 insertions, 143 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 2cd94bb..c617dda 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
#include "tensor_ops.h"
#include "quant_util.h"
#include "template_types.h"
+#include "half.hpp"
using namespace TosaReference;
using namespace Eigen;
@@ -329,8 +330,8 @@ int OpArgMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
-OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
+template <DType Dtype, DType AccDtype>
+OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_AVG_POOL2D, id_)
@@ -341,15 +342,15 @@ OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype>
-OpAvgPool2d<Dtype>::~OpAvgPool2d()
+template <DType Dtype, DType AccDtype>
+OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype>
-int OpAvgPool2d<Dtype>::checkTensorAttributes()
+template <DType Dtype, DType AccDtype>
+int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -385,8 +386,8 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes()
// This calculates the number of padding elements used for each location along an axis
// Average pooling only divides by the number of elements used, not including padding.
// This function uses left/right, but is also used for vertical padding with top/bottom
-template <DType Dtype>
-ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
+template <DType Dtype, DType AccDtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
{
ETensor1<int32_t> result(out_size);
@@ -414,8 +415,8 @@ ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_
// assuming input and output tensor have same scales like tflite reference
// so no need to scale input and output
-template <DType Dtype>
-int OpAvgPool2d<Dtype>::eval()
+template <DType Dtype, DType AccDtype>
+int OpAvgPool2d<Dtype, AccDtype>::eval()
{
int in_batch = this->in->getShape()[0];
int in_height = this->in->getShape()[1];
@@ -439,11 +440,13 @@ int OpAvgPool2d<Dtype>::eval()
int stride_h = this->attribute->stride()[0];
int stride_w = this->attribute->stride()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(OP,
"perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
- "stride=[%d,%d], pad=[%d,%d,%d,%d]",
+ "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
- kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right);
+ kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eigen::array<Eigen::Index, 2> im2col_input_dims;
im2col_input_dims[0] = kernel_h * kernel_w;
@@ -509,8 +512,7 @@ int OpAvgPool2d<Dtype>::eval()
.contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
.reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
.broadcast(bcast);
-
- if (Dtype != DType_FLOAT)
+ if (Dtype != DType_FLOAT && Dtype != DType_FP16)
{
try
{
@@ -531,14 +533,15 @@ int OpAvgPool2d<Dtype>::eval()
}
else
{
+ // Case for float-type resizes
this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV2D, id_)
@@ -549,15 +552,15 @@ OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype>
-OpConv2d<InDtype, WeightDtype>::~OpConv2d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -574,12 +577,12 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpConv2d: Output data type not supported for this configuration of operator");
+ "OpConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
std::string msg;
if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
@@ -593,8 +596,8 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv2d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -630,12 +633,14 @@ int OpConv2d<InDtype, WeightDtype>::eval()
int dilation_h = this->attribute->dilation()[0];
int dilation_w = this->attribute->dilation()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(OP,
"perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
- "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
+ "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
- pad_bottom, pad_left, pad_right);
+ pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
// GEMM-conv2d, left matrix is input, right matrix is weight
Eigen::array<Eigen::Index, 2> im2col_input_dims;
@@ -695,33 +700,33 @@ int OpConv2d<InDtype, WeightDtype>::eval()
// transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
ETensor2<WeightEigenType> im2col_weight =
- weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
+ weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
// don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
// and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
- ETensor2<AccEigenType> bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims);
+ ETensor2<OutEigenType> bias_2d = (this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
// output matrix is [N * H * W, C]
- ETensor2<AccEigenType> contracted_result =
- im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims);
+ ETensor2<OutEigenType> contracted_result =
+ (im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims)).template cast<OutEigenType>();
// adding bias
- ETensor2<AccEigenType> biased_output = contracted_result + bias_2d.template cast<AccEigenType>();
+ ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
// reshape back to [N, H, W, C]
this->output->getTensor() = biased_output.reshape(col2im_output_dims);
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV3D, id_)
@@ -732,15 +737,15 @@ OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype>
-OpConv3d<InDtype, WeightDtype>::~OpConv3d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -757,12 +762,12 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpConv3d: Output data type not supported for this configuration of operator");
+ "OpConv3d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
std::string msg;
if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
@@ -776,8 +781,8 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv3d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_depth = this->input->getShape()[1];
@@ -821,13 +826,15 @@ int OpConv3d<InDtype, WeightDtype>::eval()
int dilation_h = this->attribute->dilation()[1];
int dilation_w = this->attribute->dilation()[2];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(
OP,
"perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
- "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
+ "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
- dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
+ dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
pad[0] = std::make_pair(0, 0);
@@ -860,7 +867,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
// 2. direct convolution
- AccEigenType acc = 0;
+ AccEigenType acc(0.0);
int d_idx, h_idx, w_idx;
for (int ob = 0; ob < out_batch; ob++)
@@ -874,7 +881,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
for (int oc = 0; oc < out_channels; oc++)
{
// Initialize accumulator with bias value
- acc = this->output->getTensor()(ob, od, oh, ow, oc);
+ acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
for (int fd = 0; fd < f_depth; fd++)
{
d_idx = od * stride_d + fd * dilation_d;
@@ -892,7 +899,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
}
}
}
- this->output->getTensor()(ob, od, oh, ow, oc) = acc;
+ this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
}
}
}
@@ -901,15 +908,15 @@ int OpConv3d<InDtype, WeightDtype>::eval()
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
@@ -920,15 +927,15 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sg
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -945,12 +952,12 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
+ "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
std::string msg;
if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
@@ -964,8 +971,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -1002,12 +1009,14 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
int dilation_h = this->attribute->dilation()[0];
int dilation_w = this->attribute->dilation()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(OP,
"perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
- "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
- pad_bottom, pad_left, pad_right);
+ pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
pad[0] = std::make_pair(0, 0);
@@ -1061,9 +1070,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
{
for (int fw = 0; fw < f_width; fw++)
{
+ // Perform multiplication in AccEigenType then cast to OutEigenType
this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
- ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
- (AccEigenType)weight_val(fh, fw, ic, cm));
+ (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
+ (AccEigenType)weight_val(fh, fw, ic, cm));
}
}
}
@@ -1074,15 +1084,15 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
@@ -1093,15 +1103,15 @@ OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_
INIT_ATTRIBUTE(FullyConnected);
}
-template <DType InDtype, DType WeightDtype>
-OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpFullyConnected<InDtype, WeightDtype, AccDtype>::~OpFullyConnected()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1128,9 +1138,9 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpFullyConnected: Output data type not supported for this configuration of operator");
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
@@ -1138,8 +1148,8 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpFullyConnected<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1163,19 +1173,19 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
}
this->output->getTensor() =
- input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims) +
- this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
+ input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
+ this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType Dtype>
-OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
+template <DType Dtype, DType AccDtype>
+OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_MATMUL, id_)
@@ -1186,15 +1196,15 @@ OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(MatMul);
}
-template <DType Dtype>
-OpMatMul<Dtype>::~OpMatMul()
+template <DType Dtype, DType AccDtype>
+OpMatMul<Dtype, AccDtype>::~OpMatMul()
{
if (attribute)
delete attribute;
}
-template <DType Dtype>
-int OpMatMul<Dtype>::checkTensorAttributes()
+template <DType Dtype, DType AccDtype>
+int OpMatMul<Dtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1205,11 +1215,11 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpMatMul: Output data type not supported for this configuration of operator");
+ "OpMatMul: Output data type not supported for this configuration of operator");
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(a && b && output);
@@ -1255,8 +1265,8 @@ int OpMatMul<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
-int OpMatMul<Dtype>::eval()
+template <DType Dtype, DType AccDtype>
+int OpMatMul<Dtype, AccDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1289,22 +1299,22 @@ int OpMatMul<Dtype>::eval()
TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
TAccRank2 output_rank2_val =
a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
- TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape);
+ TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
if (i == 0)
{
this->output->getTensor() = output_rank3_val;
}
else
{
- TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0);
+ TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
this->output->getTensor() = temp;
}
}
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
@@ -1442,8 +1452,8 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
@@ -1454,15 +1464,15 @@ OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sg
INIT_ATTRIBUTE(TransposeConv);
}
-template <DType InDtype, DType WeightDtype>
-OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1473,12 +1483,12 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpTransposeConv2d: Output data type not supported for this configuration of operator");
+ "OpTransposeConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
if (attribute->out_pad().size() != 4)
{
@@ -1556,8 +1566,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpTransposeConv2d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -1584,6 +1594,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
int stride_h = this->attribute->stride()[0];
int stride_w = this->attribute->stride()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
in_channels);
@@ -1594,10 +1606,10 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
DEBUG_INFO(OP,
"perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
- "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels,
out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top,
- out_pad_bottom, out_pad_left, out_pad_right);
+ out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]);
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
@@ -1645,8 +1657,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
{
this->output->getTensor()(ob, out_y, out_x, oc) +=
- ((AccEigenType)input_val(ob, ih, iw, ic) *
- (AccEigenType)weight_val(oc, fh, fw, ic));
+ (OutEigenType) ((AccEigenType)input_val(ob, ih, iw, ic) *
+ (AccEigenType)weight_val(oc, fh, fw, ic));
}
}
}
@@ -1658,51 +1670,68 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
-DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT)
-DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT8)
-DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16)
-
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
-
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8);
-
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
-
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8);
-
-DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT8);
-DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16);
-DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT);
-
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
+
+ // [in_t, weight_t, acc_t]
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FLOAT, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48);