ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • loss.backward(retain_graph=True)의 필요성 & GAN 학습할때
    파이토치&텐서플로우&관련코딩 2021. 1. 5. 12:05

    가장 기본적인 GAN을 학습할 때를 예로 들면, 아래 코드를 참조하자.

     

    판별기부터 훈련 시키고 생성기를 훈련 시키는 모습이다. 방법은 여러가지가 있음.

    1. latent_z를 만들고 gene_img를 만들어서 real_img와 비교해서 discriminator를 학습한 뒤에, 다시 latent_z를 만들고 gene_img만들어서 generator를 학습.  => 한번의 에폭에 서로다른 gene_img로 학습시킴. 그다지 좋지 않은듯.

     

    2. discirminator를 훈련하는 것은 1과 동일. latent_z를 만들지 않고 기존의 latent_z로 gene_img를 만들어서 학습시킴. => generator를 두번 돌려야 되므로 시간아까움

     

    여기서 retrain_graph를 활용할 수 있다. 지금 문제는 discriminator가 뒤에 있는데 먼저 학습한 뒤에 앞에 있는 generator를 학습 하려면 안되는 것이다. 전체적인 과정을 그래프로 나타내면 아래와 같다. 

    먼저 discriminator를 학습하는 것은 괜찮다. 하지만 학습 직후 파이토치는 메모리 절약을 위해 학습한 경로를 메모리 해제시킨다. 그러고 나서 generator를 학습 하려고 하면 이미 generator는 그래프에서 없어졌기 때문에 학습이 안되는 것. 따라서 retain_graph를 True로 하면 discriminator를 학습 시키고 그래프를 보존하여 generator까지 갈 수 있는것. 

     

    3. 위의 코드에서 discrimintor에 retain_graph를 추가하면 된다. (detach는 빼고)

     

    4. 위의 코드처럼 discriminator를 훈련 시킬때 detach를 쓰면 generator까지 그래프가 형성되지 않는다. 즉 위 그림에서 gene_img가 독립적으로 튀어나오고 discriminator와 연결됨. 그리고 밑에서 gene_img를 그대로 사용함으로서 generator와 연결 후 backprop을 할 수 있다.

     

     

     

    안되는 예제.

    위처럼 하면 discriminator 훈련 파트에서는 discriminator만 하고 있고 아래에서는 detach가 되어있으므로 generator까지 연결되어 있지 않다. 따라서 이 코드는 discriminator만 학습되고 generaotor는 학습되지 않는다. 

    댓글

Designed by Tistory.