๐Ÿš€ PyTorch๋กœ CNN ๋ชจ๋ธ ๋‹จ๊ณ„๋ณ„๋กœ ์ •๋ณตํ•˜๊ธฐ

์˜ค๋Š˜์€ ๋งˆ์น˜ ๋ ˆ๊ณ  ๋ธ”๋ก์„ ํ•˜๋‚˜์”ฉ ์Œ“์•„ ๋ฉ‹์ง„ ์„ฑ์„ ๋งŒ๋“ค๋“ฏ, ๊ฐ„๋‹จํ•œ ์‹ ๊ฒฝ๋ง์—์„œ ์‹œ์ž‘ํ•ด ์™„์ „ํ•œ CNN ๋ชจ๋ธ์„ ๋‹จ๊ณ„๋ณ„๋กœ ๊ตฌ์ถ•ํ•˜๋Š” ๊ณผ์ •์„ ์ง„ํ–‰ํ–ˆ๋‹ค. Fashion MNIST ๋ฐ์ดํ„ฐ์…‹์„ ๊ฐ€์ง€๊ณ , ๊ฐ ๋ถ€ํ’ˆ(๋ ˆ์ด์–ด)์ด ๋ชจ๋ธ ์„ฑ๋Šฅ์— ์–ด๋–ค ์˜ํ–ฅ์„ ๋ฏธ์น˜๋Š”์ง€ ์ง์ ‘ ํ™•์ธํ•ด๋ณด๋Š” ์‹œ๊ฐ„์ด์—ˆ๋‹ค.

๐Ÿง โ€œCNN, ์™œ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ์— ๊ฐ•๋ ฅํ• ๊นŒ?โ€

์ด๋ฏธ์ง€๋Š” ํ”ฝ์…€์˜ ์œ„์น˜์™€ ๊ด€๊ณ„, ์ฆ‰ ๊ณต๊ฐ„์  ํŠน์ง•์ด ๋งค์šฐ ์ค‘์š”ํ•˜๋‹ค. CNN์€ ์ปจ๋ณผ๋ฃจ์…˜(Convolution) ๊ณผ ํ’€๋ง(Pooling) ์ด๋ผ๋Š” ํŠน๋ณ„ํ•œ ์—ฐ์‚ฐ์œผ๋กœ ์ด ๊ณต๊ฐ„ ์ •๋ณด๋ฅผ ๋˜‘๋˜‘ํ•˜๊ฒŒ ์œ ์ง€ํ•˜๋ฉฐ ํ•™์Šตํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๋ฏธ์ง€ ์ธ์‹์—์„œ ๋›ฐ์–ด๋‚œ ์„ฑ๋Šฅ์„ ๋ฐœํœ˜ํ•œ๋‹ค.


๐Ÿ› ๏ธ CNN์˜ ํ•ต์‹ฌ ์žฌ๋ฃŒ: ์ปจ๋ณผ๋ฃจ์…˜๊ณผ ํ’€๋ง

  • ์ปจ๋ณผ๋ฃจ์…˜(Convolution): ์ด๋ฏธ์ง€ ์œ„๋ฅผ ์Šฌ๋ผ์ด๋”ฉํ•˜๋Š” ํ•„ํ„ฐ(์ปค๋„) ๋ฅผ ํ†ตํ•ด ์œค๊ณฝ์„ , ์งˆ๊ฐ ๊ฐ™์€ ๊ณ ์œ ํ•œ ํŠน์ง•์„ ์ถ”์ถœํ•œ๋‹ค. ์ด ๊ฒฐ๊ณผ๋ฌผ์„ ํŠน์ง• ๋งต(Feature Map) ์ด๋ผ๊ณ  ๋ถ€๋ฅธ๋‹ค.
  • ํ’€๋ง(Pooling): ํŠน์ง• ๋งต์˜ ํฌ๊ธฐ๋ฅผ ์ค„์—ฌ(Sub-sampling) ์ค‘์š”ํ•œ ์ •๋ณด๋งŒ ๋‚จ๊ธด๋‹ค. ๋งฅ์Šค ํ’€๋ง(Max Pooling) ์€ ํŠน์ • ์˜์—ญ์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ํŠน์ง•(๊ฐ€์žฅ ํฐ ๊ฐ’)๋งŒ ๋ฝ‘์•„๋‚ด ๋ชจ๋ธ์„ ๋” ๊ฐ•์ธํ•˜๊ฒŒ ๋งŒ๋“ ๋‹ค.


