PyTorch Parameter 클래스

IT 위키

Parameter(PyTorch의 torch.nn.Parameter)는 학습 가능한 모델 파라미터를 나타내는 특수 텐서 클래스이다. 일반 텐서와 달리 모듈(nn.Module)에 속성으로 할당되면 자동으로 모델의 학습 대상 파라미터로 등록된다.

개요[편집 | 원본 편집]

  • Parametertorch.Tensor를 상속한 클래스이다. [1]
  • 생성 형식은 다음과 같다: Parameter(data, requires_grad=True)
  • 기본적으로 requires_grad=True로 설정되어 역전파 시 gradient가 계산된다.
  • nn.Module 내부에 self.weight = Parameter(...) 처럼 할당하면 자동으로 파라미터로 추적되며, parameters()named_parameters()의 반환값에 포함된다. [2]

특징[편집 | 원본 편집]

  • 일반 Tensor는 학습 대상 파라미터로 자동 등록되지 않지만, Parameter 객체는 등록된다.
  • Parameter.grad 속성은 역전파로 계산된 gradient를 저장한다.
  • requires_grad=False로 설정하면 해당 파라미터는 gradient 계산 대상에서 제외된다.

관련 클래스[편집 | 원본 편집]

  • UninitializedParameter — 초기화되지 않은 파라미터로, shape가 불명확한 경우 사용됨. [3]
  • ParameterList — 복수 개의 Parameter를 리스트 형태로 관리하는 모듈. [4]
  • ParameterDict — 이름이 있는 Parameter들의 딕셔너리 구조. [5]

사용 예제[편집 | 원본 편집]

import torch
from torch.nn import Parameter, Module

class MyModule(Module):
    def __init__(self):
        super().__init__()
        self.weight = Parameter(torch.randn(10, 10))
        self.bias = Parameter(torch.zeros(10))

    def forward(self, x):
        return x @ self.weight + self.bias

model = MyModule()

for name, param in model.named_parameters():
    print(name, param.shape, param.requires_grad)