자동 미분

IT 위키

자동 미분(Automatic Differentiation, AD)은 함수의 기울기(도함수)를 연산 그래프를 기반으로 자동으로 계산하는 기술이다. 심볼릭 미분(symbolic differentiation)이나 수치 미분(numerical differentiation)과 달리, 자동 미분은 기계 오차를 최소화하면서 정확한 기울기를 효율적으로 계산할 수 있어 딥러닝과 최적화 알고리즘의 핵심 기술로 사용된다.

개요[편집 | 원본 편집]

자동 미분은 함수의 구성 요소(덧셈, 곱셈, 행렬 연산 등)에 대해 미분 규칙을 알고 있기 때문에, 프로그래머는 단지 함수를 정의하기만 하면 프레임워크가 역전파(backpropagation)를 통해 기울기를 자동 계산한다.

예: PyTorch, TensorFlow, JAX는 모두 자동 미분 엔진을 핵심으로 갖는다.

자동 미분의 종류[편집 | 원본 편집]

자동 미분은 크게 다음 두 가지 방식으로 나뉜다.

Forward Mode AD[편집 | 원본 편집]

입력 변수에서 출발하여 출력까지 미분 값을 전달하는 방식.

특징:

  • 입력 개수가 적고 출력 개수가 많은 함수에 적합
  • 계산 흐름이 단순함
  • JAX의 `jax.jvp`, `jax.forward_ad` 등이 사용

예: f: ℝ → ℝⁿ 형태에서 효율적

Reverse Mode AD (역전파, Backpropagation)[편집 | 원본 편집]

출력에서부터 역방향으로 기울기를 계산하는 방식.

특징:

  • 출력이 스칼라(예: 손실 값)이고 입력 차원이 큰 함수에 최고 효율
  • 신경망 학습의 표준 기법
  • PyTorch, TensorFlow, JAX 등 딥러닝 프레임워크 대부분이 사용

예: 신경망과 같이 파라미터 수가 매우 많을 때 가장 효율적

→ 딥러닝에서 사용하는 backpropagation은 사실상 Reverse Mode AD의 특수한 형태이다.

자동 미분의 원리[편집 | 원본 편집]

자동 미분은 보통 다음처럼 동작한다:

  1. 연산을 **그래프(Computation Graph)** 로 기록
  2. 각 노드(연산)의 미분 규칙을 알고 있으므로
  3. Chain rule(연쇄 법칙)을 이용해 전체 미분을 계산

예시 연산:

y = x*2 + x*3
dy/dx = (2 + 3) = 5

프레임워크는 사용자가 직접 미분 규칙을 쓸 필요 없이 내부적으로 규칙을 조합해 기울기를 계산한다.

수치 미분 및 심볼릭 미분과의 비교[편집 | 원본 편집]

방식 설명 장점 단점
수치 미분 (Finite Differences) h를 이용한 근사 차분 구현 쉬움 오차 큼, 비용 비쌈
심볼릭 미분 함수식을 기호적 조작 정확한 표현식 표현식 폭발, 복잡함
자동 미분 그래프 기반 미분 정확 + 빠름 심볼릭보다 엄밀하진 않음 (추상적)

주요 프레임워크의 자동 미분[편집 | 원본 편집]

PyTorch (Autograd)[편집 | 원본 편집]

  • 동적 그래프 기반
  • 연산 수행 시 그래프를 기록
  • `.backward()` 호출 시 기울기 계산
  • Eager execution이 직관적

TensorFlow (GradientTape)[편집 | 원본 편집]

  • 2.x에서 eager + 자동 그래프 변환
  • `tf.GradientTape()`로 연산 추적

JAX (Autograd + XLA)[편집 | 원본 편집]

  • 순수 함수 기반
  • `jax.grad` 로 함수의 미분 생성
  • JIT + XLA 최적화로 매우 빠른 성능
  • Forward/Reverse/Vector-Jacobian 등 다양한 모드 지원

JAX는 AD가 프레임워크 핵심이자 철학이며, 성능 최적화가 매우 뛰어나다.

딥러닝에서의 자동 미분[편집 | 원본 편집]

딥러닝 학습 과정은 다음과 같이 자동 미분을 필수로 사용한다:

  1. 손실 함수 L(θ) 계산
  2. 파라미터 θ에 대한 ∂L/∂θ 계산 (역전파)
  3. 옵티마이저(Adam, SGD 등)로 θ 업데이트

이 기울기 계산 과정 전체가 자동 미분 엔진에 의해 수행된다.

Jacobian, Hessian 등 고차 미분[편집 | 원본 편집]

자동 미분은 1차 미분뿐 아니라 Jacobian, Hessian 같은 고차 미분도 계산할 수 있다.

대표적 예:

  • Hessian-vector product (HVP)
  • Jacobian-vector product (JVP)
  • JAX는 특히 고차 미분에 강함 (`jax.jacobian`, `jax.hessian`)

장점[편집 | 원본 편집]

  • 사용자는 미분 공식을 직접 작성할 필요 없음
  • 복잡한 수식·모델에서도 안정적으로 동작
  • 딥러닝 모델 학습의 필수 기술
  • Forward/Reverse 구조로 고성능을 보장
  • XLA/Triton 등과 결합하면 매우 빠름

한계[편집 | 원본 편집]

  • 메모리 사용량이 증가할 수 있음
  • 함수형 제약(JAX) 등 사용 난이도 증가 가능
  • 비미분 가능 연산(discrete ops)은 별도 처리 필요
  • 계산 그래프가 너무 크면 성능 저하 발생

함께 보기[편집 | 원본 편집]

각주[편집 | 원본 편집]