상세 컨텐츠

본문 제목

[6주차 / 진유석 / 논문리뷰] Generative Adversarial Nets

방학 세션/CV

by 진유석 2023. 2. 22. 17:47

본문

0. Introduction

 

최근까지도 많이 쓰이고 있는 GAN을 제시한 Ian Gootfellow의 논문입니다. GAN은 Generator(생성모델)와 Discriminator(판별모델)의 서로 다른 2개의 네트워크로 구성되어 있습니다. GAN을 설명할 때 가장 많이 드는 예시가 경찰과 위조지폐범 이야기입니다.

위조지폐범은 경찰을 속이려고 하고, 경찰은 돈을 보고 이것이 위조지폐인지 진짜 지폐인지 구분합니다. 위조지폐범과 경찰은 각각 생성 모델과 판별 모델에 대응됩니다. 판별 모델은 생성 모델이 만들어낸 데이터를 입력으로 받습니다. 그 이후 입력된 데이터가 가짜라고 생각하면 0, 진짜라고 생각하면 1을 출력하게 됩니다. 생성 모델은 경찰이 나의 가짜 데이터를 진짜 데이터로 분류하도록 속일 수 있는 방향으로 학습합니다. 진짜같은 데이터를 만들어내기에 'Generative'하고, 속일 수 있도록/구별해낼 수 있도록 서로 적대적으로 학습하기에  'Adversarial'한 모델인 것입니다. 계속해서 생성 모델을 학습시키고, 경찰이 구분하지 못할 정도의 진짜같은 데이터를 만들어낼 수 있습니다.

모델의 구조는 위 그림과 같습니다. random noise z를 입력으로 받은 생성 모델은 Fake image들을 만들어내게 되고, 판별 모델은 진짜 데이터와 가짜 데이터를 구분하려고 노력합니다.

minimax 구조의 위 수식에서 D는 Discriminator(판별 모델), G는 Generator(생성 모델)입니다. 또 x는 진짜 데이터, G(z)는 random noise를 입력으로 받아 생성 모델이 생성한 가짜 데이터입니다.

 

desmos에서 y = logx와 y=log(1-x)의 그래프를 그려 왔습니다. 파란색이 y = logx, 초록색이 y = log(1-x)입니다. 

 

D의 관점에서 보면, D가 잘 학습되었다는 것은 '실제 데이터는 실제 데이터로, 가짜 데이터는 가짜 데이터로' 정확하게 판별해낼 수 있다는 의미입니다. 다시 말해 실제 데이터를 입력받으면 1을, 가짜 데이터를 입력받으면 0을 출력해야 잘 학습된 것인데요, 위 식에서 실제 데이터 x를 입력받은 D(x)는 1, 가짜 데이터를 입력받은 D(G(z))는 0으로 출력해야 합니다. 그렇게 되었을 때 두 식은 logD(x) = log(1- D(G(z))) = 0이 됩니다. D는 1 이상의 값을 만들 수 없고, 로그함수는 열린 구간 (0,1) 에서 음수 값을 갖습니다. 진짜 데이터가 들어간 파란색 그래프는 1에서 최대값을 가지고, 가짜 데이터가 들어간 초록색 그래프는 0에서 최댓값을 갖습니다. 즉, 진짜 데이터는 1로, 가짜 데이터는 0으로 학습을 잘 하고 싶다는 것은 V(D, G)를 '최대화'하고 싶다는 의미입니다. 

 

G의 관점에서는, G가 잘 학습되었다는 것은 판별 모델이 가짜 데이터와 진짜 데이터를 구분하지 못할 만큼 진짜같은 데이터를 만들어냈다는 의미입니다. 즉 D(G(z))가 1이 될 수 있도록 학습하는 것입니다. 1이 된다는 것은 초록색 그래프가 계속 내려가게, '최소화되게' 만들고 싶다는 것입니다.

 

* 하지만 실제에서는 문제점이 있기에 조금 다르게 만듭니다. 초창기에 random noise로부터 만든 생성 모델의 결과물은 어떨까요? 아마 QR코드처럼 진짜같은 이미지가 아니라 노이즈처럼 생겼을 것입니다. 아직 생성 실력이 부족하니까요. 이 때 판별 모델은 자신 있게 이건 가짜라고 말할 수 있습니다. 즉 D(G(z))는 처음에는 0에 근접해 있을 것입니다.

이번에도 desmos에서 그래프를 그려 왔는데요, 파란색이 y = -logx, 초록색이 y = log(1-x)입니다. 두 그래프를 비교해 보면, 0 근처에서 파란색의 gradient가 훨씬 큰 것을 알 수 있습니다. 만약 초록색 그래프로 학습하게 되면, gradient가 작아 잘 학습되지 않습니다. 이 때문에 실제 학습에서는 파란색 그래프를 사용했다고 밝히고 있습니다. (사실 실제로는 y= logx를 사용했는데, max log(D(G(z)))나 min -log(D(G(Z)))나 똑같아서 이렇게 그래프를 그려 왔습니다.)

논문의 그림보다 좋은 그림을 찾아서 가져왔습니다. starGAN을 만드신 최윤제 님의 발표자료인데요, 위 그래프처럼 x4에 해당하는 실제 이미작 많으면 p(x)가 높고, x9의 경우는 p(x)가 낮을 것입니다. Generator의 목적은, 실제 데이터 분포가 나타내는 확률분포 그래프와 유사한 확률 분포를 가지도록 학습합니다.

 

아래 유사 코드는 학습이 이루어지는 방식입니다. 동시에 학습하는 것이 아니라, k번만큼 discriminator를 학습시킨 뒤 generator을 1번 학습시키는 방향으로 따로따로 학습합니다. (논문에서는 k=1로 설정)

아래는 Optimality의 증명이고, 쿠빅 CV 세션 강의자료에서 가져왔습니다. 논문의 글 + 수식보다 더 자세하기에 이해하기 편한 것 같습니다.

1) Discriminator

2) Generator

관련글 더보기

댓글 영역