ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • SoftmaxGAN에서 모델 그래프 구조 공부
    파이토치&텐서플로우&관련코딩 2021. 5. 5. 23:29

    아래 그림은 SoftGAN의 그래프이다. 이대로 구현 하면 되는데 이전의 GAN들과 다른 점은 Generator를 훈련시키기 위해 real logit도 쓴다는 점이다.
    **순서에 주의하자. 먼저 D를 업데이트 시키고 업데이트 된 D에서 logit을 추출하여 G_loss를 만든다. 또한 D를 훈련시킬때에는 optimizer_D를 사용하고 G는 optimizer_G를 사용하므로 그래프가 연결되어 있어도 업데이트는 각각 될 수 밖에없다. (다시말하면, optimizer_D.step()을 하면 D의 parameter만 업데이트 된다.) **

    아래 코드가 안되는 이유는 D를 훈련시키고 retain_graph=True로 하여 그래프를 남긴다. 그 뒤에 G_loss를 계산하기 위해 처음에 계산한 real_logit과 gene_logit을 사용한다. (Z_B도 마찬가지.)그런데 문제는 D가 훈련되면서 파이토치에서 parameter가 업데이트 되는 원리는 inplace연산으로, (메모리를 아끼기 위해 이렇게 한다.) 모델이 가만히 있는 것 처럼 보인다. x += 2를 생각하자. 이 때 원래 계산되었던 real_logit과 gene_logit을 사용하면서 optimizer_G.step()이 실행되는데 D는 업데이트가 되지 않지만 inplace연산을 이미 거친 그래프부분을 이용했기 때문에 계산이 안된다.

    이제 이 문제를 해결 해보자.

    첫번째 방법은 아래 코드처럼 그냥 gene_logit과 real_logit을 두개 만들면 된다.

    이 때, D를 훈련시키는 코드에서는 gene_logit을 생성하는 부분에서 detach를 사용하였기 때문에 그래프에 G가 연결되지 않는다. 따라서 D_loss.backward()에 매개변수로 retain_graph=True를 주지 않아도 된다.

    두번째 방법은 아래 코드처럼 retain graph를 활용하여 detach를 없앨 수 있다. 원리는 원래 backward를 실행하고 나면 리프부터 루트까지 연결된 부분을 없앤다. 하지만 retain graph를 True로 하면 사라지지 않아 G를 남길 수 있다. * optimizer_D.step()을 하여도 G는 하나도 바뀌지 않는다. 대신 그래프가 안 사라지게 할 수는 있는것. 여기에서 위에처럼 D를 하나 더 만들어서 업데이트 시킨다.

    세번째 방법이 특이한데, 먼저 그냥 D_loss와 G_loss를 구한다. 문제점의 코드와 다른점은 먼제 다 계산하고 업데이트를 시킨다. 코드를 확인하자.
    원래는 아래처럼 G를 2배 훈련시키지 않아도 돌아가는데 D의 힘이 더 센지 그냥

    if idx % 2 ==0:
      ...

    부분만 살려서 쓰면 생성이 안되고 그냥 검정색 사진만 나온다. 그래서 한번은 G만 훈련하면 좀 나아진다.

    전체적인 결과는 깃헙에서 확인.

    댓글

Designed by Tistory.