import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPool2D
from tensorflow.keras.layers import Flatten, Dense, Dropout
#######################################
# 32x32 픽셀의 6만개 컬러이미지 포함
# 각 이미지는 10개의 클래스로 라벨링
# 50000개 이미지는 트레이닝 용도로 사용
# 10000개 이미지는 테스트 용도로 사용
#######################################
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.reshape(-1, 32, 32, 3) # 텐서로 변환, 높이, 너비, 채널
x_test = x_test.reshape(-1, 32, 32, 3)
print(x_train.shape, x_test.shape)
print(y_train.shape, y_test.shape)
  # 정규화 작업 전에 데이터 샘플 출력
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
print("Train samples:", x_train.shape, y_train.shape)
print("Test samples:", x_test.shape, y_test.shape)
plt.figure(figsize=(10,10))
for i in range(25):
  plt.subplot(5, 5, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(x_train[i])
  plt.xlabel(class_names[y_train[i][0]])
plt.show()
    # 정규화 작업은 class_names 의 값들이 출력이된후 작업해야 함 ! 
    # 아니면 다 깨짐
    # 255.0 정규화
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
 
 
# CNN 모델구축
cnn = Sequential()
cnn.add(Conv2D(input_shape=(32,32,3), kernel_size=(3,3), filters=32, activation='relu'))
cnn.add(Conv2D(kernel_size=(3,3), filters=64, activation='relu'))
cnn.add(MaxPool2D(pool_size=(2,2)))
cnn.add(Dropout(0.25))
cnn.add(Flatten())  # 3차원 텐서를 1차원 벡터로 변환
cnn.add(Dense(128, activation='relu'))  # 은닉층 개념
cnn.add(Dropout(0.5))
cnn.add(Dense(10, activation='softmax'))    # 출력층
# CNN 모델 컴파일 및 학습
cnn.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
hist = cnn.fit(x_train, y_train, batch_size=128, epochs=30, validation_data=(x_test, y_test))
 
cnn.evaluate(x_test, y_test)    # 모델 정확도 평가
 
 
# 정확도 및 손실 1
import matplotlib.pyplot as plt
plt.plot(hist.history['accuracy'])
plt.plot(hist.history['val_accuracy'])
plt.title('Accuracy Trend')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='best')
plt.grid()
plt.show()
 
 
 
# 정확도 및 손실 2
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Loss Trend')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='best')
plt.grid()
plt.show()