โœจ ๋ชจ๋ธ ๋นŒ๋“œ์—… ๊ณผ์ •: ๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ์ฒœ์ฒœํžˆ

Fashion MNIST ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ์–ด๋–ป๊ฒŒ ์ง„ํ™”ํ•˜๋Š”์ง€ ๋‹จ๊ณ„๋ณ„๋กœ ๋”ฐ๋ผ๊ฐ€ ๋ณด์•˜๋‹ค.

1๏ธโƒฃ Step 1: ๊ธฐ๋ณธ DNN ๋ชจ๋ธ (Baseline)

  • ๊ฐ€์žฅ ๋จผ์ €, ์ด๋ฏธ์ง€๋ฅผ 1์ฐจ์›์œผ๋กœ ๊ธธ๊ฒŒ ํŽผ์ณ์„œ ์ผ๋ฐ˜์ ์ธ ์™„์ „ ์—ฐ๊ฒฐ ์‹ ๊ฒฝ๋ง(DNN)์œผ๋กœ ํ•™์Šต์‹œ์ผฐ๋‹ค.
  • ๊ฒฐ๊ณผ: Test Accuracy 88.66%. ๋‚˜์˜์ง€ ์•Š์€ ์‹œ์ž‘์ด๋‹ค!

2๏ธโƒฃ Step 2: ์ปจ๋ณผ๋ฃจ์…˜(CNN) ๋ ˆ์ด์–ด ์ถ”๊ฐ€

  • DNN ๋ชจ๋ธ ์•ž์— ์ปจ๋ณผ๋ฃจ์…˜ ๋ ˆ์ด์–ด๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ๊ณต๊ฐ„ ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜๋„๋ก ํ–ˆ๋‹ค.
  • ๊ฒฐ๊ณผ: Test Accuracy 88.37%. ์ •ํ™•๋„๋Š” ๋น„์Šทํ–ˆ์ง€๋งŒ, ํ•™์Šต ๊ณผ์ •์ด ๋” ์•ˆ์ •ํ™”๋˜์—ˆ๋‹ค.

3๏ธโƒฃ Step 3: ๋งฅ์Šคํ’€๋ง(Max Pooling)์œผ๋กœ ์„ฑ๋Šฅ ์ ํ”„!

  • ์ปจ๋ณผ๋ฃจ์…˜ ๋ ˆ์ด์–ด ๋’ค์— ๋งฅ์Šค ํ’€๋ง์„ ์ถ”๊ฐ€ํ•ด ํŠน์ง•์„ ์••์ถ•ํ•˜๊ณ  ์—ฐ์‚ฐ ํšจ์œจ์„ ๋†’์˜€๋‹ค.
  • ๊ฒฐ๊ณผ: Test Accuracy 91.22%. ์„ฑ๋Šฅ์ด ๋ˆˆ์— ๋„๊ฒŒ ํ–ฅ์ƒ๋˜์—ˆ๋‹ค! ์—ญ์‹œ CNN์˜ ํ•ต์‹ฌ์€ ํ’€๋ง์— ์žˆ๋Š” ๊ฒƒ ๊ฐ™๋‹ค.

4๏ธโƒฃ Step 4: ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ & ์•ˆ์ •ํ™” (Dropout, Batch Norm)

  • ๋ชจ๋ธ์ด ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์—๋งŒ ๋„ˆ๋ฌด ์ต์ˆ™ํ•ด์ง€๋Š” ๊ณผ์ ํ•ฉ์„ ๋ง‰๊ธฐ ์œ„ํ•ด ๋“œ๋กญ์•„์›ƒ(Dropout) ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”(Batch Normalization) ๋ฅผ ์ถ”๊ฐ€ํ–ˆ๋‹ค. ์ด๋“ค์€ ๋ชจ๋ธ์„ ๋” ์•ˆ์ •์ ์ด๊ณ  ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์ด ์ข‹๊ฒŒ ๋งŒ๋“ค์–ด์ค€๋‹ค.
