Multidimensional input support for Dense layer, Resource vs resource in Conv2D layer, lower reusefactor results in almost-0 resource usage
Created by: vandenBergArthur
TL;DR at the bottom
Hi all, As mentioned in my other post #747, I am trying to implement a graph convolution. So, I need a matrix multiplication A * B = C where A is my input tensor and B is an adjacency matrix. To realize this, I have created 2 alternatives that use supported Keras layers so that I am able to use hls4ml to deploy this model. (We are also trying to use the extension API to implement the whole model.)
Alternative 1
In the first alternative, I simply use Dense layers to mimic the matrix multiplication.
in_channels = 32
out_channels = 32
nodes = 25
input_x = Input(shape=(nodes,in_channels), name = 'input_x')
# 1x1 Convolution of incoming frame
dense1 = Dense(units=out_channels, use_bias=True, name='dense1')(input_x)
# Switch dimensions of nodes & out_channels to setup correct matmul operation
perm1 = Permute((2,1), name='perm1')(dense1)
# Use Dense layer to perform matrix multiplication with the adjacency matrix
# Units = number of columns of adj matrix
dense2 = Dense(units=nodes, use_bias=False, kernel_initializer=tf.keras.initializers.Constant(adj1),name='dense2')(perm1)
model = Model(inputs=input_x, outputs=dense2)
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_x (InputLayer) [(None, 25, 32)] 0
dense1 (Dense) (None, 25, 96) 3168
perm1 (Permute) (None, 96, 25) 0
dense2 (Dense) (None, 96, 25) 625
=================================================================
Total params: 3,793
Trainable params: 3,793
Non-trainable params: 0
Where adj1 is a tensor that represents the adjacency matrix:
# Create a 25x25 tensor with random values of 0 or 1
adj1 = tf.random.uniform(shape=(25, 25), minval=0, maxval=1)
# Round the values to the nearest integer (0 or 1)
adj1 = tf.math.round(adj1)
# Set the diagonal elements to zero
adj1 = tf.linalg.set_diag(adj1, tf.zeros(25))
# Convert to numpy
adj1 = adj1.numpy()
# Make tensor symmetric by adding its transpose to its upper triangular part
adj1 = np.triu(adj1) + np.triu(adj1, 1).T
For a starting configuration, I used default precision & a RF = 64 (like in tutorial 7 where a model with dense layers is deployed to the PYNQ-Z2
board:
config = hls4ml.utils.config_from_keras_model(model, granularity='name')
config['Model']['Strategy'] = 'Resource'
for layer in config['LayerName'].keys():
config['LayerName'][layer]['Strategy'] = 'Resource'
config['LayerName'][layer]['ReuseFactor'] = 64
print("-----------------------------------")
plotting.print_dict(config)
print("-----------------------------------")
hls_model = hls4ml.converters.convert_from_keras_model(model,
hls_config=config,
output_dir='/home/arthur/Documents/Testing/configs/ourModel_32_resource/',
backend='VivadoAccelerator',
board='pynq-z2')
hls_model.compile()
But when I build this model with hls_model.build(csim=False, export=True)
I get a rather odd output:
================================================================
== Utilization Estimates
================================================================
* Summary:
+-----------------+---------+-------+--------+-------+-----+
| Name | BRAM_18K| DSP48E| FF | LUT | URAM|
+-----------------+---------+-------+--------+-------+-----+
|DSP | -| -| -| -| -|
|Expression | -| -| 40| 661| -|
|FIFO | -| -| -| -| -|
|Instance | -| -| 9118| 2949| -|
|Memory | -| -| -| -| -|
|Multiplexer | -| -| -| 105| -|
|Register | 0| -| 969| 192| -|
+-----------------+---------+-------+--------+-------+-----+
|Total | 0| 0| 10127| 3907| 0|
+-----------------+---------+-------+--------+-------+-----+
|Available | 280| 220| 106400| 53200| 0|
+-----------------+---------+-------+--------+-------+-----+
|Utilization (%) | 0| 0| 9| 7| 0|
+-----------------+---------+-------+--------+-------+-----+
I compared these results with those of the untrained model from the tutorial, and my resource usage is suspiciously low.
Model from tutorial:
model = Sequential()
model.add(QDense(64, input_shape=(16,), name='fc1',
kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),
kernel_initializer='glorot_uniform'))
model.add(QActivation(activation=quantized_relu(6), name='relu1'))
model.add(QDense(32, name='fc2',
kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),
kernel_initializer='glorot_uniform'))
model.add(QActivation(activation=quantized_relu(6), name='relu2'))
model.add(QDense(32, name='fc3',
kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),
kernel_initializer='glorot_uniform'))
model.add(QActivation(activation=quantized_relu(6), name='relu3'))
model.add(QDense(5, name='output',
kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),
kernel_initializer='glorot_uniform'))
model.add(Activation(activation='softmax', name='softmax'))
model.summary()
Model: "sequential_8"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
fc1 (QDense) (None, 64) 1088
relu1 (QActivation) (None, 64) 0
fc2 (QDense) (None, 32) 2080
relu2 (QActivation) (None, 32) 0
fc3 (QDense) (None, 32) 1056
relu3 (QActivation) (None, 32) 0
output (QDense) (None, 5) 165
softmax (Activation) (None, 5) 0
=================================================================
Total params: 4,389
Trainable params: 4,389
Non-trainable params: 0
I have found out that the 3D-input to the Dense layer is probably the reason. I tested a similar model with a 2D input shape, and the resource usage seems more normal.
So, does hls4ml not support Dense layers with 3D inputs? In Keras itself it should be allowed:
N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).
Alternative 2
- In the 2nd alternative, I used a Conv2D layer to mimic the matmul operation. The model looks like this:
in_channels = 32
out_channels = 32
nodes = 25
a = Input(shape=(1,nodes,in_channels), name = 'input_x')
b = Conv2D(filters=out_channels, kernel_size=1, strides=1, padding='valid', data_format='channels_last',
use_bias=True, name ='conv2d_1x1')(a)
b = Reshape(target_shape=(nodes,out_channels), name = 'reshape1')(b)
c = Permute((2,1), name = 'permute1')(b)
c = Reshape(target_shape=(out_channels,nodes,1), name = 'reshape2')(c)
d = Conv2D(filters=nodes, kernel_size=(1,nodes), strides=1, padding='valid', data_format='channels_last',
use_bias=False, kernel_initializer=tf.keras.initializers.Constant(adj1), name = 'matmul')(c)
model = Model(inputs=a, outputs=d)
If I then configure this model with the same settings as in alternative 1
, I get this error:
In file included from firmware/myproject.cpp:4:
firmware/parameters.h:28:44: error: ‘Resource’ is not a member of ‘nnet’; did you mean ‘resource’?
28 | static const unsigned strategy = nnet::Resource;
| ^~~~~~~~
| resource
firmware/parameters.h:59:44: error: ‘Resource’ is not a member of ‘nnet’; did you mean ‘resource’?
59 | static const unsigned strategy = nnet::Resource;
| ^~~~~~~~
| resource
g++: error: myproject.o: No such file or directory
Changing Resource
into resource
does fix the problem, but I thought this was worth pointing out.
Then, in the Utilization Estimates
I noticed that my DSP usage was rather low, so I tried to decrease the reusefactor to 25 (instead of 64). But then again, the resource usage was suspiciously low like in alternative 1.
TL;DR
- Is a 3D input to a Dense layer unsupported by hls4ml?
- I have the following error when using Conv2D layer:
In file included from firmware/myproject.cpp:4:
firmware/parameters.h:28:44: error: ‘Resource’ is not a member of ‘nnet’; did you mean ‘resource’?
28 | static const unsigned strategy = nnet::Resource;
| ^~~~~~~~
| resource
firmware/parameters.h:59:44: error: ‘Resource’ is not a member of ‘nnet’; did you mean ‘resource’?
59 | static const unsigned strategy = nnet::Resource;
| ^~~~~~~~
| resource
g++: error: myproject.o: No such file or directory
- Changing reusefactor from 64 to 25 results in almost zero resource utilization. Why is it that
Latency
andResource
have different reusefactors because changing it to 32 is fine?
I added a jupyter notebook file that conains all the code. github_issue.zip