ESRGAN: incomplete loss_GAN
Hi,
I think the loss_GAN
in esrgan.py
is missing a term.
In the ESRGAN paper the adversarial loss is like the discriminator loss a sum of two terms.
However as far as I understand the code, only the second part is implemented in esrgan.py#L135
.
The following line would do the trick in my opinion.
loss_GAN = (criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid) +
criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), fake))/2