uos-machine-learning
CycleGAN Implementataion 코드에 Wasserstein loss 추가하기(Pytorch) 본문
기존 CycleGAN Discriminator Loss를 Mean_squared error 대신 Wasserstein Loss을 사용하면 좀 더 안정적인 학습이 가능하다.
Discriminator A loss function
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
Discriminator B loss function
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
WGAN Loss Gradient Panelty
기존 WGAN은 weight를 -0.01에서 0.01로 잡아버렸지만, Gradient Panelty를 이용하여 규제 효과를 줄 수 있다.
def calc_gradient_penalty(netD, real_data, generated_data):
# GP strength
LAMBDA = 10
b_size = real_data.size()[0]
# Calculate interpolation
alpha = torch.rand(b_size, 1, 1, 1)
alpha = alpha.expand_as(real_data)
alpha = alpha.cuda()
interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
interpolated = Variable(interpolated, requires_grad=True)
interpolated = interpolated.cuda()
# Calculate probability of interpolated examples
prob_interpolated = netD(interpolated)
# Calculate gradients of probabilities with respect to examples
gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
create_graph=True, retain_graph=True)[0]
# Gradients have shape (batch_size, num_channels, img_width, img_height),
# so flatten to easily take norm per example in batch
gradients = gradients.view(b_size, -1)
# Derivatives of the gradient close to 0 can cause problems because of
# the square root, so manually calculate norm and add epsilon
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
# Return gradient penalty
return LAMBDA * ((gradients_norm - 1) ** 2).mean()
바꾼 Discriminator A loss function
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
grad_penalty_A = calc_gradient_penalty(D_A, fake_A_, real_A)
loss_D_A = torch.mean(D_A(fake_A_)) - torch.mean(D_A(real_A)) + grad_penalty_A
바꾼 Discriminator B loss function
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
grad_penalty_B = calc_gradient_penalty(D_B, fake_B_, real_B)
loss_D_B = torch.mean(D_B(fake_B_)) - torch.mean(D_B(real_B)) + grad_penalty_B
Cyclegan Implementation 코드 : https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cyclegan
eriklindernoren/PyTorch-GAN
PyTorch implementations of Generative Adversarial Networks. - eriklindernoren/PyTorch-GAN
github.com
'딥러닝' 카테고리의 다른 글
Yolo v3 논문 리뷰 (0) | 2019.09.18 |
---|---|
파이썬 라이브러리 소개 - imgaug (1) | 2019.09.12 |
Tensorflow 2.0 Neural Style Transfer 튜토리얼 (5) | 2019.05.14 |
Image Super Resolution Evaluation Metric 케라스 구현 (0) | 2019.05.11 |
Keras Custom Loss 만들기 (0) | 2019.05.09 |
Comments