# ์ตœ์ข… ๋ชจ๋ธ์˜ ์ผ๋ถ€
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32), # ๋ฐฐ์น˜ ์ •๊ทœํ™”
        nn.ReLU(),
        nn.MaxPool2d(2),    # ๋งฅ์Šค ํ’€๋ง
        nn.Dropout(0.3)     # ๋“œ๋กญ์•„์›ƒ
    )
    self.layer2 = nn.Sequential(
        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout(0.3)
    )
    self.fc1 = nn.Linear(64 * 7 * 7, 10) # ์™„์ „ ์—ฐ๊ฒฐ ๋ ˆ์ด์–ด

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = x.view(x.size(0), -1) # flatten
    x = self.fc1(x)
    return x
  • ์ตœ์ข… ๊ฒฐ๊ณผ: Test Accuracy ~93%. ๋ชจ๋“  ๋ถ€ํ’ˆ์„ ์กฐ๋ฆฝํ•˜๋‹ˆ ํ›จ์”ฌ ๊ฒฌ๊ณ ํ•˜๊ณ  ๊ฐ•๋ ฅํ•œ ๋ชจ๋ธ์ด ์™„์„ฑ๋˜์—ˆ๋‹ค. ๐Ÿ‘

โœจ ์˜ค๋Š˜์˜ ํšŒ๊ณ 

๋‹จ์ˆœํ•œ DNN ๋ชจ๋ธ์—์„œ ์‹œ์ž‘ํ•ด ์ปจ๋ณผ๋ฃจ์…˜, ํ’€๋ง, ๋“œ๋กญ์•„์›ƒ, ๋ฐฐ์น˜ ์ •๊ทœํ™” ๊ฐ™์€ ๊ฐœ๋…๋“ค์„ ํ•˜๋‚˜์”ฉ ์ถ”๊ฐ€ํ•˜๋ฉฐ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ๋ˆˆ์— ๋„๊ฒŒ ์ข‹์•„์ง€๋Š” ๊ฒƒ์„ ์ง์ ‘ ํ™•์ธํ•˜๋‹ˆ ์ •๋ง ํฅ๋ฏธ๋กœ์› ๋‹ค. ๊ฐ ๋ ˆ์ด์–ด๊ฐ€ ์–ด๋–ค ์—ญํ• ์„ ํ•˜๋Š”์ง€, ์™œ ํ•„์š”ํ•œ์ง€๋ฅผ ๋ชธ์†Œ ์ฒด๊ฐํ•  ์ˆ˜ ์žˆ๋Š” ์ตœ๊ณ ์˜ ์‹ค์Šต์ด์—ˆ๋‹ค.

ํŠนํžˆ ๋งฅ์Šค ํ’€๋ง์„ ์ถ”๊ฐ€ํ–ˆ์„ ๋•Œ ์„ฑ๋Šฅ์ด ํฌ๊ฒŒ ์˜ค๋ฅด๋Š” ๊ฒƒ์„ ๋ณด๊ณ , ํŠน์ง•์„ ์ž˜ ์••์ถ•ํ•˜๊ณ  ๊ฐ•์กฐํ•˜๋Š” ๊ฒƒ์ด ์–ผ๋งˆ๋‚˜ ์ค‘์š”ํ•œ์ง€ ๊นจ๋‹ฌ์•˜๋‹ค. ๋‹ค์Œ์—๋Š” ์˜ค๋Š˜ ๋งŒ๋“  ๋ชจ๋ธ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋” ๋ณต์žกํ•œ ๋ฐ์ดํ„ฐ์…‹์— ๋„์ „ํ•˜๊ฑฐ๋‚˜, VGG๋‚˜ ResNet ๊ฐ™์€ ์œ ๋ช…ํ•œ CNN ์•„ํ‚คํ…์ฒ˜๋ฅผ ๋ถ„์„ํ•ด๋ด์•ผ๊ฒ ๋‹ค. ๐Ÿ˜„