From 7a2f6d849ade2b96484cde1fdeb1a1a0d9ffa654 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Fri, 7 Jan 2022 17:00:03 -0600 Subject: [PATCH 1/9] Remove intermediate casting in BatchNormalization --- .../vivado/nnet_utils/nnet_batchnorm.h | 10 ++--- .../vivado/nnet_utils/nnet_batchnorm_stream.h | 4 +- .../templates/vivado/nnet_utils/nnet_mult.h | 27 ++++++++++++ hls4ml/templates/vivado_template.py | 4 +- test/pytest/test_batchnorm.py | 44 +++++++++++++++++++ 5 files changed, 80 insertions(+), 9 deletions(-) create mode 100644 test/pytest/test_batchnorm.py diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h index 9a5cff0d3..ce1cfb315 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h @@ -43,8 +43,8 @@ struct batchnorm_config static const bool store_weights_in_bram = false; static const unsigned n_zeros = 0; // partitioning arrays cyclically to go with roll factors? - template<class x_T, class y_T, class res_T> - using product = nnet::product::mult<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::mult_nocast<x_T, y_T>; }; template<class data_T, class res_T, typename CONFIG_T> @@ -71,7 +71,7 @@ void normalize( #pragma HLS ARRAY_PARTITION variable=bias complete int multiplier_limit = ceil(float(CONFIG_T::n_in) / float(CONFIG_T::reuse_factor)); - CONFIG_T::template product<data_T, typename CONFIG_T::scale_t, res_T>::limit(multiplier_limit); + CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::limit(multiplier_limit); } else if (CONFIG_T::io_type == io_serial) { #pragma HLS ARRAY_RESHAPE variable=scale complete dim=1 @@ -87,10 +87,10 @@ void normalize( } if (CONFIG_T::n_filt==-1) { - res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t, res_T>::product(data[ires], scale[ires]) + bias[ires]; + res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::product(data[ires], scale[ires]) + bias[ires]; } else { int norm_index = ires%CONFIG_T::n_filt; - res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t, res_T>::product(data[ires], scale[norm_index]) + bias[norm_index]; + res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::product(data[ires], scale[norm_index]) + bias[norm_index]; } } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h index 382887fed..826bdafe9 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h @@ -43,7 +43,7 @@ void normalize( constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); constexpr unsigned ii = CONFIG_T::n_in / multiplier_limit; - CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t, typename res_T::value_type>::limit(multiplier_limit); + CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t>::limit(multiplier_limit); BatchNormLoop: for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) { #pragma HLS PIPELINE II=ii @@ -60,7 +60,7 @@ void normalize( } else { norm_index = j % CONFIG_T::n_filt; } - out_data[j] = CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t, typename res_T::value_type>::product(in_data[j], scale[norm_index]) + bias[norm_index]; + out_data[j] = CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t>::product(in_data[j], scale[norm_index]) + bias[norm_index]; } res.write(out_data); diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h index 3a597f038..29523862c 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h @@ -26,6 +26,18 @@ class Product{ static void limit(unsigned multiplier_limit) {} // Nothing to do here }; +template<class x_T, class w_T> +class Product_nocast{ + public: + static auto product(x_T a, w_T w) -> decltype(a*w) + { + // 'Normal' product + #pragma HLS INLINE + return a * w; + } + static void limit(unsigned multiplier_limit) {} // Nothing to do here +}; + template<class x_T, class w_T, class y_T> class both_binary : public Product<x_T, w_T, y_T>{ public: @@ -82,6 +94,21 @@ class mult : public Product<x_T, w_T, y_T>{ } }; +template<class x_T, class w_T> +class mult_nocast : public Product_nocast<x_T, w_T>{ + public: + static auto product(x_T a, w_T w) -> decltype(a*w) + { + // 'Normal' product + #pragma HLS INLINE + return a * w; + } + static void limit(unsigned multiplier_limit){ + #pragma HLS INLINE + #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + } +}; + template<class x_T, class w_T, class y_T> class weight_exponential : public Product<x_T, w_T, y_T>{ public: diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py index 149b52f1d..5811d7c12 100644 --- a/hls4ml/templates/vivado_template.py +++ b/hls4ml/templates/vivado_template.py @@ -31,8 +31,8 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{ static const bool store_weights_in_bram = false; typedef {bias_t} bias_t; typedef {scale_t} scale_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}_nocast<x_T, y_T>; }};\n""" conv1d_config_template = """struct config{index} : nnet::conv1d_config {{ diff --git a/test/pytest/test_batchnorm.py b/test/pytest/test_batchnorm.py new file mode 100644 index 000000000..25744e7f6 --- /dev/null +++ b/test/pytest/test_batchnorm.py @@ -0,0 +1,44 @@ +import pytest +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import BatchNormalization +import numpy as np +import hls4ml + + +in_shape = 16 +atol = 5e-3 + +@pytest.fixture(scope='module') +def data(): + np.random.seed(0) + X = np.random.rand(100, in_shape) + return X + + +@pytest.fixture(scope='module') +def model(): + model = Sequential() + model.add(BatchNormalization(input_shape=(in_shape,))) + model.compile() + return model + + +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_global_pool1d(model, data, io_type): + + config = hls4ml.utils.config_from_keras_model(model, + default_precision='ap_fixed<32,1>', + granularity='name') + + hls_model = hls4ml.converters.convert_from_keras_model(model, + hls_config=config, + io_type=io_type, + output_dir=f'hls4mlprj_batchnorm_{io_type}', + part='xcvu9p-flgb2104-2-i') + hls_model.compile() + + + # Predict + y_keras = np.squeeze(model.predict(data)) + y_hls = hls_model.predict(data) + np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True) -- GitLab From ffd4bfc060b68fb57a55ee4dc614e9ef6395041f Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Tue, 11 Jan 2022 18:13:39 -0600 Subject: [PATCH 2/9] Add _nocast options for other calculations --- .../templates/vivado/nnet_utils/nnet_mult.h | 128 ++++++++++++++++-- 1 file changed, 119 insertions(+), 9 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h index 29523862c..a590511cd 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h @@ -5,6 +5,7 @@ #include "nnet_helpers.h" #include "hls_stream.h" #include <math.h> +#include <iostream> namespace nnet { @@ -12,7 +13,7 @@ namespace product{ /* --- * 5 different methods to perform the product of input and weight, depending on the - * types of each. + * types of each. * --- */ template<class x_T, class w_T, class y_T> @@ -26,15 +27,8 @@ class Product{ static void limit(unsigned multiplier_limit) {} // Nothing to do here }; -template<class x_T, class w_T> class Product_nocast{ public: - static auto product(x_T a, w_T w) -> decltype(a*w) - { - // 'Normal' product - #pragma HLS INLINE - return a * w; - } static void limit(unsigned multiplier_limit) {} // Nothing to do here }; @@ -48,6 +42,16 @@ class both_binary : public Product<x_T, w_T, y_T>{ } }; +template<class x_T, class w_T> +class both_binary_nocast : public Product_nocast{ + public: + static x_T product(x_T a, w_T w){ + // specialisation for 1-bit weights and incoming data + #pragma HLS INLINE + return a == w; + } +}; + template<class x_T, class w_T, class y_T> class weight_binary : public Product<x_T, w_T, y_T>{ public: @@ -58,6 +62,16 @@ class weight_binary : public Product<x_T, w_T, y_T>{ } }; +template<class x_T, class w_T> +class weight_binary_nocast : public Product_nocast{ + public: + static x_T product(x_T a, w_T w){ + // Specialisation for 1-bit weights, arbitrary data + #pragma HLS INLINE + return w == 0 ? (x_T) -a : a; + } +}; + template<class x_T, class w_T, class y_T> class data_binary : public Product<x_T, w_T, y_T>{ public: @@ -68,6 +82,16 @@ class data_binary : public Product<x_T, w_T, y_T>{ } }; +template<class x_T, class w_T> +class data_binary_nocast : public Product_nocast{ + public: + static w_T product(x_T a, w_T w){ + // Specialisation for 1-bit data, arbitrary weight + #pragma HLS INLINE + return a == 0 ? (w_T) -w : w; + } +}; + template<class x_T, class w_T, class y_T> class weight_ternary : public Product<x_T, w_T, y_T>{ public: @@ -80,6 +104,18 @@ class weight_ternary : public Product<x_T, w_T, y_T>{ } }; +template<class x_T, class w_T> +class weight_ternary_nocast : public Product_nocast{ + public: + static x_T product(x_T a, w_T w){ + // Specialisation for 2-bit weights, arbitrary data + #pragma HLS INLINE + if (w == 0) return (x_T) 0; + else if(w == -1) return (x_T) -a; + else return (x_T) a; // if(w == 1) + } +}; + template<class x_T, class w_T, class y_T> class mult : public Product<x_T, w_T, y_T>{ public: @@ -95,7 +131,7 @@ class mult : public Product<x_T, w_T, y_T>{ }; template<class x_T, class w_T> -class mult_nocast : public Product_nocast<x_T, w_T>{ +class mult_nocast : public Product_nocast{ public: static auto product(x_T a, w_T w) -> decltype(a*w) { @@ -122,6 +158,80 @@ class weight_exponential : public Product<x_T, w_T, y_T>{ } }; +template<class x_T, class w_T> +class weight_exponential_nocast : public Product_nocast{ + public: + using rt = x_T; + static rt product(x_T a, w_T w){ + std::cerr << "Should not match to this function" << std::endl; + // Shift product for exponential weights + #pragma HLS INLINE + // shift by the exponent. Negative weights shift right + rt y = static_cast<rt>(a) << w.weight; + // negate or not depending on weight sign + return w.sign == 1 ? y : static_cast<rt>(-y); + } +}; + +template<class w_T, int _AP_W> +class weight_exponential_nocast<ap_int<_AP_W>, w_T> : public Product_nocast{ + public: + using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_W + decltype(w_T::weight)::width>; + static rt product(ap_int<_AP_W> a, w_T w){ + // Shift product for exponential weights + #pragma HLS INLINE + // shift by the exponent. Negative weights shift right + rt y = static_cast<rt>(a) << w.weight; + // negate or not depending on weight sign + return w.sign == 1 ? y : static_cast<rt>(-y); + } +}; + +template<class w_T, int _AP_W> +class weight_exponential_nocast<ap_uint<_AP_W>, w_T> : public Product_nocast{ + public: + using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_W + decltype(w_T::weight)::width + 1>; + static rt product(ap_uint<_AP_W> a, w_T w){ + // Shift product for exponential weights + #pragma HLS INLINE + // shift by the exponent. Negative weights shift right + rt y = static_cast<rt>(a) << w.weight; + // negate or not depending on weight sign + return w.sign == 1 ? y : static_cast<rt>(-y); + } +}; + +template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N> +class weight_exponential_nocast<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{ + public: + using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_I + decltype(w_T::weight)::width, + _AP_Q, _AP_O, _AP_N>; + static rt product(ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){ + // Shift product for exponential weights + #pragma HLS INLINE + // shift by the exponent. Negative weights shift right + rt y = static_cast<rt>(a) << w.weight; + // negate or not depending on weight sign + return w.sign == 1 ? y : static_cast<rt>(-y); + } +}; + +template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N> +class weight_exponential_nocast<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{ + public: + using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_I + decltype(w_T::weight)::width + 1, + _AP_Q, _AP_O, _AP_N>; + static rt product(ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){ + // Shift product for exponential weights + #pragma HLS INLINE + // shift by the exponent. Negative weights shift right + // shift by the exponent. Negative weights shift right + rt y = static_cast<rt>(a) << w.weight; + // negate or not depending on weight sign + return w.sign == 1 ? y : static_cast<rt>(-y); + } +}; + } // namespace product_type template<class data_T, class res_T, typename CONFIG_T> -- GitLab From c2d15c6335c8e9d47d4cdc0d47ede856c4372593 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Mon, 17 Jan 2022 15:31:44 -0600 Subject: [PATCH 3/9] Make all the nnet::product not cast the results --- .../vivado/nnet_utils/nnet_batchnorm.h | 2 +- .../templates/vivado/nnet_utils/nnet_dense.h | 4 +- .../vivado/nnet_utils/nnet_dense_compressed.h | 2 +- .../vivado/nnet_utils/nnet_dense_latency.h | 6 +- .../vivado/nnet_utils/nnet_dense_resource.h | 6 +- .../templates/vivado/nnet_utils/nnet_merge.h | 8 +- .../templates/vivado/nnet_utils/nnet_mult.h | 102 ++---------------- hls4ml/templates/vivado_template.py | 14 +-- 8 files changed, 32 insertions(+), 112 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h index ce1cfb315..edc6ff320 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h @@ -44,7 +44,7 @@ struct batchnorm_config static const unsigned n_zeros = 0; // partitioning arrays cyclically to go with roll factors? template<class x_T, class y_T> - using product = nnet::product::mult_nocast<x_T, y_T>; + using product = nnet::product::mult<x_T, y_T>; }; template<class data_T, class res_T, typename CONFIG_T> diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h index deb1c042d..c9785335a 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h @@ -30,8 +30,8 @@ struct dense_config static const unsigned n_zeros = 0; // partitioning arrays cyclically to go with roll factors? // Product function to use - template<class x_T, class y_T, class res_T> - using product = nnet::product::mult<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::mult<x_T, y_T>; }; template<class data_T, class res_T, typename CONFIG_T> diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h index adfaa0e1b..dc803ff2b 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h @@ -86,7 +86,7 @@ void dense_compressed( auto weight_cache = weights[w].weight; data_T data_cache = data[row]; //mult[col] += weight_cache * data_cache; - typename CONFIG_T::accum_t prod = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data_cache, weight_cache); + typename CONFIG_T::accum_t prod = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data_cache, weight_cache); fill_mult<CONFIG_T>(col, mult, prod); } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h index 4a04671fd..2bbab0496 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h @@ -54,7 +54,7 @@ void dense_latency( #pragma HLS ARRAY_PARTITION variable=acc complete int multiplier_limit = ceil(float(CONFIG_T::n_in*CONFIG_T::n_out) / float(CONFIG_T::reuse_factor)) - floor(float(CONFIG_T::n_zeros) / float(CONFIG_T::reuse_factor)); - CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::limit(multiplier_limit); + CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::limit(multiplier_limit); } else if (CONFIG_T::io_type == io_serial){ // Only reduce cycle_factor if n_out is evenly divisible by reuse_factor @@ -90,10 +90,10 @@ void dense_latency( Product2: for(int jj = 0; jj < CONFIG_T::n_out; jj++) { if (CONFIG_T::io_type == io_serial) { int multiplier_limit = ceil(float(CONFIG_T::n_out) / float(CONFIG_T::reuse_factor)); - CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::limit(multiplier_limit); + CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::limit(multiplier_limit); } int index = ii*CONFIG_T::n_out+jj; - mult[index] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(cache, weights[index]); + mult[index] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(cache, weights[index]); } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h index 756a62743..c77059aa7 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h @@ -73,7 +73,7 @@ void dense_resource_rf_leq_nin( for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data[in_index], weights[w_index]); + acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]); // Increment w_index w_index += rufactor; @@ -157,7 +157,7 @@ void dense_resource_rf_gt_nin_rem0( MultLoop: for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data[in_index], weights[w_index]); + acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]); w_index += rufactor; if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) break; // check out of bounds @@ -223,7 +223,7 @@ void dense_resource_rf_gt_nin( int w_index = ir + rufactor * im; int in_index = w_index % nin; if (w_index >= CONFIG_T::n_in*CONFIG_T::n_out) continue; // check out of bounds - tmpmult[im] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data[in_index], weights[w_index]); + tmpmult[im] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]); } typename CONFIG_T::accum_t mult[multiplier_limit]; diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_merge.h b/hls4ml/templates/vivado/nnet_utils/nnet_merge.h index 48a5e172d..b103533e0 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_merge.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_merge.h @@ -38,8 +38,8 @@ struct dot_config { static const unsigned reuse_factor = 1; typedef float accum_t; // Product function to use - template<class x_T, class y_T, class res_T> - using product = nnet::product::mult<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::mult<x_T, y_T>; }; struct concat_config { @@ -129,7 +129,7 @@ void dot1d( #pragma HLS PIPELINE II=CONFIG_T::reuse_factor constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); - CONFIG_T::template product<input1_T, input2_T, typename CONFIG_T::accum_t>::limit(multiplier_limit); + CONFIG_T::template product<input1_T, input2_T>::limit(multiplier_limit); typename CONFIG_T::accum_t mult[CONFIG_T::n_in]; #pragma HLS ARRAY_PARTITION variable=mult complete @@ -137,7 +137,7 @@ void dot1d( Product: for(int i_mult=0; i_mult < CONFIG_T::n_in; i_mult++) { #pragma HLS UNROLL - mult[i_mult] = CONFIG_T::template product<input1_T, input2_T, typename CONFIG_T::accum_t>::product(data1[i_mult], data2[i_mult]); + mult[i_mult] = CONFIG_T::template product<input1_T, input2_T>::product(data1[i_mult], data2[i_mult]); } Accum: for(int i_acc = 0; i_acc < CONFIG_T::n_in; i_acc++) { diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h index a590511cd..5eb7103a7 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h @@ -12,38 +12,17 @@ namespace nnet { namespace product{ /* --- - * 5 different methods to perform the product of input and weight, depending on the + * different methods to perform the product of input and weight, depending on the * types of each. * --- */ -template<class x_T, class w_T, class y_T> class Product{ - public: - static y_T product(x_T a, w_T w){ - // 'Normal' product - #pragma HLS INLINE - return a * w; - } - static void limit(unsigned multiplier_limit) {} // Nothing to do here -}; - -class Product_nocast{ public: static void limit(unsigned multiplier_limit) {} // Nothing to do here }; -template<class x_T, class w_T, class y_T> -class both_binary : public Product<x_T, w_T, y_T>{ - public: - static y_T product(x_T a, w_T w){ - // specialisation for 1-bit weights and incoming data - #pragma HLS INLINE - return a == w; - } -}; - template<class x_T, class w_T> -class both_binary_nocast : public Product_nocast{ +class both_binary : public Product{ public: static x_T product(x_T a, w_T w){ // specialisation for 1-bit weights and incoming data @@ -52,18 +31,8 @@ class both_binary_nocast : public Product_nocast{ } }; -template<class x_T, class w_T, class y_T> -class weight_binary : public Product<x_T, w_T, y_T>{ - public: - static y_T product(x_T a, w_T w){ - // Specialisation for 1-bit weights, arbitrary data - #pragma HLS INLINE - return w == 0 ? (x_T) -a : a; - } -}; - template<class x_T, class w_T> -class weight_binary_nocast : public Product_nocast{ +class weight_binary : public Product{ public: static x_T product(x_T a, w_T w){ // Specialisation for 1-bit weights, arbitrary data @@ -72,18 +41,8 @@ class weight_binary_nocast : public Product_nocast{ } }; -template<class x_T, class w_T, class y_T> -class data_binary : public Product<x_T, w_T, y_T>{ - public: - static y_T product(x_T a, w_T w){ - // Specialisation for 1-bit data, arbitrary weight - #pragma HLS INLINE - return a == 0 ? (w_T) -w : w; - } -}; - template<class x_T, class w_T> -class data_binary_nocast : public Product_nocast{ +class data_binary : public Product{ public: static w_T product(x_T a, w_T w){ // Specialisation for 1-bit data, arbitrary weight @@ -92,20 +51,8 @@ class data_binary_nocast : public Product_nocast{ } }; -template<class x_T, class w_T, class y_T> -class weight_ternary : public Product<x_T, w_T, y_T>{ - public: - static y_T product(x_T a, w_T w){ - // Specialisation for 2-bit weights, arbitrary data - #pragma HLS INLINE - if (w == 0) return (x_T) 0; - else if(w == -1) return (x_T) -a; - else return (x_T) a; // if(w == 1) - } -}; - template<class x_T, class w_T> -class weight_ternary_nocast : public Product_nocast{ +class weight_ternary : public Product{ public: static x_T product(x_T a, w_T w){ // Specialisation for 2-bit weights, arbitrary data @@ -116,22 +63,8 @@ class weight_ternary_nocast : public Product_nocast{ } }; -template<class x_T, class w_T, class y_T> -class mult : public Product<x_T, w_T, y_T>{ - public: - static y_T product(x_T a, w_T w){ - // 'Normal' product - #pragma HLS INLINE - return a * w; - } - static void limit(unsigned multiplier_limit){ - #pragma HLS INLINE - #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation - } -}; - template<class x_T, class w_T> -class mult_nocast : public Product_nocast{ +class mult : public Product{ public: static auto product(x_T a, w_T w) -> decltype(a*w) { @@ -145,21 +78,8 @@ class mult_nocast : public Product_nocast{ } }; -template<class x_T, class w_T, class y_T> -class weight_exponential : public Product<x_T, w_T, y_T>{ - public: - static y_T product(x_T a, w_T w){ - // Shift product for exponential weights - #pragma HLS INLINE - // shift by the exponent. Negative weights shift right - y_T y = a << w.weight; - // negate or not depending on weight sign - return w.sign == 1 ? (y_T) y : (y_T) -y; - } -}; - template<class x_T, class w_T> -class weight_exponential_nocast : public Product_nocast{ +class weight_exponential : public Product{ public: using rt = x_T; static rt product(x_T a, w_T w){ @@ -174,7 +94,7 @@ class weight_exponential_nocast : public Product_nocast{ }; template<class w_T, int _AP_W> -class weight_exponential_nocast<ap_int<_AP_W>, w_T> : public Product_nocast{ +class weight_exponential<ap_int<_AP_W>, w_T> : public Product{ public: using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_W + decltype(w_T::weight)::width>; static rt product(ap_int<_AP_W> a, w_T w){ @@ -188,7 +108,7 @@ class weight_exponential_nocast<ap_int<_AP_W>, w_T> : public Product_nocast{ }; template<class w_T, int _AP_W> -class weight_exponential_nocast<ap_uint<_AP_W>, w_T> : public Product_nocast{ +class weight_exponential<ap_uint<_AP_W>, w_T> : public Product{ public: using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_W + decltype(w_T::weight)::width + 1>; static rt product(ap_uint<_AP_W> a, w_T w){ @@ -202,7 +122,7 @@ class weight_exponential_nocast<ap_uint<_AP_W>, w_T> : public Product_nocast{ }; template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N> -class weight_exponential_nocast<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{ +class weight_exponential<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{ public: using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_I + decltype(w_T::weight)::width, _AP_Q, _AP_O, _AP_N>; @@ -217,7 +137,7 @@ class weight_exponential_nocast<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> }; template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N> -class weight_exponential_nocast<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{ +class weight_exponential<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{ public: using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_I + decltype(w_T::weight)::width + 1, _AP_Q, _AP_O, _AP_N>; diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py index 5811d7c12..5040fda0e 100644 --- a/hls4ml/templates/vivado_template.py +++ b/hls4ml/templates/vivado_template.py @@ -19,8 +19,8 @@ dense_config_template = """struct config{index} : nnet::dense_config {{ typedef {bias_t} bias_t; typedef {weight_t} weight_t; typedef {index_t} index_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{ @@ -32,7 +32,7 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{ typedef {bias_t} bias_t; typedef {scale_t} scale_t; template<class x_T, class y_T> - using product = nnet::product::{product_type}_nocast<x_T, y_T>; + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" conv1d_config_template = """struct config{index} : nnet::conv1d_config {{ @@ -68,8 +68,8 @@ conv_mult_config_template = """struct config{index}_mult : nnet::dense_config {{ typedef {accum_t} accum_t; typedef {bias_t} bias_t; typedef {weight_t} weight_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" conv2d_config_template = """struct config{index} : nnet::conv2d_config {{ @@ -214,8 +214,8 @@ dot_config_template = """struct config{index} : nnet::dot_config {{ static const unsigned n_out = {n_out}; static const unsigned reuse_factor = {reuse}; typedef {accum_t} accum_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" concat_config_template = """struct config{index} : nnet::concat_config {{ -- GitLab From ed94850c0cf7a8f596c5bb3aa3551291cf062673 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Tue, 18 Jan 2022 11:06:45 -0600 Subject: [PATCH 4/9] update product based on Sioni's suggestions --- .../templates/vivado/nnet_utils/nnet_mult.h | 92 ++++--------------- 1 file changed, 20 insertions(+), 72 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h index 5eb7103a7..586bc65ae 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h @@ -34,32 +34,37 @@ class both_binary : public Product{ template<class x_T, class w_T> class weight_binary : public Product{ public: - static x_T product(x_T a, w_T w){ + static auto product(x_T a, w_T w) -> decltype(-a) + { // Specialisation for 1-bit weights, arbitrary data #pragma HLS INLINE - return w == 0 ? (x_T) -a : a; + if (w == 0) return -a; + else return a; } }; template<class x_T, class w_T> class data_binary : public Product{ public: - static w_T product(x_T a, w_T w){ + static auto product(x_T a, w_T w) -> decltype(-w) + { // Specialisation for 1-bit data, arbitrary weight #pragma HLS INLINE - return a == 0 ? (w_T) -w : w; + if (a == 0) return -w; + else return w; } }; template<class x_T, class w_T> class weight_ternary : public Product{ public: - static x_T product(x_T a, w_T w){ + static auto product(x_T a, w_T w) -> decltype(-a) + { // Specialisation for 2-bit weights, arbitrary data #pragma HLS INLINE - if (w == 0) return (x_T) 0; - else if(w == -1) return (x_T) -a; - else return (x_T) a; // if(w == 1) + if (w == 0) return 0; + else if(w == -1) return -a; + else return a; // if(w == 1) } }; @@ -81,74 +86,17 @@ class mult : public Product{ template<class x_T, class w_T> class weight_exponential : public Product{ public: - using rt = x_T; - static rt product(x_T a, w_T w){ - std::cerr << "Should not match to this function" << std::endl; - // Shift product for exponential weights - #pragma HLS INLINE - // shift by the exponent. Negative weights shift right - rt y = static_cast<rt>(a) << w.weight; - // negate or not depending on weight sign - return w.sign == 1 ? y : static_cast<rt>(-y); - } -}; - -template<class w_T, int _AP_W> -class weight_exponential<ap_int<_AP_W>, w_T> : public Product{ - public: - using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_W + decltype(w_T::weight)::width>; - static rt product(ap_int<_AP_W> a, w_T w){ - // Shift product for exponential weights - #pragma HLS INLINE - // shift by the exponent. Negative weights shift right - rt y = static_cast<rt>(a) << w.weight; - // negate or not depending on weight sign - return w.sign == 1 ? y : static_cast<rt>(-y); - } -}; - -template<class w_T, int _AP_W> -class weight_exponential<ap_uint<_AP_W>, w_T> : public Product{ - public: - using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_W + decltype(w_T::weight)::width + 1>; - static rt product(ap_uint<_AP_W> a, w_T w){ - // Shift product for exponential weights - #pragma HLS INLINE - // shift by the exponent. Negative weights shift right - rt y = static_cast<rt>(a) << w.weight; - // negate or not depending on weight sign - return w.sign == 1 ? y : static_cast<rt>(-y); - } -}; - -template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N> -class weight_exponential<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{ - public: - using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_I + decltype(w_T::weight)::width, - _AP_Q, _AP_O, _AP_N>; - static rt product(ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){ + // Construct the return type from the multiplication equivalent to the largest shifts + // ap_int<pow2(decltype(w_T::weight)::width-1)-1> is the type if the multiplicand equivalent to the largest lshift << + // ap_fixed<pow2(decltype(w_T::weight)::width-1)-1,0> is the type of the multiplicand equivalent to the largest rshift >> + using r_T = decltype(x_T(0) * (ap_int<pow2(decltype(w_T::weight)::width-1)-1>(1)+ap_fixed<pow2(decltype(w_T::weight)::width-1)-1,0>(1))); + static r_T product(x_T a, w_T w){ // Shift product for exponential weights #pragma HLS INLINE // shift by the exponent. Negative weights shift right - rt y = static_cast<rt>(a) << w.weight; - // negate or not depending on weight sign - return w.sign == 1 ? y : static_cast<rt>(-y); - } -}; - -template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N> -class weight_exponential<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{ - public: - using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_I + decltype(w_T::weight)::width + 1, - _AP_Q, _AP_O, _AP_N>; - static rt product(ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){ - // Shift product for exponential weights - #pragma HLS INLINE - // shift by the exponent. Negative weights shift right - // shift by the exponent. Negative weights shift right - rt y = static_cast<rt>(a) << w.weight; + r_T y = static_cast<r_T>(a) << w.weight; // negate or not depending on weight sign - return w.sign == 1 ? y : static_cast<rt>(-y); + return w.sign == 1 ? y : static_cast<r_T>(-y); } }; -- GitLab From 810618d90330c56e9b41ad7b726800b4d150cc00 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Mon, 24 Jan 2022 15:36:20 -0600 Subject: [PATCH 5/9] break up accumulations in nnet_dense_resource with a cast --- hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h index c77059aa7..c0e5d1759 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h @@ -73,7 +73,8 @@ void dense_resource_rf_leq_nin( for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]); + acc[out_index] += static_cast<typename CONFIG_T::accum_t>( + CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index])); // Increment w_index w_index += rufactor; @@ -157,7 +158,8 @@ void dense_resource_rf_gt_nin_rem0( MultLoop: for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]); + acc[out_index] += static_cast<typename CONFIG_T::accum_t>( + CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index])); w_index += rufactor; if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) break; // check out of bounds -- GitLab From 5456eb758e692e1116091d972c6ff19a9665b54b Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Mon, 24 Jan 2022 15:58:27 -0600 Subject: [PATCH 6/9] add auto_po2 qkeras test that revealed problem earlier --- test/pytest/test_qkeras.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/pytest/test_qkeras.py b/test/pytest/test_qkeras.py index cdfc312a8..be061f1ea 100644 --- a/test/pytest/test_qkeras.py +++ b/test/pytest/test_qkeras.py @@ -106,8 +106,8 @@ def randX_100_16(): # Note 4-bit test can still fail sometimes depending on random seed # https://github.com/fastmachinelearning/hls4ml/issues/381 #@pytest.mark.parametrize('bits', [4, 6, 8]) -@pytest.mark.parametrize('bits', [4]) -def test_single_dense_activation_exact(randX_100_16, bits): +@pytest.mark.parametrize('bits,alpha', [(4, 1), (4, 'auto_po2')]) +def test_single_dense_activation_exact(randX_100_16, bits, alpha): ''' Test a single Dense -> Activation layer topology for bit exactness with number of bits parameter @@ -115,7 +115,7 @@ def test_single_dense_activation_exact(randX_100_16, bits): X = randX_100_16 model = Sequential() model.add(QDense(16, input_shape=(16,), name='fc1', - kernel_quantizer=quantized_bits(bits,0,alpha=1), bias_quantizer=quantized_bits(bits,0,alpha=1), + kernel_quantizer=quantized_bits(bits,0,alpha=alpha), bias_quantizer=quantized_bits(bits,0,alpha=1), kernel_initializer='lecun_uniform')) model.add(QActivation(activation=quantized_relu(bits,0), name='relu1')) model.compile() -- GitLab From 14120601d21ff472fec1d16f82e948d98446a0ff Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Thu, 10 Feb 2022 17:18:40 -0600 Subject: [PATCH 7/9] update vivado passes --- hls4ml/backends/vivado/passes/convolution_templates.py | 4 ++-- hls4ml/backends/vivado/passes/core_templates.py | 8 ++++---- hls4ml/backends/vivado/passes/merge_templates.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index c3e35bf1d..22aa5837e 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -13,8 +13,8 @@ conv_mult_config_template = """struct config{index}_mult : nnet::dense_config {{ typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" # Conv1D templates diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 686956735..201562f7f 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -18,8 +18,8 @@ dense_config_template = """struct config{index} : nnet::dense_config {{ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; typedef {index_t.name} index_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' @@ -62,8 +62,8 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{ static const bool store_weights_in_bram = false; typedef {bias_t.name} bias_t; typedef {scale_t.name} scale_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' diff --git a/hls4ml/backends/vivado/passes/merge_templates.py b/hls4ml/backends/vivado/passes/merge_templates.py index 219a7f4e2..863512c4c 100644 --- a/hls4ml/backends/vivado/passes/merge_templates.py +++ b/hls4ml/backends/vivado/passes/merge_templates.py @@ -50,8 +50,8 @@ dot_config_template = """struct config{index} : nnet::dot_config {{ static const unsigned n_out = {n_out}; static const unsigned reuse_factor = {reuse}; typedef {accum_t.name} accum_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; + template<class x_T, class y_T> + using product = nnet::product::{product_type}<x_T, y_T>; }};\n""" class DotConfigTemplate(LayerConfigTemplate): -- GitLab From cb210210a9d28329c83f14260c6c1f633235542d Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Mon, 21 Feb 2022 16:23:02 -0600 Subject: [PATCH 8/9] update product on quartus side --- .../backends/quartus/passes/core_templates.py | 2 - .../quartus/firmware/nnet_utils/nnet_dense.h | 44 +++++++++---------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/hls4ml/backends/quartus/passes/core_templates.py b/hls4ml/backends/quartus/passes/core_templates.py index edbfcc56b..63c3693b0 100644 --- a/hls4ml/backends/quartus/passes/core_templates.py +++ b/hls4ml/backends/quartus/passes/core_templates.py @@ -71,8 +71,6 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{ static const bool store_weights_in_bram = false; typedef {bias_t.name} bias_t; typedef {scale_t.name} scale_t; - template<class x_T, class y_T, class res_T> - using product = nnet::product::{product_type}<x_T, y_T, res_T>; }};\n""" batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h index 18645042f..8075a9671 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h @@ -51,38 +51,34 @@ struct dense_config // partitioning arrays cyclically to go with roll factors? }; -template<class data_T, class weight_T, class ret_T> -inline typename std::enable_if<std::is_same<data_T, ac_int<1, false>>::value - and std::is_same<weight_T, ac_int<1, false>>::value, ac_int<1, false>>::type -product(ac_int<1, false> a, ac_int<1, false> w){ +inline ac_int<1, false> product(ac_int<1, false> a, ac_int<1, false> w) +{ // specialisation for 1-bit weights and incoming data - return (ret_T) (a == w); + return (a == w); } -template<class data_T, class weight_T, class ret_T> -inline typename std::enable_if<(not std::is_same<data_T, ac_int<1, false>>::value) - and std::is_same<weight_T, ac_int<1, false>>::value, ret_T>::type -product(data_T a, ac_int<1, false> w){ +template<class data_T> +auto product(data_T a, ac_int<1, false> w) -> decltype(-a) +{ // Specialisation for 1-bit weights, arbitrary data - return w == 0 ? (ret_T) -a : a; + if (w == 0) return -a; + else return a; } -template<class data_T, class weight_T, class ret_T> -inline typename std::enable_if<(not std::is_same<data_T, ac_int<2, false>>::value) - and std::is_same<weight_T, ac_int<2, true>>::value, ret_T>::type -product(data_T a, ac_int<2, true> w){ +template<class data_T> +auto product(data_T a, ac_int<2, true> w) -> decltype(-a) +{ // Specialisation for 2-bit weights, arbitrary data - if (w == 0) return (ret_T) 0; - else if(w == -1) return (ret_T) -a; - else return (ret_T) a; // if(w == 1) + if (w == 0) return 0; + else if(w == -1) return -a; + else return a; // if(w == 1) } -template<class data_T, class weight_T, class ret_T> -inline typename std::enable_if<(not std::is_same<data_T, ac_int<1, false>>::value) - and (not std::is_same<weight_T, ac_int<1, false>>::value), ret_T>::type -product(data_T a, weight_T w){ +template<class data_T, class weight_T> +auto product(data_T a, weight_T w) -> decltype(a*w) +{ // 'Normal' product - return (ret_T)(a * w); + return a * w; } template<class data_T, class res_T, typename CONFIG_T> @@ -138,7 +134,7 @@ void dense_rf_gt( uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im; if (w_index >= CONFIG_T::reuse_factor_rounded*CONFIG_T::block_factor_rounded) continue; int data_index = d_index[ir][im]; - tmp_acc[im] = product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>(data[data_index], weights[w_index]); + tmp_acc[im] = product(data[data_index], weights[w_index]); } hls_register typename CONFIG_T::accum_t mult[CONFIG_T::multiplier_limit]; ResetMult: @@ -192,7 +188,7 @@ void dense_rf_lt( for (int im = 0, in_index = ir; im < CONFIG_T::block_factor; im++) { uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im; if (ir + CONFIG_T::reuse_factor * im >= CONFIG_T::n_in*CONFIG_T::n_out) continue; - mult[im] = product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>(data[in_index], weights[w_index]); + mult[im] = product(data[in_index], weights[w_index]); in_index += CONFIG_T::reuse_factor; if (in_index >= CONFIG_T::n_in) in_index = ir; } -- GitLab From aa8e4e29743f36ae34f477896eb102e66e845f38 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski <jmitrevs@fnal.gov> Date: Mon, 21 Feb 2022 16:25:39 -0600 Subject: [PATCH 9/9] make different backends not overwrite projects --- test/pytest/test_keras_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py index 0e424f64e..820cb431d 100644 --- a/test/pytest/test_keras_api.py +++ b/test/pytest/test_keras_api.py @@ -37,7 +37,7 @@ def test_dense(backend): keras_prediction = model.predict(X_input) config = hls4ml.utils.config_from_keras_model(model) - output_dir = str(test_root_path / 'hls4mlprj_keras_api_dense') + output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}') hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) @@ -78,7 +78,7 @@ def test_activations(activation_function, backend): X_input = np.random.rand(100,1) keras_prediction = model.predict(X_input) config = hls4ml.utils.config_from_keras_model(model) - output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}'.format(activation_function.__class__.__name__)) + output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}_{}'.format(activation_function.__class__.__name__, backend)) hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile() hls_prediction = hls_model.predict(X_input) -- GitLab