Notice
Recent Posts
Recent Comments
Link
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Archives
Today
Total
관리 메뉴

uos-machine-learning

CycleGAN Implementataion 코드에 Wasserstein loss 추가하기(Pytorch) 본문

딥러닝

CycleGAN Implementataion 코드에 Wasserstein loss 추가하기(Pytorch)

이산한하루 2019. 9. 8. 13:51

기존 CycleGAN Discriminator Loss를 Mean_squared error 대신 Wasserstein Loss을 사용하면 좀 더 안정적인 학습이 가능하다.

Discriminator A loss function

# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2

Discriminator B loss function

# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2

WGAN Loss Gradient Panelty 

기존 WGAN은 weight를 -0.01에서 0.01로 잡아버렸지만, Gradient Panelty를 이용하여 규제 효과를 줄 수 있다.

def calc_gradient_penalty(netD, real_data, generated_data):
    # GP strength
    LAMBDA = 10

    b_size = real_data.size()[0]

    # Calculate interpolation
    alpha = torch.rand(b_size, 1, 1, 1)
    alpha = alpha.expand_as(real_data)
    alpha = alpha.cuda()

    interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
    interpolated = Variable(interpolated, requires_grad=True)
    interpolated = interpolated.cuda()

    # Calculate probability of interpolated examples
    prob_interpolated = netD(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
                           create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(b_size, -1)

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty
    return LAMBDA * ((gradients_norm - 1) ** 2).mean()

바꾼 Discriminator A loss function

fake_A_ = fake_A_buffer.push_and_pop(fake_A)
grad_penalty_A = calc_gradient_penalty(D_A, fake_A_, real_A)
loss_D_A = torch.mean(D_A(fake_A_)) - torch.mean(D_A(real_A)) + grad_penalty_A

바꾼 Discriminator B loss function

fake_B_ = fake_B_buffer.push_and_pop(fake_B)
grad_penalty_B = calc_gradient_penalty(D_B, fake_B_, real_B)
loss_D_B = torch.mean(D_B(fake_B_)) - torch.mean(D_B(real_B)) + grad_penalty_B

Cyclegan Implementation 코드 : https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cyclegan

 

eriklindernoren/PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks. - eriklindernoren/PyTorch-GAN

github.com

Comments