본문 바로가기

AI

TensorFlow 함수형 API 로 AlexNet 논문 구현

AlexNet (2012)

  • 딥러닝이 주목받는 계기가 된 모델로, 5개의 Convolutional Layer와 3개의 Fully Connected Layer로 구성
  • ReLU 활성화 함수와 Dropout을 도입하여 학습 성능 향상.

A. Krizhevsky, I. Sutskever and G. E. Hinton, "ImageNet Classification with Deep Convolutional Neural Networks," in Advances in Neural Information Processing Systems (NIPS), 2012.

  • AlexNet은 2012년 ImageNet 대회에서 우승하며 딥러닝의 가능성을 널리 알린 모델입니다.
  • 구성: 5개의 Convolutional Layer와 3개의 Fully Connected Layer로 이루어져 있으며, 특히 Relu 활성화 함수와 Dropout 기법을 도입하여 신경망 학습에서 과적합을 방지했습니다.
  • ReLU 활성화는 비선형성을 높이고 학습 속도를 크게 향상시켰으며, Dropout은 네트워크 일부를 무작위로 차단하여 과적합을 줄였습니다.
  • 특징: 다중 GPU를 사용하여 병렬 처리를 통해 학습 시간을 크게 단축하고 메모리 제약을 극복 했습니다.
import tensorflow as tf
input_ = tf.keras.Input((227,227,3)) # 논문에는 224 로 나와 있지만 data augmentation 으로 실제 227
x = tf.keras.layers.Conv2D(48, 11, 4)(input_) # 필터 48개, 필터 크기 11 X 11, stride 4
x = tf.keras.layers.MaxPool2D(3, 2)(x) # 윈도우 크기 3x3, stride 2
xx = tf.keras.layers.Conv2D(128, 5, padding='same')(x)  # 필터 128개, 필터 크기 5X5

y = tf.keras.layers.Conv2D(48, 11, 4)(input_)
y = tf.keras.layers.MaxPool2D(3, 2)(y)
yy = tf.keras.layers.Conv2D(128, 5, padding='same')(y)

i = tf.keras.layers.Concatenate()([xx,yy])
i = tf.keras.layers.Conv2D(192, 3, padding='same')(i)
i = tf.keras.layers.Conv2D(192, 3, padding='same')(i)
i = tf.keras.layers.Conv2D(128, 3, padding='same')(i)
i = tf.keras.layers.MaxPool2D(3, 2)(i)
i = tf.keras.layers.Flatten()(i)

j = tf.keras.layers.Concatenate()([xx,yy])
j = tf.keras.layers.Conv2D(192, 3, padding='same')(j)
j = tf.keras.layers.Conv2D(192, 3, padding='same')(j)
j = tf.keras.layers.Conv2D(128, 3, padding='same')(j)
j = tf.keras.layers.MaxPool2D(3, 2)(j)
j = tf.keras.layers.Flatten()(j)

k = tf.keras.layers.Concatenate()([i,j])
k = tf.keras.layers.Dense(2048)(k)

l = tf.keras.layers.Concatenate()([i,j])
l = tf.keras.layers.Dense(2048)(l)

m = tf.keras.layers.Concatenate()([k,l])
m = tf.keras.layers.Dense(2048)(m)

n = tf.keras.layers.Concatenate()([k,l])
n = tf.keras.layers.Dense(2048)(n)

t = tf.keras.layers.Concatenate()([n,m])
t = tf.keras.layers.Dense(1000, activation='softmax')(t)
model = tf.keras.Model(input_, t)
tf.keras.utils.plot_model(model, rankdir='LR', show_shapes=True)

model.summary()
odel: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)              ┃ Output Shape           ┃        Param # ┃ Connected to           ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)  │ (None, 227, 227, 3)    │              0 │ -                      │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d (Conv2D)           │ (None, 55, 55, 48)     │         17,472 │ input_layer[0][0]      │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_2 (Conv2D)         │ (None, 55, 55, 48)     │         17,472 │ input_layer[0][0]      │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ max_pooling2d             │ (None, 27, 27, 48)     │              0 │ conv2d[0][0]           │
│ (MaxPooling2D)            │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ max_pooling2d_1           │ (None, 27, 27, 48)     │              0 │ conv2d_2[0][0]         │
│ (MaxPooling2D)            │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_1 (Conv2D)         │ (None, 27, 27, 128)    │        153,728 │ max_pooling2d[0][0]    │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_3 (Conv2D)         │ (None, 27, 27, 128)    │        153,728 │ max_pooling2d_1[0][0]  │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ concatenate (Concatenate) │ (None, 27, 27, 256)    │              0 │ conv2d_1[0][0],        │
│                           │                        │                │ conv2d_3[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ concatenate_1             │ (None, 27, 27, 256)    │              0 │ conv2d_1[0][0],        │
│ (Concatenate)             │                        │                │ conv2d_3[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_4 (Conv2D)         │ (None, 27, 27, 192)    │        442,560 │ concatenate[0][0]      │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_7 (Conv2D)         │ (None, 27, 27, 192)    │        442,560 │ concatenate_1[0][0]    │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_5 (Conv2D)         │ (None, 27, 27, 192)    │        331,968 │ conv2d_4[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_8 (Conv2D)         │ (None, 27, 27, 192)    │        331,968 │ conv2d_7[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_6 (Conv2D)         │ (None, 27, 27, 128)    │        221,312 │ conv2d_5[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2d_9 (Conv2D)         │ (None, 27, 27, 128)    │        221,312 │ conv2d_8[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ max_pooling2d_2           │ (None, 13, 13, 128)    │              0 │ conv2d_6[0][0]         │
│ (MaxPooling2D)            │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ max_pooling2d_3           │ (None, 13, 13, 128)    │              0 │ conv2d_9[0][0]         │
│ (MaxPooling2D)            │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ flatten (Flatten)         │ (None, 21632)          │              0 │ max_pooling2d_2[0][0]  │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ flatten_1 (Flatten)       │ (None, 21632)          │              0 │ max_pooling2d_3[0][0]  │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ concatenate_2             │ (None, 43264)          │              0 │ flatten[0][0],         │
│ (Concatenate)             │                        │                │ flatten_1[0][0]        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ concatenate_3             │ (None, 43264)          │              0 │ flatten[0][0],         │
│ (Concatenate)             │                        │                │ flatten_1[0][0]        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ dense (Dense)             │ (None, 2048)           │     88,606,720 │ concatenate_2[0][0]    │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ dense_1 (Dense)           │ (None, 2048)           │     88,606,720 │ concatenate_3[0][0]    │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ concatenate_5             │ (None, 4096)           │              0 │ dense[0][0],           │
│ (Concatenate)             │                        │                │ dense_1[0][0]          │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ concatenate_4             │ (None, 4096)           │              0 │ dense[0][0],           │
│ (Concatenate)             │                        │                │ dense_1[0][0]          │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ dense_3 (Dense)           │ (None, 2048)           │      8,390,656 │ concatenate_5[0][0]    │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ dense_2 (Dense)           │ (None, 2048)           │      8,390,656 │ concatenate_4[0][0]    │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ concatenate_6             │ (None, 4096)           │              0 │ dense_3[0][0],         │
│ (Concatenate)             │                        │                │ dense_2[0][0]          │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ dense_4 (Dense)           │ (None, 1000)           │      4,097,000 │ concatenate_6[0][0]    │
└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘
 Total params: 200,425,832 (764.56 MB)
 Trainable params: 200,425,832 (764.56 MB)
 Non-trainable params: 0 (0.00 B)