diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index fb65071243339d01b5dce5c8893407bcbb116cb8..1ecf4b37f16b0a0661755fcbf2941e3422f84094 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -93,7 +93,7 @@ def parse_yaml_config(config_file): print('Loading configuration from', config_file) with open(config_file, 'r') as file: - parsed_config = yaml.load(file, Loader=yaml.SafeLoader) + parsed_config = yaml.safe_load(file) return parsed_config def convert_from_config(config): diff --git a/hls4ml/converters/keras/core.py b/hls4ml/converters/keras/core.py index 3cfd33cf84ea3db4fa39f874b03c2f4e40b2c0b5..bb5ac5ec97219e8683295337892f0713024b25bf 100644 --- a/hls4ml/converters/keras/core.py +++ b/hls4ml/converters/keras/core.py @@ -86,18 +86,18 @@ def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader, if layer['class_name'] != 'Activation': layer['activation'] = layer['class_name'] if layer['class_name'] == 'LeakyReLU': - layer['activ_param'] = keras_layer["config"].get('alpha', 0.3) + layer['activ_param'] = keras_layer['config'].get('alpha', 0.3) elif layer['class_name'] == 'ThresholdedReLU': - layer['activ_param'] = keras_layer["config"].get('theta', 1.) + layer['activ_param'] = keras_layer['config'].get('theta', 1.) elif layer['class_name'] == 'ELU': - layer['activ_param'] = keras_layer["config"].get('alpha', 1.) + layer['activ_param'] = keras_layer['config'].get('alpha', 1.) elif layer['class_name'] == 'ReLU': layer['class_name'] = 'Activation' if layer['class_name'] == 'Activation' and layer['activation'] == 'softmax': layer['class_name'] = 'Softmax' - if layer['class_name'] == 'ReLU': - layer['class_name'] = 'Activation' + if layer['class_name'] == 'Softmax': + layer['axis'] = keras_layer['config'].get('axis', -1) return layer, [shape for shape in input_shapes[0]] diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_to_hls.py index 44c8c4b295504d6165dfbd63bb765adc4be7929b..1911545dcb32d9822f4df9fc8a7b8dad0ecd34f4 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_to_hls.py @@ -319,6 +319,7 @@ def keras_to_hls(config): act_layer['class_name'] = layer['activation'] elif layer['activation'] == 'softmax': act_layer['class_name'] = 'Softmax' + act_layer['axis'] = -1 else: act_layer['class_name'] = 'Activation' inputs_map[layer['name']] = act_layer['name'] diff --git a/hls4ml/converters/onnx/core.py b/hls4ml/converters/onnx/core.py index 208fae7caf7a0962c8cfdaa03756679cfb1c63b3..985d941549ddf33cf51f00981f7d8e4264663434 100644 --- a/hls4ml/converters/onnx/core.py +++ b/hls4ml/converters/onnx/core.py @@ -23,7 +23,10 @@ def parse_gemm_layer(reader, node, inputs_map, input_shapes, graph, config): return layer, output_shape #------------------Global paras for activations -activation_layers = ['Relu', 'Tanh', 'Sigmoid', 'LeakyRelu', 'ThresholdedRelu', 'HardSigmoid', 'Elu', 'Selu', 'PRelu', 'Softmax', 'Softsign', 'Softplus', 'Clip'] +# TODO: repair HardSigmoid support +# https://github.com/fastmachinelearning/hls4ml/issues/409 +#activation_layers = ['Relu', 'Tanh', 'Sigmoid', 'LeakyRelu', 'ThresholdedRelu', 'HardSigmoid', 'Elu', 'Selu', 'PRelu', 'Softmax', 'Softsign', 'Softplus', 'Clip'] +activation_layers = ['Relu', 'Tanh', 'Sigmoid', 'LeakyRelu', 'ThresholdedRelu', 'Elu', 'Selu', 'PRelu', 'Softmax', 'Softsign', 'Softplus', 'Clip'] activation_map = {'Relu':'ReLU', 'Tanh':'Activation', 'Sigmoid':'Activation', 'LeakyRelu':'LeakyReLU', diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index eb5f7f3cea76c10788e8366a0718d6f72b62a982..cafbf1a4636b5942b0d2fc6807aa4f68cdb2f96c 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -2,6 +2,8 @@ import numpy as np from hls4ml.converters.pytorch_to_hls import pytorch_handler +# TODO: propagate use_bias info properly +# https://github.com/fastmachinelearning/hls4ml/issues/409 @pytorch_handler('Linear') def parse_linear_layer(pytorch_layer, layer_name, input_shapes, data_reader, config): assert('Linear' in pytorch_layer.__class__.__name__) @@ -15,6 +17,7 @@ def parse_linear_layer(pytorch_layer, layer_name, input_shapes, data_reader, con layer['n_out'] = pytorch_layer.out_features #Handling whether bias is used or not + assert not pytorch_layer.bias is None, "PyTorch Linear with bias=False not yet supported" if pytorch_layer.bias is None: layer['use_bias'] = False else: @@ -24,8 +27,10 @@ def parse_linear_layer(pytorch_layer, layer_name, input_shapes, data_reader, con return layer, output_shape - -activation_layers = ['LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'ReLU'] +# TODO: propagate parametrized activation parameters +# https://github.com/fastmachinelearning/hls4ml/issues/409 +# activation_layers = ['LeakyReLU', 'ThresholdedReLU', 'ELU', 'PReLU', 'Softmax', 'ReLU'] +activation_layers = ['Softmax', 'ReLU'] @pytorch_handler(*activation_layers) def parse_activation_layer(pytorch_layer, layer_name, input_shapes, data_reader, config): diff --git a/hls4ml/model/hls_layers.py b/hls4ml/model/hls_layers.py index c1cca5070018f9dda52f4dca214cb5dc91d9f84d..c730d60ffd55daf2be901f32ad2eb396524e773e 100644 --- a/hls4ml/model/hls_layers.py +++ b/hls4ml/model/hls_layers.py @@ -1384,6 +1384,9 @@ class Softmax(Activation): self.set_attr('implementation', 'latency') else: self.set_attr('implementation', self.model.config.get_strategy(self).lower()) + + if self.model.config.get_config_value('IOType') == 'io_parallel': + assert len(self.get_input_variable().shape) == 1, 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.' class TernaryTanh(Activation): def initialize(self): diff --git a/hls4ml/model/hls_model.py b/hls4ml/model/hls_model.py index 22461074509ebcce727b8e329a48210ae0e0f0d9..0f9c11ae3a7d10f9d4d1db2b76fb1e5da277b608 100644 --- a/hls4ml/model/hls_model.py +++ b/hls4ml/model/hls_model.py @@ -555,21 +555,21 @@ class HLSModel(object): else: xlist = x - for x in xlist: - if not isinstance(x, np.ndarray): + for xi in xlist: + if not isinstance(xi, np.ndarray): raise Exception('Expected numpy.ndarray, but got {}'.format(type(x))) - if not x.flags['C_CONTIGUOUS']: + if not xi.flags['C_CONTIGUOUS']: raise Exception('Array must be c_contiguous, try using numpy.ascontiguousarray(x)') - x = xlist[0] - if x.dtype in [np.single, np.float32]: + x0 = xlist[0] + if x0.dtype in [np.single, np.float32]: top_function = getattr(self._top_function_lib, self.config.get_project_name() + '_float') ctype = ctypes.c_float - elif x.dtype in [np.double, np.float64, np.float_]: + elif x0.dtype in [np.double, np.float64, np.float_]: top_function = getattr(self._top_function_lib, self.config.get_project_name() + '_double') ctype = ctypes.c_double else: - raise Exception('Invalid type ({}) of numpy array. Supported types are: single, float32, double, float64, float_.'.format(x.dtype)) + raise Exception('Invalid type ({}) of numpy array. Supported types are: single, float32, double, float64, float_.'.format(x0.dtype)) top_function.restype = None @@ -584,9 +584,9 @@ class HLSModel(object): else: xlist = x n_samples = [] - for i, x in enumerate(xlist): + for i, xi in enumerate(xlist): expected_size = self.get_input_variables()[i].size() - x_size = np.prod(x.shape) + x_size = np.prod(xi.shape) n_sample, rem = divmod(x_size, expected_size) if rem != 0: raise Exception('Input size mismatch, got {}, expected {}'.format(x_size.shape, self.get_input_variables()[i].shape)) @@ -600,23 +600,25 @@ class HLSModel(object): def predict(self, x): top_function, ctype = self._get_top_function(x) n_samples = self._compute_n_samples(x) + n_inputs = len(self.get_input_variables()) curr_dir = os.getcwd() os.chdir(self.config.get_output_dir() + '/firmware') output = [] - if n_samples == 1: + if n_samples == 1 and n_inputs == 1: x = [x] try: for i in range(n_samples): predictions = np.zeros(self.get_output_variables()[0].size(), dtype=ctype) - if len(self.get_input_variables()) == 1: + if n_inputs == 1: top_function(x[i], predictions, ctypes.byref(ctypes.c_ushort()), ctypes.byref(ctypes.c_ushort())) else: - argtuple = [xi for xi in x[i]] + inp = [xj[i] for xj in x] + argtuple = inp argtuple += [predictions] - argtuple += [ctypes.byref(ctypes.c_ushort()) for i in range(len(x[i])+1)] + argtuple += [ctypes.byref(ctypes.c_ushort()) for k in range(len(inp)+1)] argtuple = tuple(argtuple) top_function(*argtuple) output.append(predictions) @@ -639,6 +641,7 @@ class HLSModel(object): top_function, ctype = self._get_top_function(x) n_samples = self._compute_n_samples(x) + n_inputs = len(self.get_input_variables()) class TraceData(ctypes.Structure): _fields_ = [('name', ctypes.c_char_p), @@ -670,7 +673,7 @@ class HLSModel(object): os.chdir(self.config.get_output_dir() + '/firmware') output = [] - if n_samples == 1: + if n_samples == 1 and n_inputs == 1: x = [x] try: @@ -678,12 +681,13 @@ class HLSModel(object): for i in range(n_samples): predictions = np.zeros(self.get_output_variables()[0].size(), dtype=ctype) - if len(self.get_input_variables()) == 1: + if n_inputs == 1: top_function(x[i], predictions, ctypes.byref(ctypes.c_ushort()), ctypes.byref(ctypes.c_ushort())) else: - argtuple = [xi for xi in x[i]] + inp = [xj[i] for xj in x] + argtuple = inp argtuple += [predictions] - argtuple += [ctypes.byref(ctypes.c_ushort()) for i in range(len(x[i])+1)] + argtuple += [ctypes.byref(ctypes.c_ushort()) for k in range(len(inp)+1)] argtuple = tuple(argtuple) top_function(*argtuple) output.append(predictions) diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 04bd24b35ab570fd0f833a4c55c52b294eb33a16..19915b553e89dab20024f73deda4d120705a5707 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -12,7 +12,7 @@ from hls4ml.model.optimizer.passes.conv_same_pad import InsertZeroPaddingBeforeC from hls4ml.model.optimizer.passes.conv_same_pad import InsertZeroPaddingBeforeConv2D from hls4ml.model.optimizer.passes.pointwise import OptimizePointwiseConv from hls4ml.model.optimizer.passes.clone import CloneOutput -from hls4ml.model.optimizer.passes.repack_stream import ReshapeStream, BroadcastStream +from hls4ml.model.optimizer.passes.repack_stream import ReshapeStream, BroadcastStream, RemoveFinalReshape from hls4ml.model.optimizer.passes.transpose_opt import RemoveUselessTranspose from hls4ml.model.optimizer.passes.multi_dense import ReplaceMultidimensionalDenseWithConv @@ -40,6 +40,7 @@ register_pass('conv1d_same_pad', InsertZeroPaddingBeforeConv1D) register_pass('conv2d_same_pad', InsertZeroPaddingBeforeConv2D) register_pass('optimize_pointwise_conv', OptimizePointwiseConv) register_pass('clone_output', CloneOutput) +register_pass('remove_final_reshape', RemoveFinalReshape) register_pass('reshape_stream', ReshapeStream) register_pass('remove_useless_transpose', RemoveUselessTranspose) register_pass('replace_multidense_conv', ReplaceMultidimensionalDenseWithConv) diff --git a/hls4ml/model/optimizer/passes/repack_stream.py b/hls4ml/model/optimizer/passes/repack_stream.py index f54a264c9dd97f693a29d8385766ced6ff1361ec..de2e2984191aaa81f2bf0d8f0f18ddc4b16f19f8 100644 --- a/hls4ml/model/optimizer/passes/repack_stream.py +++ b/hls4ml/model/optimizer/passes/repack_stream.py @@ -71,7 +71,8 @@ for backend in ['Vivado', 'VivadoAccelerator']: class ReshapeStream(OptimizerPass): ''' Repacks stream for Reshape layer ''' def match(self, node): - return node.__class__.__name__ == 'Reshape' + # do not run optimizer pass for a flatten layer (1 output dimension) + return node.__class__.__name__ == 'Reshape' and len(node.get_output_variable().shape) > 1 def transform(self, model, node): if model.config.backend.name not in ['Vivado', 'VivadoAccelerator'] or \ @@ -121,3 +122,19 @@ class BroadcastStream(OptimizerPass): node.inputs[idx] = brdcst_out return True + +class RemoveFinalReshape(OptimizerPass): + ''' Remove reshape if final layer ''' + def match(self, node): + # match if reshape is final node + return node.__class__.__name__ == 'Reshape' and not node.get_output_nodes() + + def transform(self, model, node): + if model.config.get_config_value('IOType') == 'io_parallel': + print('WARNING: Final layer is a Reshape, which does not affect the output for io_parallel; removing it') + # remove, but don't rewire because it's the output layer + model.remove_node(node, rewire=False) + return True + elif model.config.get_config_value('IOType') == 'io_stream': + print('WARNING: Final layer is a Reshape, which may incur a large resource cost for io_stream; consider removing it') + return False diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h index 3b195cddcfe5ea1a9b025c9874d009e08faabdcd..83f73ad89e9e73d4e2717fe93fd4e22171368360 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h @@ -243,11 +243,11 @@ void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){ // Calculate all the e^x's typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; - #pragma HLS array_partition variable=exp_res complete + #pragma HLS array_partition variable=exp_res complete typename CONFIG_T::exp_table_t exp_sum(0); for(unsigned i = 0; i < CONFIG_T::n_in; i++){ #pragma HLS unroll - unsigned x = softmax_idx_from_real_val<data_T, CONFIG_T>(data[i]); + unsigned x = softmax_idx_from_real_val<data_T, CONFIG_T>(data[i]); exp_res[i] = exp_table[x]; } @@ -298,11 +298,11 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){ // Calculate all the e^x's typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; - #pragma HLS array_partition variable=exp_res complete + #pragma HLS array_partition variable=exp_res complete typename CONFIG_T::exp_table_t exp_sum(0); for(unsigned i = 0; i < CONFIG_T::n_in; i++){ #pragma HLS unroll - unsigned x = softmax_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i]); + unsigned x = softmax_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i]); exp_res[i] = exp_table[x]; } @@ -337,11 +337,11 @@ void init_invert_table_legacy(typename CONFIG_T::table_t table_out[N_TABLE]) // Inversion function: // result = 1/x for (int ii = 0; ii < N_TABLE; ii++) { - // First, convert from table index to X-value (signed 8-bit, range 0 to +64) - float in_val = 64.0*ii/float(N_TABLE); + // First, convert from table index to X-value (signed 8-bit, range 0 to +64) + float in_val = 64.0*ii/float(N_TABLE); // Next, compute lookup table function - if (in_val > 0.0) table_out[ii] = 1.0/in_val; - else table_out[ii] = 0.0; + if (in_val > 0.0) table_out[ii] = 1.0/in_val; + else table_out[ii] = 0.0; } } @@ -376,33 +376,34 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) int data_round; int index; for (int ii=0; ii<CONFIG_T::n_in; ii++) { - data_cache[ii] = data[ii]; - exp_res[ii] = 0; + data_cache[ii] = data[ii]; + exp_res[ii] = 0; } + for (int ii=0; ii<CONFIG_T::n_in; ii++) { - if (CONFIG_T::io_type == io_serial){ - #pragma HLS PIPELINE - } - for (int jj=0; jj<CONFIG_T::n_in; jj++) { - if (ii==jj) exp_diff_res = 1; - else { - data_round = (data_cache[jj]-data_cache[ii])*CONFIG_T::table_size/16; - index = data_round + 8*CONFIG_T::table_size/16; - if (index < 0) index = 0; - if (index > CONFIG_T::table_size-1) index = CONFIG_T::table_size-1; - exp_diff_res = exp_table[index]; - } - exp_res[ii] += exp_diff_res; - } + if (CONFIG_T::io_type == io_serial) { + #pragma HLS PIPELINE + } + for (int jj=0; jj<CONFIG_T::n_in; jj++) { + if (ii==jj) exp_diff_res = 1; + else { + data_round = (data_cache[jj]-data_cache[ii])*CONFIG_T::table_size/16; + index = data_round + 8*CONFIG_T::table_size/16; + if (index < 0) index = 0; + if (index > CONFIG_T::table_size-1) index = CONFIG_T::table_size-1; + exp_diff_res = exp_table[index]; + } + exp_res[ii] += exp_diff_res; + } } //Second loop to invert for (int ii=0; ii<CONFIG_T::n_in; ii++) { - int exp_res_index = exp_res[ii]*CONFIG_T::table_size/64; - if (exp_res_index < 0) exp_res_index = 0; - if (exp_res_index > CONFIG_T::table_size-1) exp_res_index = CONFIG_T::table_size-1; - //typename CONFIG_T::table_t exp_res_invert = invert_table[exp_res_index]; - res[ii] = (res_T) invert_table[exp_res_index]; + int exp_res_index = exp_res[ii]*CONFIG_T::table_size/64; + if (exp_res_index < 0) exp_res_index = 0; + if (exp_res_index > CONFIG_T::table_size-1) exp_res_index = CONFIG_T::table_size-1; + //typename CONFIG_T::table_t exp_res_invert = invert_table[exp_res_index]; + res[ii] = (res_T) invert_table[exp_res_index]; } } @@ -420,7 +421,7 @@ void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){ case softmax_implementation::legacy: softmax_legacy<data_T, res_T, CONFIG_T>(data, res); break; - } + } } // ************************************************* @@ -720,7 +721,7 @@ void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_ template<class data_T, class res_T, typename CONFIG_T> void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { - elu<data_T, res_T, CONFIG_T>(data, 1.0, res); + elu<data_T, res_T, CONFIG_T>(data, 1.0, res); } // ************************************************* @@ -810,26 +811,22 @@ void prelu(data_T data[CONFIG_T::n_in], data_T alpha[CONFIG_T::n_in], res_T res template<class data_T, class res_T, typename CONFIG_T> void binary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + if (CONFIG_T::io_type == io_parallel){ + #pragma HLS PIPELINE + } - if (CONFIG_T::io_type == io_parallel){ - #pragma HLS PIPELINE - } - - data_T datareg; - res_T cache; - for (int ii=0; ii<CONFIG_T::n_in; ii++) { - - if (CONFIG_T::io_type == io_serial){ - #pragma HLS PIPELINE - } - datareg = data[ii]; - if( datareg > 0 ) cache = 1; - else cache = -1; - - res[ii] = (res_T) cache; - - } - + data_T datareg; + res_T cache; + for (int ii=0; ii<CONFIG_T::n_in; ii++) { + if (CONFIG_T::io_type == io_serial){ + #pragma HLS PIPELINE + } + datareg = data[ii]; + if( datareg > 0 ) cache = 1; + else cache = -1; + + res[ii] = (res_T) cache; + } } // ************************************************* @@ -839,25 +836,23 @@ template<class data_T, class res_T, typename CONFIG_T> void ternary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { - if (CONFIG_T::io_type == io_parallel){ - #pragma HLS PIPELINE - } + if (CONFIG_T::io_type == io_parallel){ + #pragma HLS PIPELINE + } - data_T datareg; - res_T cache; - for (int ii=0; ii<CONFIG_T::n_in; ii++) { - - if (CONFIG_T::io_type == io_serial){ - #pragma HLS PIPELINE - } - datareg = 2*data[ii]; - if( datareg > 1 ) cache = 1; - else if( datareg > -1 && datareg <= 1) cache=0; - else cache = -1; + data_T datareg; + res_T cache; + for (int ii=0; ii<CONFIG_T::n_in; ii++) { + if (CONFIG_T::io_type == io_serial) { + #pragma HLS PIPELINE + } + datareg = 2*data[ii]; + if( datareg > 1 ) cache = 1; + else if( datareg > -1 && datareg <= 1) cache=0; + else cache = -1; - res[ii] = (res_T) cache; - - } + res[ii] = (res_T) cache; + } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h index 09968ad4499ab18d52f134fc31139c7f4ae33519..e2138039a05d712bc0b7beaa2f0f6ec7a68d3046 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h @@ -139,40 +139,36 @@ void softmax_latency(hls::stream<data_T> &data, hls::stream<res_T> &res){ initialized = true; } - constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); - constexpr unsigned ii = CONFIG_T::n_in / multiplier_limit; + constexpr unsigned multiplier_limit = DIV_ROUNDUP(data_T::size, CONFIG_T::reuse_factor); + constexpr unsigned ii = data_T::size / multiplier_limit; // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; + typename CONFIG_T::exp_table_t exp_res[data_T::size]; #pragma HLS array_partition variable=exp_res complete typename CONFIG_T::exp_table_t exp_sum(0); SoftmaxExpLoop: for(unsigned i = 0; i < CONFIG_T::n_in / data_T::size; i++){ - if (CONFIG_T::n_in / data_T::size > 1) { - #pragma HLS PIPELINE - } + #pragma HLS PIPELINE II=ii + data_T in_pack = data.read(); SoftmaxExpPackLoop: for(unsigned j = 0; j < data_T::size; j++){ #pragma HLS UNROLL unsigned x = softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(in_pack[j]); - exp_res[i * data_T::size + j] = exp_table[x]; + exp_res[j] = exp_table[x]; } - } - // Explicitly sum the results with an adder tree. - // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add<typename CONFIG_T::exp_table_t> op_add; - exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add); + // Explicitly sum the results with an adder tree. + // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing + Op_add<typename CONFIG_T::exp_table_t> op_add; + exp_sum = reduce<typename CONFIG_T::exp_table_t, data_T::size, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add); - typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)]; - SoftmaxInvLoop: for(unsigned i = 0; i < CONFIG_T::n_in / res_T::size; i++){ - #pragma HLS PIPELINE II=ii + typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)]; res_T out_pack; #pragma HLS DATA_PACK variable=out_pack SoftmaxInvPackLoop: for(unsigned j = 0; j < res_T::size; j++){ #pragma HLS UNROLL #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation - out_pack[i * res_T::size + j] = exp_res[i * res_T::size + j] * inv_exp_sum; + out_pack[j] = exp_res[j] * inv_exp_sum; } res.write(out_pack); } @@ -199,58 +195,54 @@ void softmax_stable(hls::stream<data_T> &data, hls::stream<res_T> &res){ initialized = true; } - constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); - constexpr unsigned ii = CONFIG_T::n_in / multiplier_limit; + constexpr unsigned multiplier_limit = DIV_ROUNDUP(data_T::size, CONFIG_T::reuse_factor); + constexpr unsigned ii = data_T::size / multiplier_limit; - typename data_T::value_type data_array[CONFIG_T::n_in]; + typename data_T::value_type data_array[data_T::size]; #pragma HLS ARRAY_PARTITION variable=data_array complete SoftmaxArrayLoop: for(unsigned i = 0; i < CONFIG_T::n_in / data_T::size; i++){ - if (CONFIG_T::n_in / data_T::size > 1) { - #pragma HLS PIPELINE - } + #pragma HLS PIPELINE II=ii + data_T in_pack = data.read(); SoftmaxArrayPackLoop: for(unsigned j = 0; j < data_T::size; j++){ #pragma HLS UNROLL - data_array[i * data_T::size + j] = in_pack[j]; + data_array[j] = in_pack[j]; } - } - // Find the max and compute all delta(x_i, x_max) - Op_max<typename data_T::value_type> op_max; - typename data_T::value_type x_max = reduce<typename data_T::value_type, CONFIG_T::n_in, Op_max<typename data_T::value_type>>(data_array, op_max); + // Find the max and compute all delta(x_i, x_max) + Op_max<typename data_T::value_type> op_max; + typename data_T::value_type x_max = reduce<typename data_T::value_type, data_T::size, Op_max<typename data_T::value_type>>(data_array, op_max); - // For the diffs, use the same type as the input but force rounding and saturation - ap_fixed<data_T::value_type::width, data_T::value_type::iwidth,AP_RND,AP_SAT> d_xi_xmax[CONFIG_T::n_in]; - for(unsigned i = 0; i < CONFIG_T::n_in; i++){ - #pragma HLS UNROLL - d_xi_xmax[i] = data_array[i] - x_max; - } + // For the diffs, use the same type as the input but force rounding and saturation + ap_fixed<data_T::value_type::width, data_T::value_type::iwidth,AP_RND,AP_SAT> d_xi_xmax[data_T::size]; + for(unsigned j = 0; j < data_T::size; j++){ + #pragma HLS UNROLL + d_xi_xmax[j] = data_array[j] - x_max; + } - // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; - #pragma HLS ARRAY_PARTITION variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); - for(unsigned i = 0; i < CONFIG_T::n_in; i++){ - #pragma HLS UNROLL - unsigned x = softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[i]); - exp_res[i] = exp_table[x]; - } + // Calculate all the e^x's + typename CONFIG_T::exp_table_t exp_res[data_T::size]; + #pragma HLS ARRAY_PARTITION variable=exp_res complete + typename CONFIG_T::exp_table_t exp_sum(0); + for(unsigned j = 0; j < data_T::size; j++){ + #pragma HLS UNROLL + unsigned x = softmax_idx_from_real_val<typename data_T::value_type, CONFIG_T>(d_xi_xmax[j]); + exp_res[j] = exp_table[x]; + } - // Explicitly sum the results with an adder tree. - // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add<typename CONFIG_T::exp_table_t> op_add; - exp_sum = reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add); + // Explicitly sum the results with an adder tree. + // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing + Op_add<typename CONFIG_T::exp_table_t> op_add; + exp_sum = reduce<typename CONFIG_T::exp_table_t, data_T::size, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add); - typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)]; - SoftmaxInvLoop: for(unsigned i = 0; i < CONFIG_T::n_in / res_T::size; i++){ - #pragma HLS PIPELINE II=ii + typename CONFIG_T::inv_table_t inv_exp_sum = invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t,CONFIG_T>(exp_sum)]; res_T out_pack; #pragma HLS DATA_PACK variable=out_pack SoftmaxInvPackLoop: for(unsigned j = 0; j < res_T::size; j++){ #pragma HLS UNROLL #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation - out_pack[i * res_T::size + j] = exp_res[i * res_T::size + j] * inv_exp_sum; + out_pack[j] = exp_res[j] * inv_exp_sum; } res.write(out_pack); } @@ -275,52 +267,48 @@ void softmax_legacy(hls::stream<data_T> &data, hls::stream<res_T> &res) { } // Index into the lookup table based on data for exponentials - typename CONFIG_T::table_t exp_res[CONFIG_T::n_in]; + typename CONFIG_T::table_t exp_res[data_T::size]; typename CONFIG_T::table_t exp_diff_res; - typename data_T::value_type data_cache[CONFIG_T::n_in]; + typename data_T::value_type data_cache[data_T::size]; - SoftmaxInitLoop: for(unsigned i = 0; i < CONFIG_T::n_in / data_T::size; i++) { + SoftmaxInitLoop: for(unsigned s = 0; s < CONFIG_T::n_in / data_T::size; s++) { #pragma HLS PIPELINE data_T in_pack = data.read(); SoftmaxInitPackLoop: for(unsigned j = 0; j < data_T::size; j++) { #pragma HLS UNROLL - data_cache[i * data_T::size + j] = in_pack[j]; - exp_res[i * data_T::size + j] = 0; + data_cache[j] = in_pack[j]; + exp_res[j] = 0; } - } - SoftmaxExpLoop: for (int i = 0; i < CONFIG_T::n_in; i++) { - #pragma HLS PIPELINE - SoftmaxExpInner: for (int j = 0; j < CONFIG_T::n_in; j++) { + SoftmaxExpLoop: for (int i = 0; i < data_T::size; i++) { #pragma HLS UNROLL - - if (i == j) { - exp_diff_res = 1; - } else { - int data_round = (data_cache[j] - data_cache[i]) * CONFIG_T::table_size / 16; - int index = data_round + 8 * CONFIG_T::table_size / 16; - if (index < 0) index = 0; - if (index > CONFIG_T::table_size - 1) index = CONFIG_T::table_size - 1; - exp_diff_res = exp_table[index]; + SoftmaxExpInner: for (int j = 0; j < data_T::size; j++) { + #pragma HLS UNROLL + + if (i == j) { + exp_diff_res = 1; + } else { + int data_round = (data_cache[j] - data_cache[i]) * CONFIG_T::table_size / 16; + int index = data_round + 8 * CONFIG_T::table_size / 16; + if (index < 0) index = 0; + if (index > CONFIG_T::table_size - 1) index = CONFIG_T::table_size - 1; + exp_diff_res = exp_table[index]; + } + + exp_res[i] += exp_diff_res; } - - exp_res[i] += exp_diff_res; } - } - - SoftmaxInvLoop: for(unsigned i = 0; i < CONFIG_T::n_in / res_T::size; i++) { - #pragma HLS PIPELINE res_T out_pack; #pragma HLS DATA_PACK variable=out_pack SoftmaxInvPackLoop: for(unsigned j = 0; j < res_T::size; j++) { #pragma HLS UNROLL - - int exp_res_index = exp_res[i * res_T::size + j] * CONFIG_T::table_size / 64; + + int exp_res_index = exp_res[j] * CONFIG_T::table_size / 64; if (exp_res_index < 0) exp_res_index = 0; if (exp_res_index > CONFIG_T::table_size - 1) exp_res_index = CONFIG_T::table_size - 1; - - out_pack[i * res_T::size + j] = (typename res_T::value_type) invert_table[exp_res_index]; + + out_pack[j] = (typename res_T::value_type) invert_table[exp_res_index]; } res.write(out_pack); } @@ -328,6 +316,8 @@ void softmax_legacy(hls::stream<data_T> &data, hls::stream<res_T> &res) { template<class data_T, class res_T, typename CONFIG_T> void softmax(hls::stream<data_T> &data, hls::stream<res_T> &res){ + assert(CONFIG_T::axis == -1); + switch(CONFIG_T::implementation){ case softmax_implementation::latency: softmax_latency<data_T, res_T, CONFIG_T>(data, res); diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h index a146784550e3c7544e63d11c9c78e88bb57c9f84..a251051285202778b3435da5324cee34baa347a9 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h @@ -154,20 +154,18 @@ void im2col_2d_cl( const int col) { int index = 0; - for (int channel = CONFIG_T::n_chan; channel--; data++) { + for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { #pragma HLS UNROLL - for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { - int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height; - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height; + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { if (input_row < 0 || input_row >= CONFIG_T::in_height) { data_col[index++] = 0; } else { int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width + col * CONFIG_T::stride_width; if (input_col >= 0 && input_col < CONFIG_T::in_width) { - //*(data_col++) = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan]; - data_col[index++] = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan]; + data_col[index++] = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan + channel]; } else { - //*(data_col++) = 0; data_col[index++] = 0; } } @@ -209,7 +207,6 @@ void conv_2d_resource_cl( FiltLoop: for (int k = 0; k < CONFIG_T::n_filt; k++) { res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k]; - //res[k * CONFIG_T::out_height * CONFIG_T::out_width + i * CONFIG_T::out_width + j] = res_col[k]; // Transposed order } } } diff --git a/hls4ml/templates/vivado_accelerator/pynq-z2/python_drivers/axi_stream_driver.py b/hls4ml/templates/vivado_accelerator/pynq-z2/python_drivers/axi_stream_driver.py index 50104202a016ccd77cc2b785fc90b57af47be28c..4adb187ab43a02e0f22d9d06775bd9a3dddf6344 100644 --- a/hls4ml/templates/vivado_accelerator/pynq-z2/python_drivers/axi_stream_driver.py +++ b/hls4ml/templates/vivado_accelerator/pynq-z2/python_drivers/axi_stream_driver.py @@ -4,32 +4,36 @@ from datetime import datetime import pynq.lib.dma import numpy as np + class NeuralNetworkOverlay(Overlay): - def __init__(self, bitfile_name, dtbo=None, download=True, ignore_version=False, device=None): - - super().__init__(bitfile_name, dtbo=dtbo, download=download, ignore_version=ignore_version, device=device) - + def __init__(self, bitfile_name, x_shape, y_shape, dtype=np.float32, dtbo=None, download=True, ignore_version=False, + device=None): + super().__init__(bitfile_name, dtbo=None, download=True, ignore_version=False, device=None) + self.sendchannel = self.hier_0.axi_dma_0.sendchannel + self.recvchannel = self.hier_0.axi_dma_0.recvchannel + self.input_buffer = allocate(shape=x_shape, dtype=dtype) + self.output_buffer = allocate(shape=y_shape, dtype=dtype) + def _print_dt(self, timea, timeb, N): - dt = (timeb - timea) - dts = dt.seconds + dt.microseconds * 10**-6 + dt = (timeb - timea) + dts = dt.seconds + dt.microseconds * 10 ** -6 rate = N / dts print("Classified {} samples in {} seconds ({} inferences / s)".format(N, dts, rate)) return dts, rate - def predict(self, X, y_shape, dtype=np.float32, debug=None, profile=False, encode=None, decode=None): + + def predict(self, X, debug=False, profile=False, encode=None, decode=None): """ Obtain the predictions of the NN implemented in the FPGA. Parameters: - X : the input vector. Should be numpy ndarray. - - y_shape : the shape of the output vector. Needed to the accelerator to set the TLAST bit properly and - for sizing the output vector shape. - - dtype : the data type of the elements of the input/output vectors. - Note: it should be set depending on the interface of the accelerator; if it uses 'float' - types for the 'data' AXI-Stream field, 'np.float32' dtype is the correct one to use. + - dtype : the data type of the elements of the input/output vectors. + Note: it should be set depending on the interface of the accelerator; if it uses 'float' + types for the 'data' AXI-Stream field, 'np.float32' dtype is the correct one to use. Instead if it uses 'ap_fixed<A,B>', 'np.intA' is the correct one to use (note that A cannot - any integer value, but it can assume {..., 8, 16, 32, ...} values. Check `numpy` + any integer value, but it can assume {..., 8, 16, 32, ...} values. Check `numpy` doc for more info). - In this case the encoding/decoding has to be computed by the PS. For example for - 'ap_fixed<16,6>' type the following 2 functions are the correct one to use for encode/decode + In this case the encoding/decoding has to be computed by the PS. For example for + 'ap_fixed<16,6>' type the following 2 functions are the correct one to use for encode/decode 'float' -> 'ap_fixed<16,6>': ``` def encode(xi): @@ -48,24 +52,24 @@ class NeuralNetworkOverlay(Overlay): timea = datetime.now() if encode is not None: X = encode(X) - with allocate(shape=X.shape, dtype=dtype) as input_buffer, \ - allocate(shape=y_shape, dtype=dtype) as output_buffer: - input_buffer[:] = X - self.hier_0.axi_dma_0.sendchannel.transfer(input_buffer) - self.hier_0.axi_dma_0.recvchannel.transfer(output_buffer) - if debug: - print("Transfer OK") - self.hier_0.axi_dma_0.sendchannel.wait() - if debug: - print("Send OK") - self.hier_0.axi_dma_0.recvchannel.wait() - if debug: - print("Receive OK") - result = output_buffer.copy() + self.input_buffer[:] = X + self.sendchannel.transfer(self.input_buffer) + self.recvchannel.transfer(self.output_buffer) + if debug: + print("Transfer OK") + self.sendchannel.wait() + if debug: + print("Send OK") + self.recvchannel.wait() + if debug: + print("Receive OK") + # result = self.output_buffer.copy() if decode is not None: - result = decode(result) + self.output_buffer = decode(self.output_buffer) + if profile: timeb = datetime.now() dts, rate = self._print_dt(timea, timeb, len(X)) - return result, dts, rate - return result \ No newline at end of file + return self.output_buffer, dts, rate + else: + return self.output_buffer \ No newline at end of file diff --git a/hls4ml/templates/vivado_accelerator/zcu102/python_drivers/axi_stream_driver.py b/hls4ml/templates/vivado_accelerator/zcu102/python_drivers/axi_stream_driver.py index 50104202a016ccd77cc2b785fc90b57af47be28c..4adb187ab43a02e0f22d9d06775bd9a3dddf6344 100644 --- a/hls4ml/templates/vivado_accelerator/zcu102/python_drivers/axi_stream_driver.py +++ b/hls4ml/templates/vivado_accelerator/zcu102/python_drivers/axi_stream_driver.py @@ -4,32 +4,36 @@ from datetime import datetime import pynq.lib.dma import numpy as np + class NeuralNetworkOverlay(Overlay): - def __init__(self, bitfile_name, dtbo=None, download=True, ignore_version=False, device=None): - - super().__init__(bitfile_name, dtbo=dtbo, download=download, ignore_version=ignore_version, device=device) - + def __init__(self, bitfile_name, x_shape, y_shape, dtype=np.float32, dtbo=None, download=True, ignore_version=False, + device=None): + super().__init__(bitfile_name, dtbo=None, download=True, ignore_version=False, device=None) + self.sendchannel = self.hier_0.axi_dma_0.sendchannel + self.recvchannel = self.hier_0.axi_dma_0.recvchannel + self.input_buffer = allocate(shape=x_shape, dtype=dtype) + self.output_buffer = allocate(shape=y_shape, dtype=dtype) + def _print_dt(self, timea, timeb, N): - dt = (timeb - timea) - dts = dt.seconds + dt.microseconds * 10**-6 + dt = (timeb - timea) + dts = dt.seconds + dt.microseconds * 10 ** -6 rate = N / dts print("Classified {} samples in {} seconds ({} inferences / s)".format(N, dts, rate)) return dts, rate - def predict(self, X, y_shape, dtype=np.float32, debug=None, profile=False, encode=None, decode=None): + + def predict(self, X, debug=False, profile=False, encode=None, decode=None): """ Obtain the predictions of the NN implemented in the FPGA. Parameters: - X : the input vector. Should be numpy ndarray. - - y_shape : the shape of the output vector. Needed to the accelerator to set the TLAST bit properly and - for sizing the output vector shape. - - dtype : the data type of the elements of the input/output vectors. - Note: it should be set depending on the interface of the accelerator; if it uses 'float' - types for the 'data' AXI-Stream field, 'np.float32' dtype is the correct one to use. + - dtype : the data type of the elements of the input/output vectors. + Note: it should be set depending on the interface of the accelerator; if it uses 'float' + types for the 'data' AXI-Stream field, 'np.float32' dtype is the correct one to use. Instead if it uses 'ap_fixed<A,B>', 'np.intA' is the correct one to use (note that A cannot - any integer value, but it can assume {..., 8, 16, 32, ...} values. Check `numpy` + any integer value, but it can assume {..., 8, 16, 32, ...} values. Check `numpy` doc for more info). - In this case the encoding/decoding has to be computed by the PS. For example for - 'ap_fixed<16,6>' type the following 2 functions are the correct one to use for encode/decode + In this case the encoding/decoding has to be computed by the PS. For example for + 'ap_fixed<16,6>' type the following 2 functions are the correct one to use for encode/decode 'float' -> 'ap_fixed<16,6>': ``` def encode(xi): @@ -48,24 +52,24 @@ class NeuralNetworkOverlay(Overlay): timea = datetime.now() if encode is not None: X = encode(X) - with allocate(shape=X.shape, dtype=dtype) as input_buffer, \ - allocate(shape=y_shape, dtype=dtype) as output_buffer: - input_buffer[:] = X - self.hier_0.axi_dma_0.sendchannel.transfer(input_buffer) - self.hier_0.axi_dma_0.recvchannel.transfer(output_buffer) - if debug: - print("Transfer OK") - self.hier_0.axi_dma_0.sendchannel.wait() - if debug: - print("Send OK") - self.hier_0.axi_dma_0.recvchannel.wait() - if debug: - print("Receive OK") - result = output_buffer.copy() + self.input_buffer[:] = X + self.sendchannel.transfer(self.input_buffer) + self.recvchannel.transfer(self.output_buffer) + if debug: + print("Transfer OK") + self.sendchannel.wait() + if debug: + print("Send OK") + self.recvchannel.wait() + if debug: + print("Receive OK") + # result = self.output_buffer.copy() if decode is not None: - result = decode(result) + self.output_buffer = decode(self.output_buffer) + if profile: timeb = datetime.now() dts, rate = self._print_dt(timea, timeb, len(X)) - return result, dts, rate - return result \ No newline at end of file + return self.output_buffer, dts, rate + else: + return self.output_buffer \ No newline at end of file diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py index 5b1c3b03c89f98ac9697e276a399f7add52b0c86..149b52f1d5b4291d3fedc96ef6724d47717792a3 100644 --- a/hls4ml/templates/vivado_template.py +++ b/hls4ml/templates/vivado_template.py @@ -121,6 +121,7 @@ softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{ static const unsigned table_size = {table_size}; static const unsigned io_type = nnet::{iotype}; static const unsigned reuse_factor = {reuse}; + static const unsigned axis = {axis}; static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation}; typedef {exp_table_t} exp_table_t; typedef {inv_table_t} inv_table_t; diff --git a/hls4ml/utils/example_models.py b/hls4ml/utils/example_models.py index 10fd5f090e1a91d23abdd1b8023a65b1b12a3a32..fdc405ca3a7ae992f11e549bede65370734afe7d 100644 --- a/hls4ml/utils/example_models.py +++ b/hls4ml/utils/example_models.py @@ -88,7 +88,7 @@ def _load_example_config(model_name): #Load the configuration from local yml file with open(config_name, 'r') as ymlfile: - config = yaml.load(ymlfile) + config = yaml.safe_load(ymlfile) return config diff --git a/test/pytest/test_cnn_mnist.py b/test/pytest/test_cnn_mnist.py index 274f25a547c78ec5ad5e4e5529b3ce0ec6994152..2397def2a9a59fe01d3e9b1a499cb0962198e39c 100644 --- a/test/pytest/test_cnn_mnist.py +++ b/test/pytest/test_cnn_mnist.py @@ -28,16 +28,15 @@ def mnist_model(): model.load_weights('../../example-models/keras/qkeras_mnist_cnn_weights.h5') return model -# TODO: add ('io_parallel', 'resource') when it can pass -# https://github.com/fastmachinelearning/hls4ml/issues/375 @pytest.fixture @pytest.mark.parametrize('settings', [('io_parallel', 'latency'), + ('io_parallel', 'resource'), ('io_stream', 'latency'), ('io_stream', 'resource')]) def hls_model(settings): io_type = settings[0] strategy = settings[1] - config = yaml.load(open('../../example-models/config-files/qkeras_mnist_cnn_config.yml').read()) + config = yaml.safe_load(open('../../example-models/config-files/qkeras_mnist_cnn_config.yml').read()) config['KerasJson'] = '../../example-models/keras/qkeras_mnist_cnn.json' config['KerasH5'] = '../../example-models/keras/qkeras_mnist_cnn_weights.h5' config['OutputDir'] = 'hls4mlprj_cnn_mnist_{}_{}'.format(io_type, strategy) @@ -49,10 +48,12 @@ def hls_model(settings): return hls_model @pytest.mark.parametrize('settings', [('io_parallel', 'latency'), + ('io_parallel', 'resource'), ('io_stream', 'latency'), ('io_stream', 'resource')]) def test_accuracy(mnist_data, mnist_model, hls_model): x_train, y_train, x_test, y_test = mnist_data + x_test, y_test = x_test[:5000], y_test[:5000] model = mnist_model # model under test predictions and accuracy y_keras = model.predict(x_test) diff --git a/test/pytest/test_graph.py b/test/pytest/test_graph.py index 1d0dcca9c5bcda3d48f6b85e7464fa168611ecb2..8f9e62f4b83428c1a8f8f832948c94745bc0d55e 100644 --- a/test/pytest/test_graph.py +++ b/test/pytest/test_graph.py @@ -1,6 +1,7 @@ import hls4ml import numpy as np import pytest +import tensorflow as tf class Reader: def get_weights_data(self, name, var): @@ -94,15 +95,47 @@ def test_graph_manipulation(parameters, iotype): np.testing.assert_array_equal(expected_layers, actual_layers) @pytest.mark.parametrize('iotype', ['io_parallel', 'io_stream']) -def test_graph_branch(iotype): - odir = 'hls4mlprj_graph_branch_model' +@pytest.mark.parametrize('batch', [1, 100]) +def test_graph_branch(iotype, batch): + odir = 'hls4mlprj_graph_branch_model_{}_batch{}'.format(iotype, batch) model = branch_model(odir, iotype) original_layers = np.array([layer.name for layer in list(model.get_layers())]) model.compile() hls4ml.utils.plot_model(model, show_shapes=True, show_precision=True, to_file='{}/model.png'.format(odir)) - X0 = np.random.rand(1,1) - X1 = np.random.rand(1,1) + X0 = np.random.rand(batch, 1) + X1 = np.random.rand(batch, 1) y_expected = 2*(X0+X1) y = model.predict([X0, X1]).reshape(y_expected.shape) # check the output np.testing.assert_allclose(y, y_expected, rtol=1, atol=2**-16) + +@pytest.mark.parametrize('iotype', ['io_parallel', 'io_stream']) +def test_final_reshape(iotype): + ''' Test case for a model with a Reshape as the final layer ''' + inputs = tf.keras.layers.Input(shape=(1,1,1)) # 1 input pixel + conv = tf.keras.layers.Conv2D(6,1) # 6 filters, 1x1 kernel + x = conv(inputs) + conv.set_weights([np.linspace(1,6,6).reshape(1,1,1,6), np.zeros(6)]) # ascending int weights, 0 bias + x = tf.keras.layers.Reshape((3,2))(x) # reshape the (1,1,6) output to (3,2) + model = tf.keras.models.Model(inputs=inputs, outputs=x) + + # create the HLSModel + config = hls4ml.utils.config_from_keras_model(model, granularity='model') + hls_model = hls4ml.converters.convert_from_keras_model(model, + output_dir=f'hls4mlprj_graph_final_reshape_{iotype}', + backend='Vivado', + io_type = iotype, + hls_config=config) + hls_model.compile() + + # Test on ascending integers. The weights mean that each output pixel/neuron has + # a different value + X = np.linspace(-4,4,9).reshape(9,1,1,1) + y = model.predict(X) + y_hls = hls_model.predict(X).reshape(y.shape) + # because of integer inputs and integer weights, we can expect exact matching + np.testing.assert_allclose(y, y_hls, rtol=0) + + + + diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 5263a55ba4b3a3524f788ee841f1640ce1c76944..ef98902cc99d76d98f9ccd3f84894230d4d53179 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -4,42 +4,51 @@ import numpy as np import pytest from sklearn.metrics import accuracy_score -def flat_distribution(N, M): - return np.random.rand(N, M) -def high_accuracy_distribution(N, M): - '''Start with a flat distribution, then pick a random member of each row to amplify''' - x = np.random.rand(N, M) - imax = np.random.randint(0,M,size=N) - x[:,imax] *= 10 - return x +def flat_distribution(shape): + return np.random.rand(*shape) + + +def high_accuracy_distribution(shape): + '''Start with a flat distribution, then pick a random member of each row to amplify''' + x = np.random.rand(*shape) + imax = np.random.randint(0, shape[1], size=shape[0]) + x[:, imax] *= 10 + return x + @pytest.fixture() -def generate_data(function): - return function(1000,8) +def generate_data(function, input_shape): + return function((1000, *input_shape)) + # TODO: include latency strategy with flat_distribution when it can be made to pass -#@pytest.mark.parametrize('strategy,function', [('latency', flat_distribution), -# ('stable', flat_distribution), -# ('stable', high_accuracy_distribution)]) -@pytest.mark.parametrize('strategy,function', [('stable', flat_distribution), - ('stable', high_accuracy_distribution)]) -def test_softmax(strategy, generate_data): - X = generate_data - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Activation(input_shape=(8,), activation='softmax', name='softmax')) - model.compile() - cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') - cfg['LayerName']['softmax']['Strategy'] = strategy - cfg['LayerName']['softmax']['inv_table_t'] = 'ap_fixed<18,8,AP_RND,AP_SAT>' - cfg['LayerName']['softmax']['exp_table_t'] = 'ap_fixed<18,8,AP_RND,AP_SAT>' - hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=cfg, output_dir='hls4mlprj_softmax_{}'.format(strategy)) - hls_model.compile() - y_keras = model.predict(X) - y_hls4ml = hls_model.predict(X) - - acc_hls4ml = accuracy_score(np.argmax(y_keras, axis=1), np.argmax(y_hls4ml, axis=1)) - - print('Accuracy hls4ml relative to keras: {}'.format(acc_hls4ml)) - - assert acc_hls4ml >= 0.98 +@pytest.mark.parametrize('strategy,function,input_shape,io_type', [#('latency', flat_distribution, (8,), 'io_parallel'), + #('latency', flat_distribution, (8, 8, 3), 'io_stream'), + ('stable', flat_distribution, (8,), 'io_parallel'), + ('stable', high_accuracy_distribution, (8,), 'io_parallel'), + ('stable', flat_distribution, (8,), 'io_stream'), + ('stable', high_accuracy_distribution, (8,), 'io_stream'), + # Multi-dimensional tests, only for io_stream for now + ('stable', flat_distribution, (8, 8, 3), 'io_stream'), + ('stable', high_accuracy_distribution, (8, 8, 3), 'io_stream')]) +def test_softmax(strategy, generate_data, input_shape, io_type): + X = generate_data + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) + model.compile() + cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') + cfg['LayerName']['softmax']['Strategy'] = strategy + cfg['LayerName']['softmax']['inv_table_t'] = 'ap_fixed<18,8,AP_RND,AP_SAT>' + cfg['LayerName']['softmax']['exp_table_t'] = 'ap_fixed<18,8,AP_RND,AP_SAT>' + hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=cfg, io_type=io_type, + output_dir='hls4mlprj_softmax_{}'.format(strategy)) + hls_model.compile() + y_keras = model.predict(X) + y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) + + acc_hls4ml = accuracy_score(np.argmax(y_keras, axis=-1).ravel(), np.argmax(y_hls4ml, axis=-1).ravel()) + + print('Accuracy hls4ml relative to keras: {}'.format(acc_hls4ml)) + + assert acc_hls4ml >= 0.98