4 글 보임 - 1 에서 4 까지 (총 4 중에서)
-
글쓴이글
-
2022년 8월 7일 02:19 #41121
이정범참가자제가 학교 과제로 이미지를 통해 동물을 판별하는 모델을 제작 중에 있습니다. 인터넷 자료들을 참고하며 모델을 제작중인데 model을 학습시키는 단계에서 에러가 납니다. anaconda 가상환경에서 jupyter notebook을 이용해 작업중이며 tensorflow-gpu는 2.9.1버전 사용중이고 cuda = 11.2, cudnn = 8.1입니다 -----코드 입니다---- train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=30, shear_range=0.3, horizontal_flip=True, width_shift_range=0.1, height_shift_range=0.1, zoom_range=0.25, ) valid_datagen = ImageDataGenerator( rescale=1./255, ) -----------------------------------
batch_size = 64 img_width = 128 img_height = 128
train_data = train_datagen.flow_from_directory( 'images/data/endangered_species/train', batch_size=batch_size, target_size=(img_width, img_height), shuffle=True, ) valid_data = valid_datagen.flow_from_directory( 'images/data/endangered_species/test', target_size=(img_width, img_height), batch_size=batch_size, shuffle=False, ) -----------------------------------------
def visualize_images(images, labels): figure, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 14)) classes = list(train_data.class_indices.keys()) img_no = 0 for i in range(3): for j in range(3): img = images[img_no] label_no = np.argmax(labels[img_no])
ax[i,j].imshow(img) ax[i,j].set_title(classes[label_no]) ax[i,j].set_axis_off() img_no += 1
images, labels = next(train_data) visualize_images(images, labels) -------------------------------- base = MobileNetV2(input_shape=(img_width, img_height,3),include_top=False,weights='imagenet') base.trainable = True model = Sequential() model.add(base) model.add(GlobalAveragePooling2D()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(6, activation='softmax')) opt = Adam(learning_rate=0.001) model.compile(optimizer=opt,loss = 'categorical_crossentropy',metrics=['accuracy']) -------------------------- reduce_lr = ReduceLROnPlateau(monitor = 'val_accuracy',patience = 1,verbose = 1) early_stop = EarlyStopping(monitor = 'val_accuracy',patience = 5,verbose = 1,restore_best_weights = True) check_point = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True) --------------------------- history = model.fit(train_data, epochs=50, validation_data = valid_data, callbacks=[early_stop,reduce_lr,check_point]) ------------------------------ 무엇때문에 에러가 나는걸까요?
-
글쓴이글
4 글 보임 - 1 에서 4 까지 (총 4 중에서)
- 답변은 로그인 후 가능합니다.