-
torch gradient 계산 정리파이토치&텐서플로우&관련코딩 2021. 4. 28. 11:38
만약 FloatTensor만 선언한다면 requres_grad는 False로 지정되어있어서 backward에서 오류가 난다. 무조건 지정 해줘야함. 단, label같은건 안해도 되지.
그러고 곱셈을 한다. 이때에는 asd의 grad는 none이다. 그 다음에 backward를 실행하면 c의 grad_fn은 multiplication으로 되고 asd와 qwe의 grad가 정해진다.torch.autograd.grad : WGAN-GP를 구현하다 사용한 함수
역할은 그래디언트를 직접적으로 텐서로 볼 수 있는 함수이다. 그래디언트를 직접 계산 해보거나 그래디언트가 손실 함수에 들어가야 할 때 쓸 수 있다.
위에있는 이미지는 실제 WGAN-GP에서 쓰인 gp 항을 계산하기 위한 코드이다. 그 중 torch.autograd.grad함수만 보면,
outputs : 최종 결과물
inputs : outputs을 이걸로 미분할 것이다.
grad_outputs : 그냥 최종 결과물에 몇을 곱하는 것으로 이해함. 이게 값이 2배가 되면 미분 값도 2배가 되서 나온다.
retain_graph : 이게 False로 되어있으면 한번 backward하면 그래프가 없어짐. True이면 그래프가 여러번 backward되어도 안사라진다.
create_graph : 원래 이게 먼저 나와야 맞지 않나 싶다. 그래프를 생성해야 이것도 최종 로스에 들어가는 것이 의미가 있다. 즉, 이 gradient도 그래프로 만들어야 역전파에 관여되면서 추가적인 fine-tune가 이루어 질 수 있다.
only_inputs : 잘 모르겟음.'파이토치&텐서플로우&관련코딩' 카테고리의 다른 글
(미결) GAN학습 할 때 G부터? D부터? + 추론시간은 모델 크기에 비례? (0) 2021.05.31 SoftmaxGAN에서 모델 그래프 구조 공부 (0) 2021.05.05 LR Schedulr 정리 (0) 2021.04.25 view vs reshape (0) 2021.04.23 파이토치 torchvision.utils.save_image (0) 2021.04.22