aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src')
-rw-r--r--reference_model/src/ops/tensor_ops.cc10
-rw-r--r--reference_model/src/ops/tensor_ops.h10
2 files changed, 10 insertions, 10 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 045c0a5..a150656 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -78,7 +78,7 @@ OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
setRequiredOperands(1, 1);
setRequiredRank(4);
- INIT_ATTRIBUTE(Pool2d);
+ INIT_ATTRIBUTE(Pool);
INIT_QINFO(Unary);
}
@@ -299,7 +299,7 @@ OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
setRequiredOperands(3, 1);
setRequiredRank(4);
- INIT_ATTRIBUTE(Conv2d);
+ INIT_ATTRIBUTE(Conv);
INIT_QINFO(Conv);
}
@@ -491,7 +491,7 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sg
setRequiredOperands(3, 1);
setRequiredRank(4);
- INIT_ATTRIBUTE(Conv2d);
+ INIT_ATTRIBUTE(Conv);
INIT_QINFO(Conv);
}
@@ -892,7 +892,7 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
setRequiredOperands(1, 1);
setRequiredRank(4);
- INIT_ATTRIBUTE(Pool2d);
+ INIT_ATTRIBUTE(Pool);
}
template <DType Dtype>
@@ -1034,7 +1034,7 @@ OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
setRequiredOperands(3, 1);
setRequiredRank(4);
- INIT_ATTRIBUTE(TransposeConv2d);
+ INIT_ATTRIBUTE(TransposeConv);
INIT_QINFO(Conv);
}
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 6ffc27d..eea351d 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -68,7 +68,7 @@ public:
protected:
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
- tosa::TosaPool2dAttribute* attribute;
+ tosa::TosaPoolAttribute* attribute;
tosa::TosaUnaryQuantInfo* qinfo;
protected:
@@ -104,7 +104,7 @@ protected:
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
- tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvAttribute* attribute;
tosa::TosaConvQuantInfo* qinfo;
};
@@ -136,7 +136,7 @@ protected:
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
- tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvAttribute* attribute;
tosa::TosaConvQuantInfo* qinfo;
};
@@ -219,7 +219,7 @@ public:
protected:
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
- tosa::TosaPool2dAttribute* attribute;
+ tosa::TosaPoolAttribute* attribute;
};
template <DType InDtype, DType WeightDtype>
@@ -250,7 +250,7 @@ protected:
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
TosaReference::TensorTemplate<TAcc>* output;
- TosaTransposeConv2dAttribute* attribute;
+ TosaTransposeConvAttribute* attribute;
TosaConvQuantInfo* qinfo;
};