aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_utils.h')
-rw-r--r--reference_model/src/verify/verify_utils.h39
1 files changed, 37 insertions, 2 deletions
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
index 341bd90..0d7bf47 100644
--- a/reference_model/src/verify/verify_utils.h
+++ b/reference_model/src/verify/verify_utils.h
@@ -22,6 +22,7 @@
#include <cstdint>
#include <optional>
+#include <string>
#include <vector>
#define TOSA_REF_REQUIRE(COND, MESSAGE, ...) \
@@ -164,9 +165,43 @@ struct AccPrecision<half_float::half>
static constexpr int32_t normal_frac = 7;
};
-/// \brief Error bounds check for ULP and ABS_ERROR modes
+/// \brief Single value error bounds check for ULP, ABS_ERROR and other compliance modes
+///
+/// \param testValue Implementation value
+/// \param referenceValue Reference value
+/// \param errorBound Positive error bound value
+/// \param resultDifference Return: Difference between reference value and implementation value
+/// \param resultWarning Return: Warning message if implementation is outside error bounds
+///
+/// \return True if compliant else false
template <typename OutType>
-bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorBound);
+bool tosaCheckFloatBound(
+ OutType testValue, double referenceValue, double errorBound, double& resultDifference, std::string& resultWarning);
+
+/// \brief Whole tensor checker for values inside error bounds
+///
+/// \param referenceData Reference output tensor data
+/// \param boundsData Optional reference bounds tensor data
+/// \param implementationData Implementation output tensor data
+/// \param shape Tensor shape - all tensors must be this shape
+/// \param modeStr Short string indicating which compliance mode we are testing
+/// \param cfgPtr Pointer to this mode's configuration data, passed to the calcErrorBound()
+/// \param calcErrorBound Pointer to a function that can calculate the error bound per ref value
+///
+/// \return True if compliant else false
+template <typename OutType>
+bool validateData(const double* referenceData,
+ const double* boundsData,
+ const OutType* implementationData,
+ const std::vector<int32_t>& shape,
+ const std::string& modeStr,
+ const void* cfgPtr,
+ double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr));
+
+// Unused arguments helper function
+template <typename... Args>
+inline void unused(Args&&...)
+{}
}; // namespace TosaReference
#endif // VERIFY_UTILS_H_