새소식

ML | AI/내용 정리 - 2022.10.08

Gradient Checkpointing이란

  • -
 

GitHub - cybertronai/gradient-checkpointing: Make huge neural nets fit in memory

Make huge neural nets fit in memory. Contribute to cybertronai/gradient-checkpointing development by creating an account on GitHub.

github.com

모든 사진과, 글은 위의 링크를 참조했습니다.

Graident checkpointing?

GPU 사용 시 사용 가능한 메모리를 늘리기 위한 방법 중 하나이다. 이를 통해 연산 시간이 늘어나는 대신, 메모리 사용량 이 줄어든다. 아래 그래프는 ResNet 모델에서 최대 메모리 사용량을 비교했는데 blocks(N)이 늘어나 파라미터 수가 늘어날 수록 효과가 더 큰 것을 확인할 수 있다.

그래프의 점선인 $\sqrt x$과 optimized의 memory 사용량이 같음을 확인할 수 있다. 이는 gradient checkpointing의 결과로메모리 사용량이 $O(\sqrt n)$으로 줄어듦을 의미한다. 물론 $O$는 checkpoint를 어떻게 설정하는 지에 따라 달라지겠지만, 결과적으로 줄어든다는 사실은 변하지 않는다.

Gradient checkpointing이 메모리 사용량을 어떻게 줄이는지 확인하기 위해선 딥러닝 모델 학습 과정에 대한 이해가 필요하다. 실제로 모델 학습 과정에서 가장 많은 메모리 사용을 요구하는 순간은 역전파를 통해 Loss 값을 구하는 순간이다. 왜 그런지는 순전파 - 역전파로 이어지는 과정을 살펴보면 알 수 있다.


Graident checkpointing!

아래 그림을 보면 역전파 과정에서 빠른 계산을 위해, 당장 사용하지 않는 노드의 값이더라도 저장(푸른색)해두는 것을 알 수 있다. 이는 연산 속도 측면에선 장점이 있지만 그만큼 저장해야할 가중치가 늘어나 메모리 사용량이 늘어난다는 단점이 있다.

기존 역전파 방식

연산 속도를 생각하지 않으면 역전파 과정에서 모든 노드의 가중치를 저장해둘 필요는 없다. 이 경우, 이전 단계부터 순전파 과정을 새로 해야한다는 단점이 있지만 메모리 사용량은 확실히 줄어든다는 것을 알 수 있다. 하지만 순전파 과정이 2번씩 일어나게 되기 때문에 결국 $O(N^2)$의 시간 복잡도가 발생하게 된다. 이 경우, N이 큰 딥러닝 특성상 연산 속도가 너무 느려질 것이다.

가중치를 저장 안하는 역전파 방식

Gradient checkpointing은 모든 값을 저장하는 방식, 어떤 값도 저장하지 않는 방식 사이의 절충안이다. 따라서 메모리 사용량을 줄이면서 적당한 속도까지 확보할 수 있다. 방법은 일부 노드만 선택한 후 그 노드의 gradient만 저장하는 것이다. 이 를 통해 checkpoint 이후의 노드까지 순전파를 빠르게 수행할 수 있게 된다.

앞서, memory 사용량을 보여준 그래프에서는 gradient checkpointing의 결과로 $O(sqrt(n))$의 시간 복잡도를 기록했다. 이는 $\sqrt N$ 간격으로 checkpoint를 설정한다면 얻을 수 있다.

 

Gradient checkpointing


단순한 FFN 말고 다른 경우엔, gradient 구하는 과정이 조금 다를 수 있기 때문에 고려할 사항이 좀 더 생긴다. 또한 실제 활용 시엔 당연하게도 Trade-off 관계인 '메모리' - '연산 속도 사이'에서 고민이 필요하다. 더 궁금한 것이 있다면 아래 링크를 참고하면 된다.

 

Fitting larger networks into memory.

TLDR; we (OpenAI) release the python/Tensorflow package openai/gradient-checkpointing, that lets you fit 10x larger neural nets into memory…

medium.com

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.