PyTorch Parameter 클래스
IT 위키
Parameter(PyTorch의 torch.nn.Parameter)는 학습 가능한 모델 파라미터를 나타내는 특수 텐서 클래스이다. 일반 텐서와 달리 모듈(nn.Module)에 속성으로 할당되면 자동으로 모델의 학습 대상 파라미터로 등록된다.
개요[편집 | 원본 편집]
- Parameter는 torch.Tensor를 상속한 클래스이다. [1]
- 생성 형식은 다음과 같다: Parameter(data, requires_grad=True)
- 기본적으로 requires_grad=True로 설정되어 역전파 시 gradient가 계산된다.
- nn.Module 내부에 self.weight = Parameter(...) 처럼 할당하면 자동으로 파라미터로 추적되며, parameters() 나 named_parameters()의 반환값에 포함된다. [2]
특징[편집 | 원본 편집]
- 일반 Tensor는 학습 대상 파라미터로 자동 등록되지 않지만, Parameter 객체는 등록된다.
- Parameter.grad 속성은 역전파로 계산된 gradient를 저장한다.
- requires_grad=False로 설정하면 해당 파라미터는 gradient 계산 대상에서 제외된다.
관련 클래스[편집 | 원본 편집]
- UninitializedParameter — 초기화되지 않은 파라미터로, shape가 불명확한 경우 사용됨. [3]
- ParameterList — 복수 개의 Parameter를 리스트 형태로 관리하는 모듈. [4]
- ParameterDict — 이름이 있는 Parameter들의 딕셔너리 구조. [5]
사용 예제[편집 | 원본 편집]
import torch
from torch.nn import Parameter, Module
class MyModule(Module):
def __init__(self):
super().__init__()
self.weight = Parameter(torch.randn(10, 10))
self.bias = Parameter(torch.zeros(10))
def forward(self, x):
return x @ self.weight + self.bias
model = MyModule()
for name, param in model.named_parameters():
print(name, param.shape, param.requires_grad)