aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.h
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-04-28 16:29:44 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-04-28 17:37:03 -0700
commit989cb050228b47085189b1c5cb0d9b705e1060e7 (patch)
tree317bfa75a3d3b79341f2d0e3c994bcef834ff179 /reference_model/src/tensor.h
parent550ccc52de231621c0bf0c05ae2a398eec37ff51 (diff)
downloadreference_model-989cb050228b47085189b1c5cb0d9b705e1060e7.tar.gz
Support mixed-precision input tensors for TOSA unit test.
Bring CONV2D/DEPTHWISE_CONV2D/TRANSPOSE_CONV2D/FULLY_CONNECTED up running. Other minor fixes: - reference model should bail out if shape is invalid, along with "goto done" cleanup. - cleanup typos/duplicate in tosa_test_gen.py/tosa_serializer.py. - wrong input_zp/output_zp being generated for RESCALE. Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ic1f3fe0090482bdee8a61508be7c738714191e19
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r--reference_model/src/tensor.h45
1 files changed, 24 insertions, 21 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index d39cc7c..6c0622e 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -33,9 +33,7 @@ class GraphNode;
class Tensor
{
public:
- Tensor(std::string tensorName_,
- DType tensorDtype__,
- std::vector<int> shape_);
+ Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_);
virtual ~Tensor();
@@ -240,9 +238,7 @@ template <class T>
class TensorTemplate : public Tensor
{
public:
- TensorTemplate(std::string tensorName_,
- DType tensorDtype_,
- std::vector<int> shape_)
+ TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
: Tensor(tensorName_, tensorDtype_, shape_)
{
tensor = nullptr;
@@ -606,11 +602,15 @@ int Tensor6<bool>::dumpTensor(FILE* out) const;
class TensorFactory
{
public:
- static Tensor* newTensor(std::string tensorName_,
- DType tensorDtype_,
- std::vector<int> shape_,
- const uint32_t rank)
+ static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
{
+ // Bail out if any dimension is invalid.
+ for (auto& dim : shape_)
+ {
+ if (dim <= 0)
+ goto done;
+ }
+
switch (tensorDtype_)
{
case DType_FLOAT:
@@ -630,9 +630,8 @@ public:
return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
case DType_INT32:
case DType_UINT8:
case DType_INT4:
@@ -654,9 +653,8 @@ public:
return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
case DType_INT48:
switch (rank)
{
@@ -674,9 +672,8 @@ public:
return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
case DType_BOOL:
switch (rank)
{
@@ -694,16 +691,22 @@ public:
return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
default:
- goto done;
+ break;
}
done:
- FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_],
- rank);
+ std::string shape_str("[");
+ for (auto& dim : shape_)
+ {
+ shape_str += (std::to_string(dim) + ", ");
+ }
+ shape_str.append("]");
+
+ FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d, shape=%s", tensorName_.c_str(),
+ EnumNamesDType()[tensorDtype_], rank, shape_str.c_str());
}
static Tensor* newTensor(DType type, const std::vector<int> shape);