Training Covariance model¶

Tested on kernel TensorFow on GPU as defined in the image landerlini/lhcbaf:v0p8¶

This notebook is part of a pipeline. It requires the preprocessing step defined in the GAN preprocessing notebook and the trained model is validated in the Covariance-validation notebook.

Environment and libraries¶

As for the other trainings, we are handling the GPU with TensorFlow. To make sure the GPU is found, we print below the system name of the accelerator.

GPU: /device:GPU:0

Loading the data¶

The data are loaded with our custom FeatherReader helper class, defined in the local module feather_io.

In this notebook, we are using:

  • training data: to optimize the weights of the network
  • validation data: to evaluate overtraining effects

A chunk of data is loaded to ease the construction of the model, for example defining the shapes of the input and output tensors.

TensorShape([352612, 15])

Visualization of the correlations of the input features (conditions)¶

High correlation of the input features may result in large distances between the distributions of the trained and generated dataset with minimal overlap. In this conditions, the discriminator would be unable to drive the training of the generator towards a successful end (unless more advanced loss functions are used).

Since we adopted a rather sophisticated preprocessing step, it is worth verifying that it behaves as expected, clearing strong correlations between variables.

Definition of the model¶

The GANs used in this module are composed of three different neural networks trained simultaneously, namely:

  • a generator neural network that takes as an input the condition (such as the generator-level features) and the random noise and formulate predictions for the output
  • a discriminator neural network trained to identify whether a sample was part of the reference dataset or was produced by the generator
  • a referee network that mimics the configuration of the discriminator, is trained with a much larger learning rate.

For a discussion on the techniques used to describe the GAN, please refer to the resolution notebook.

Generator architecture¶

The generator architecture is very similar to the one adopted for the resolution GAN.

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 15)]         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 128)]        0                                            
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 143)          0           input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 128)          18432       concatenate[0][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 128)          16512       dense[0][0]                      
__________________________________________________________________________________________________
tf.__operators__.add (TFOpLambd (None, 128)          0           dense[0][0]                      
                                                                 dense_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 128)          16512       tf.__operators__.add[0][0]       
__________________________________________________________________________________________________
tf.__operators__.add_1 (TFOpLam (None, 128)          0           tf.__operators__.add[0][0]       
                                                                 dense_2[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 128)          16512       tf.__operators__.add_1[0][0]     
__________________________________________________________________________________________________
tf.__operators__.add_2 (TFOpLam (None, 128)          0           tf.__operators__.add_1[0][0]     
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 128)          16512       tf.__operators__.add_2[0][0]     
__________________________________________________________________________________________________
tf.__operators__.add_3 (TFOpLam (None, 128)          0           tf.__operators__.add_2[0][0]     
                                                                 dense_4[0][0]                    
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 128)          16512       tf.__operators__.add_3[0][0]     
__________________________________________________________________________________________________
tf.__operators__.add_4 (TFOpLam (None, 128)          0           tf.__operators__.add_3[0][0]     
                                                                 dense_5[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 128)          16512       tf.__operators__.add_4[0][0]     
__________________________________________________________________________________________________
tf.__operators__.add_5 (TFOpLam (None, 128)          0           tf.__operators__.add_4[0][0]     
                                                                 dense_6[0][0]                    
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 128)          16512       tf.__operators__.add_5[0][0]     
__________________________________________________________________________________________________
tf.__operators__.add_6 (TFOpLam (None, 128)          0           tf.__operators__.add_5[0][0]     
                                                                 dense_7[0][0]                    
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 128)          16512       tf.__operators__.add_6[0][0]     
__________________________________________________________________________________________________
tf.__operators__.add_7 (TFOpLam (None, 128)          0           tf.__operators__.add_6[0][0]     
                                                                 dense_8[0][0]                    
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 15)           1935        tf.__operators__.add_7[0][0]     
==================================================================================================
Total params: 152,463
Trainable params: 152,463
Non-trainable params: 0
__________________________________________________________________________________________________

Discriminator architecture¶

Please note that we observe better performance using shallower neural networks and removing the skip connections for the discriminator. It is not clear why, but it is possible that the classification problem is so different from a logistic regression (for example because of the large number of flags) that propagating the input to the output is not beneficial and limits the network ability to perform the classification.

As for the resolution GAN, the input tensor is as follows.

$X$$y$
Input conditions (gen. level features)Reference target features1
Input conditions (gen. level features)Generated target features0

The input conditions are repeated twice, but in the first half of the batch they are completed with the output features of the reference samples and labeld as $1$. In the second half of the batch they are completed with randomly generated features and labeld with $0$.

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
X_ref (InputLayer)              [(None, 15)]         0                                            
__________________________________________________________________________________________________
Y_ref (InputLayer)              [(None, 15)]         0                                            
__________________________________________________________________________________________________
Y_gen (InputLayer)              [(None, 15)]         0                                            
__________________________________________________________________________________________________
X (Concatenate)                 (None, 15)           0           X_ref[0][0]                      
                                                                 X_ref[0][0]                      
__________________________________________________________________________________________________
Y (Concatenate)                 (None, 15)           0           Y_ref[0][0]                      
                                                                 Y_gen[0][0]                      
__________________________________________________________________________________________________
XY (Concatenate)                (None, 30)           0           X[0][0]                          
                                                                 Y[0][0]                          
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 128)          3968        XY[0][0]                         
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 128)          16512       dense_10[0][0]                   
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 128)          16512       dense_11[0][0]                   
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 128)          16512       dense_12[0][0]                   
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 128)          16512       dense_13[0][0]                   
__________________________________________________________________________________________________
dense_15 (Dense)                (None, 128)          16512       dense_14[0][0]                   
__________________________________________________________________________________________________
dense_16 (Dense)                (None, 128)          16512       dense_15[0][0]                   
__________________________________________________________________________________________________
dense_17 (Dense)                (None, 1)            129         dense_16[0][0]                   
==================================================================================================
Total params: 103,169
Trainable params: 103,169
Non-trainable params: 0
__________________________________________________________________________________________________

