자동 미분
자동 미분(Automatic Differentiation, AD)은 함수의 기울기(도함수)를 연산 그래프를 기반으로 자동으로 계산하는 기술이다. 심볼릭 미분(symbolic differentiation)이나 수치 미분(numerical differentiation)과 달리, 자동 미분은 기계 오차를 최소화하면서 정확한 기울기를 효율적으로 계산할 수 있어 딥러닝과 최적화 알고리즘의 핵심 기술로 사용된다.
개요[편집 | 원본 편집]
자동 미분은 함수의 구성 요소(덧셈, 곱셈, 행렬 연산 등)에 대해 미분 규칙을 알고 있기 때문에, 프로그래머는 단지 함수를 정의하기만 하면 프레임워크가 역전파(backpropagation)를 통해 기울기를 자동 계산한다.
예: PyTorch, TensorFlow, JAX는 모두 자동 미분 엔진을 핵심으로 갖는다.
자동 미분의 종류[편집 | 원본 편집]
자동 미분은 크게 다음 두 가지 방식으로 나뉜다.
Forward Mode AD[편집 | 원본 편집]
입력 변수에서 출발하여 출력까지 미분 값을 전달하는 방식.
특징:
- 입력 개수가 적고 출력 개수가 많은 함수에 적합
- 계산 흐름이 단순함
- JAX의 `jax.jvp`, `jax.forward_ad` 등이 사용
예: f: ℝ → ℝⁿ 형태에서 효율적
Reverse Mode AD (역전파, Backpropagation)[편집 | 원본 편집]
출력에서부터 역방향으로 기울기를 계산하는 방식.
특징:
- 출력이 스칼라(예: 손실 값)이고 입력 차원이 큰 함수에 최고 효율
- 신경망 학습의 표준 기법
- PyTorch, TensorFlow, JAX 등 딥러닝 프레임워크 대부분이 사용
예: 신경망과 같이 파라미터 수가 매우 많을 때 가장 효율적
→ 딥러닝에서 사용하는 backpropagation은 사실상 Reverse Mode AD의 특수한 형태이다.
자동 미분의 원리[편집 | 원본 편집]
자동 미분은 보통 다음처럼 동작한다:
- 연산을 **그래프(Computation Graph)** 로 기록
- 각 노드(연산)의 미분 규칙을 알고 있으므로
- Chain rule(연쇄 법칙)을 이용해 전체 미분을 계산
예시 연산:
y = x*2 + x*3
dy/dx = (2 + 3) = 5
프레임워크는 사용자가 직접 미분 규칙을 쓸 필요 없이 내부적으로 규칙을 조합해 기울기를 계산한다.
수치 미분 및 심볼릭 미분과의 비교[편집 | 원본 편집]
| 방식 | 설명 | 장점 | 단점 |
|---|---|---|---|
| 수치 미분 (Finite Differences) | h를 이용한 근사 차분 | 구현 쉬움 | 오차 큼, 비용 비쌈 |
| 심볼릭 미분 | 함수식을 기호적 조작 | 정확한 표현식 | 표현식 폭발, 복잡함 |
| 자동 미분 | 그래프 기반 미분 | 정확 + 빠름 | 심볼릭보다 엄밀하진 않음 (추상적) |
주요 프레임워크의 자동 미분[편집 | 원본 편집]
PyTorch (Autograd)[편집 | 원본 편집]
- 동적 그래프 기반
- 연산 수행 시 그래프를 기록
- `.backward()` 호출 시 기울기 계산
- Eager execution이 직관적
TensorFlow (GradientTape)[편집 | 원본 편집]
- 2.x에서 eager + 자동 그래프 변환
- `tf.GradientTape()`로 연산 추적
JAX (Autograd + XLA)[편집 | 원본 편집]
- 순수 함수 기반
- `jax.grad` 로 함수의 미분 생성
- JIT + XLA 최적화로 매우 빠른 성능
- Forward/Reverse/Vector-Jacobian 등 다양한 모드 지원
JAX는 AD가 프레임워크 핵심이자 철학이며, 성능 최적화가 매우 뛰어나다.
딥러닝에서의 자동 미분[편집 | 원본 편집]
딥러닝 학습 과정은 다음과 같이 자동 미분을 필수로 사용한다:
- 손실 함수 L(θ) 계산
- 파라미터 θ에 대한 ∂L/∂θ 계산 (역전파)
- 옵티마이저(Adam, SGD 등)로 θ 업데이트
이 기울기 계산 과정 전체가 자동 미분 엔진에 의해 수행된다.
Jacobian, Hessian 등 고차 미분[편집 | 원본 편집]
자동 미분은 1차 미분뿐 아니라 Jacobian, Hessian 같은 고차 미분도 계산할 수 있다.
대표적 예:
- Hessian-vector product (HVP)
- Jacobian-vector product (JVP)
- JAX는 특히 고차 미분에 강함 (`jax.jacobian`, `jax.hessian`)
장점[편집 | 원본 편집]
- 사용자는 미분 공식을 직접 작성할 필요 없음
- 복잡한 수식·모델에서도 안정적으로 동작
- 딥러닝 모델 학습의 필수 기술
- Forward/Reverse 구조로 고성능을 보장
- XLA/Triton 등과 결합하면 매우 빠름
한계[편집 | 원본 편집]
- 메모리 사용량이 증가할 수 있음
- 함수형 제약(JAX) 등 사용 난이도 증가 가능
- 비미분 가능 연산(discrete ops)은 별도 처리 필요
- 계산 그래프가 너무 크면 성능 저하 발생