-
SoftmaxGAN에서 모델 그래프 구조 공부파이토치&텐서플로우&관련코딩 2021. 5. 5. 23:29
아래 그림은 SoftGAN의 그래프이다. 이대로 구현 하면 되는데 이전의 GAN들과 다른 점은 Generator를 훈련시키기 위해 real logit도 쓴다는 점이다.
**순서에 주의하자. 먼저 D를 업데이트 시키고 업데이트 된 D에서 logit을 추출하여 G_loss를 만든다. 또한 D를 훈련시킬때에는 optimizer_D를 사용하고 G는 optimizer_G를 사용하므로 그래프가 연결되어 있어도 업데이트는 각각 될 수 밖에없다. (다시말하면, optimizer_D.step()을 하면 D의 parameter만 업데이트 된다.) **아래 코드가 안되는 이유는 D를 훈련시키고 retain_graph=True로 하여 그래프를 남긴다. 그 뒤에 G_loss를 계산하기 위해 처음에 계산한 real_logit과 gene_logit을 사용한다. (Z_B도 마찬가지.)그런데 문제는 D가 훈련되면서 파이토치에서 parameter가 업데이트 되는 원리는 inplace연산으로, (메모리를 아끼기 위해 이렇게 한다.) 모델이 가만히 있는 것 처럼 보인다. x += 2를 생각하자. 이 때 원래 계산되었던 real_logit과 gene_logit을 사용하면서 optimizer_G.step()이 실행되는데 D는 업데이트가 되지 않지만 inplace연산을 이미 거친 그래프부분을 이용했기 때문에 계산이 안된다.
이제 이 문제를 해결 해보자.
첫번째 방법은 아래 코드처럼 그냥 gene_logit과 real_logit을 두개 만들면 된다.
이 때, D를 훈련시키는 코드에서는 gene_logit을 생성하는 부분에서 detach를 사용하였기 때문에 그래프에 G가 연결되지 않는다. 따라서 D_loss.backward()에 매개변수로 retain_graph=True를 주지 않아도 된다.두번째 방법은 아래 코드처럼 retain graph를 활용하여 detach를 없앨 수 있다. 원리는 원래 backward를 실행하고 나면 리프부터 루트까지 연결된 부분을 없앤다. 하지만 retain graph를 True로 하면 사라지지 않아 G를 남길 수 있다. * optimizer_D.step()을 하여도 G는 하나도 바뀌지 않는다. 대신 그래프가 안 사라지게 할 수는 있는것. 여기에서 위에처럼 D를 하나 더 만들어서 업데이트 시킨다.
세번째 방법이 특이한데, 먼저 그냥 D_loss와 G_loss를 구한다. 문제점의 코드와 다른점은 먼제 다 계산하고 업데이트를 시킨다. 코드를 확인하자.
원래는 아래처럼 G를 2배 훈련시키지 않아도 돌아가는데 D의 힘이 더 센지 그냥if idx % 2 ==0: ...
부분만 살려서 쓰면 생성이 안되고 그냥 검정색 사진만 나온다. 그래서 한번은 G만 훈련하면 좀 나아진다.
전체적인 결과는 깃헙에서 확인.
'파이토치&텐서플로우&관련코딩' 카테고리의 다른 글
[미]파이토치 DataParallel, DistributedDataParallel, apex + amp 정리 (0) 2021.06.13 (미결) GAN학습 할 때 G부터? D부터? + 추론시간은 모델 크기에 비례? (0) 2021.05.31 torch gradient 계산 정리 (0) 2021.04.28 LR Schedulr 정리 (0) 2021.04.25 view vs reshape (0) 2021.04.23