Referee architecture¶

The referee is kept as similar as possible to the discriminator, but trained with a larger learning rate.

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 15)]         0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 15)]         0                                            
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None, 15)]         0                                            
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 15)           0           input_3[0][0]                    
                                                                 input_3[0][0]                    
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 15)           0           input_4[0][0]                    
                                                                 input_5[0][0]                    
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 30)           0           concatenate_2[0][0]              
                                                                 concatenate_1[0][0]              
__________________________________________________________________________________________________
dense_18 (Dense)                (None, 128)          3968        concatenate_3[0][0]              
__________________________________________________________________________________________________
dense_19 (Dense)                (None, 128)          16512       dense_18[0][0]                   
__________________________________________________________________________________________________
dense_20 (Dense)                (None, 128)          16512       dense_19[0][0]                   
__________________________________________________________________________________________________
dense_21 (Dense)                (None, 128)          16512       dense_20[0][0]                   
__________________________________________________________________________________________________
dense_22 (Dense)                (None, 128)          16512       dense_21[0][0]                   
__________________________________________________________________________________________________
dense_23 (Dense)                (None, 128)          16512       dense_22[0][0]                   
__________________________________________________________________________________________________
dense_24 (Dense)                (None, 128)          16512       dense_23[0][0]                   
__________________________________________________________________________________________________
dense_25 (Dense)                (None, 1)            129         dense_24[0][0]                   
==================================================================================================
Total params: 103,169
Trainable params: 103,169
Non-trainable params: 0
__________________________________________________________________________________________________

Training step¶

The training step is defined with the lower-level tensorflow API because we need to carefully tune which weights we wish to update based on each evaluation of a loss function.

Technically, we are using the tensorflow GradientTape to keep track of the gradient while we describe the computation of the loss function. We will have a different tape for each neural network, recording the derivatives of the loss functions with respect to the weights of that particular network.

Notes on the chosen loss function¶

The loss function for the classification task is clearly a Binary Cross Entropy (BCE). However we adopt two non-default options for its computation:

  • from_logit=True, to improve the numerical stability of the gradient computation, which is of particular importance for GANs because of the very long training procedure that may inflate the errors due to many subsequent iterations
  • label_smoothing=0.1, to introduce a penalty against overconfident classification, which corresponds to the plateaux of the sigmoid function, where the gradient is null, providing no useful information for the generator's training.

Training¶

  • Batch size: 10k
  • Number of epochs: 3000
Training (validation) loss: 0.357 (0.454): 100%|████████████████████████████████████████████████████████████████████| 3000/3000 [57:45<00:00,  1.16s/it]

The evolution of the loss function as evaluated by the referee network is reported below. A dashed line represent the ideal value of the BCE when evaluated on two identical datasets with an ideal classifier.

Sanity check¶

In the following plot we represent the correlation between the output features as they are in the reference dataset and as they are reproduced by the GAN.

The original dataset is represented in blue, while the generated dataset in orange.

Exporting the model¶

The model is exported to the same directory were the preprocessing steps tX and tY were stored.

WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: /workarea/local/private/cache/models/covariance/assets

Conclusion¶

In this notebook we discussed the training procedure of the GAN model used to parametrize the covariance. The model is very similar to the one adopted for the resolution, with some difference in the architecture of the discriminator and referee networks.

In particular, we discussed

  • the overall structure of the DNN system;
  • the architecture of the generator, discriminator and of a referee network we introduced to ease monitoring, debugging an hyperparameter optimization
  • the procedure for optimizing the weights of the three networks based on three different computations of the gradients
  • the outcome of the training procedure as visualized by the evolution of the loss of the referee network

Finally, we exported the model for deployment and further validation.