Improved parsing of pytorch models using torch.FX
Created by: JanFSchulte
Current parsing of pytorch models uses a loop of the named_modules
of the model (https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/converters/pytorch_to_hls.py#L163). This has several disadvantages:
- Captures only layers that are defined as members of the model class
- Can't infer the correct order of models
- Ignores other operations that are part of the
forward()
method of the model
In this PR, we propose to fix this by first created a graph representation of the model's forward()
function using the symbolic tracing functionality of https://pytorch.org/docs/stable/fx.html. Each operation in the forward()
is represented by a node in the graph. Nodes can be of these types:
For example, for this model
class MyModuleConvRelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3,3,3)
def forward(self, x):
y1 = self.conv(x)
y = torch.relu(y1)
y = y + y1
y = torch.relu(y)
return y
the resulting graph representation is
graph():
%x : [#users=1] = placeholder[target=x]
%conv : [#users=2] = call_module[target=conv](args = (%x,), kwargs = {})
%relu : [#users=1] = call_function[target=torch.relu](args = (%conv,), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%relu, %conv), kwargs = {})
%relu_1 : [#users=1] = call_function[target=torch.relu](args = (%add,), kwargs = {})
return relu_1
As the nodes in the graph follow the order of operations of the forward()
function, we can then simply loop over them and parse each node into one node in the hls4ml model representation. For the parsing of the individual layers, existing code is used where available without significant changes. Functionality for more types of layers is also added by this PR.
The types of layers currently understood by the parser are
- Linear
- Softmax
- Relu
- LeakyReLU
- ThresholdedReLU
- ELU
- PReLU
- Sigmoid
- BatchNorm2d
- BatchNorm1d'
- Batch_norm
- MaxPool1d
- MaxPool2d
- AvgPool1d
- AvgPool2d
- Add
- Subtract
- Multiply
- Average
- Maximum
- Minimum
- Concatenate
- Dot
- Conv1d
- Conv2d
- View
- Dropout
- Flatten
- Sequential
This PR also fixes https://github.com/fastmachinelearning/hls4ml/issues/409
Changes are mostly confined to the frontend, but small changes are made to the backend to the templates for pooling layers to add the option that zero-padded entries are included in average pooling operations.
One big difference between pytorch and keras is the data format of the input tensors, which is channels_first by default, instead of the channels_last used by keras. The built-in tools in pytorch to convert a model to channels_last don't work for all dimensions of the input. Therefore the functionality has been added to transpose the inputs within hls4ml so the existing channels_last implementations of layers can be used. By default the inputs are transposed for io_parrallel but not io_stream since we don't have transpose layers for all dimensions in io_stream. The outputs are not transposed by default, but this can be switched on, again only for io_parallel.
Limitations:
- Many types of layers not supported yet
- The same functionality is available in pytorch either as torch.nn classes or torch.functional functions in many cases. These have to be parsed differently, which I have implemented only sporadically for the functionals so far.
Type of change
-
Bug fix (non-breaking change that fixes an issue) -
New feature (non-breaking change which adds functionality)
Tests
The new parsing was tested using 5-6 different pytorch model examples from around the web. In addition, I verified that the two example models for pytroch included with hls4ml get parsed successfully. A test for the API was added in the test/pytest folder, in analogy to the test for the keras parser. All tests pass successfully.
Checklist
-
I have read the guidelines for contributing. -
I have commented my code, particularly in hard-to-understand areas. -
I have made corresponding changes to the documentation. -
My changes generate no new warnings. -
I have installed and run pre-commit
on the files I edited or added. -
I have added tests that prove my fix is effective or that my feature works.