📑 논문

논문 | Aligning Text-to-Image Models using Human Feedback

노바깅 2023. 4. 1. 17:56

 최근 chatGPT, GPT4가 공개되면서 human feedback이 굉장히 각광받고 있다. 오늘은 human feedback 논문 중 하나를 리뷰해보려고 한다. (글을 쓰고 있는 시점 기준으로 under review 상태라고 한다.) 일단, 이 논문은 text-to-image 모델들이 텍스트와 이미지가 잘 align 되어 있지 않다는 점을 꼬집으며 [yellow]Human Feedback[/yellow]을 이용해서 모델을 fine-tuning 하는 메소드를 제시한다. 

 

 위에서도 언급했듯이 human feedback을 이용해서 모델을 fine-tuning 한다. 크게 3가지 과정으로 구성되어 있다.

   1. human feedback 수집

   2. human labeled image-text dataset을 이용해서 reward function 학습 (human feedback을 잘 예측하도록)

   3. reward-weighted likelihood를 최대화 하도록 모델을 fine-tuning


Method

 위에서 언급했듯이 본 모델은 3단계로 구성되어 있다. 각 단계에 대해서 조금 더 자세히 알아보자.

 

Human Data Collection

- Image-text dataset.

 본 논문에서는 count, color, background 3가지 텍스트 prompt를 고려한다. 각 종류의 prompt 마다 단어나 구를 일부 객체와 결합하여 prompt를 생성한다. 예를 들어서 green과 dog을 결합해서 'a green dog' 이라는 prompt를 생성할 수 있다. 뿐만 아니라 3가지 종류의 prompt를 모두 섞어서 'two green dogs in a city' 와 같은 하나의 prompt를 만들 수 있다.

 

- Human feedback.

 여러 사람에게 image-text dataset에 대한 binary feedback을 받는다. 사람은 동일한 prompt로 생성된 이미지 3장을 보고, 각 이미지가 잘 텍스트와 잘 align 되어 있는지 평가한다. (평가를 명확히 하기 위해서 binary feedback을 받는다.)

 

Reward Learning

 image-text 사이의 alignment 정도를 평가 하기 위해서 reward function $r_{\phi}(\textbf{x}, \textbf{z})$ ($\phi$는 reward function의 파라미터)을 학습한다. reward function은 image $\textbf{x}$와 text prompt $\textbf{z}$를 scalar value에 매핑하며, human feedback $y \in \{0, 1\}$를 예측하도록 학습된다.

 

 human feedback dataset $D^{human} = \{(\textbf{x}, \textbf{z}, y)\}$가 주어지면 reward function $r_{\phi}$은 MSE를 최소화 하도록 학습된다.

 

- Prompt classification

 reward를 학습하기 위해 위의 MSE를 최소화 하는 loss 뿐만 아니라 prompt classification loss도 추가해준다. 각 good으로 label 된 image-text pair마다 N-1개의 original text prompt와 다른 의미를 가지는 text prompts를 생성한다.  예를 들어 original prompt가 Red dog 이었다면, {Blue dog, ..., Green dog}과 같이 추가로 text prompts를 생성할 수 있다. 이렇게 생성된 데이터셋을 $D^{txt} = \{(\textbf{x}, \{\textbf{z}_j\}_{j=1}^N, i^{\prime})\}$ 이라고 하며, image $\textbf{x}$에 대해서 인덱스 $i^{\prime}$가 original prompt이다.

 이렇게 생성한 augmented prompt와 original prompt를 분류하도록 하며 이 loss는 reward function을 학습하기 위한 auxiliary loss로 사용된다. prompt classifier는 아래 수식처럼 reward function을 사용한다. 아래 수식에서 분모는 작아지고, 분자는 커지도록 함으로써 original prompt ($i^{\prime})의 확률을 가장 높게 만들어줘야 한다. (T는 0보다 큰 temperature)

 따라서 prompt classifier를 학습하기 위한 loss는 아래 수식과 같으며, $L^{CE}$는 일반적인 cross-entropy loss이다.

 

 최종적으로 reward function을 학습하기 위한 loss는 아래와 같다. ($\lambda$는 하이퍼파라미터)

 

Updating the Text-to-Image Model

 이 단계에서는 앞 단계에서 학습한 reward function $r_{\phi}$을 가지고 모델을 fine-tuning 한다. 아래 수식에서 $\theta$는 text-to-image model의 파라미터를 의미한다.

 $D^{model}$은 tested text prompt에 의해 만들어진 이미지를 담고 있는 model-generated dataset을 의미하고, $D^{pre}$은 pre-training dataset을 의미한다. 2번째 항은, model-generated dataset의 다양성의 한계로 발생할 수 있는 overfitting을 방지하기 위해 추가되었다.

 

Experiments

 위의 그림은 original model과 본 논문에서 제시한 방식으로 모델을 fine-tuning 했을 때 동일한 prompt로 생성한 이미지를 보여준다. 이미지에 볼 수 있듯이 original model은 detail 면에서 많이 떨어지지만, 본 논문의 모델이 생성한 이미지는 prompt와 잘 매치된다. 추가적인 실험 부분은 직접 논문에서 보는 것을 추천한다. (appendix에도 많은 실험 결과 이미지가 존재한다.)

 

Discussion

 Discussion 파트에서는 본 논문의 한계도 같이 제시해주고 있다.

  •  More nuanced human feedback. training set에서 비슷한 이미지들이 상위에 rank되어 있기 때문에 생성되는 이미지들이 비슷하다.
  •  Diverse and large human dataset. 본 논문에서는 좋다/아니다 (1 또는 0)으로만 human feedback을 받고 있는데, 이를 좀 더 다양하게 확장할 수도 있을 것이다.
  •  Different objectives and algorithms. 본 논문에서는 text-to-image model을 update하기 위해 reward-weighted likelihood를 최대화 했지만, NLP 분야에서 많이 사용된 RLHF를 쓰면 더 좋을 것이다.

 ChatGPT, GPT4 등의 등장으로 Human Feedback이 주목받고 있는 와중에 vision 분야에서의 human feedback도 이제 슬슬 연구가 진행되고 있는 듯 하다. 기대된다...ㅎㅎ 

 

 

논문: https://arxiv.org/abs/2302.12192