gracesmith96 asked . 2021-06-28
How can I improve generalization for my Neural Network?
How can I improve generalization for my Neural Network?
I have a network that is trained with very low error but it does not perform well with new data sets. Is there something more that can be done to have a network with low error that can also generalize with new situatio
network , generalization , matlab , programming ,
Neeta Dsouza answered . 2024-12-22 14:55:35
When training a Neural Network, generalization is an important feature to maintain in order to avoid overfitting. This can occur when the error on the training set is forced to a very small value. The network will perform very well for that particular training set because it has memorized the training examples but it can not learn to adapt to new situations. In other words it is not generalized.
There are several methods in which one can improve the generalization of the Neural Network without sacrificing accuracy.
Specifying a network which is just large enough to provide an adequate fit is highly recommended. Not only will it improve generalization but it will speed up training. The drawback to this is that you have to know beforehand how many neurons are adequate for a particular application. This can become quite difficult.
There are two other methods which are implemented in the Neural Network Toolbox.
1) The first method is known as Regularization. This invloves a modification of the performance function which is, by default, the mean sum of squares of the network errors (MSE). Generalization can be improved by modifying this performance function as follows:
MSEREG=g*MSE +(1-g)*MSW
where g is a performance ratio and MSW is the mean sum of sqaures of the network weights and biases. To set this in MATLAB please see the following example:
p=[-1 -1 2 2;0 5 0 5];
t=[-1 -1 1 1];
net=newff([-1 2;0 5],[3 1],{'tansig','purelin'},'trainbfg');
net.performFcn='msereg';
net.performParam.ratio=0.5;
net=train(net,p,t);
The difficulty here is that you may not know the correct performance parameters to set. Therefore, the training function TRAINBR should be used which determines the optimal regularization paramters. The documentation for TRAINBR is available by running this at the command line:
web([docroot '/toolbox/nnet/ref/trainbr.html'])
2) Another method is known as Early Stopping. This method uses validation to stop training if the network begins to overfit the data. Passing a validation set to the training function will test this new data set at a certain point in training to test how the network is responding for other inputs. If the error of the validation set begins to rise this generally indicates overfitting and the training will stop. The validation set is presented in the following structure format:
VV.PD - Validation delayed inputs.
VV.Tl - Validation layer targets.
VV.Ai - Validation initial input conditions.
VV.Q - Validation batch size.
VV.TS - Validation time steps.
This structure is then passed to the training function.
Not satisfied with the answer ?? ASK NOW