• 로그인
  • 장바구니에 상품이 없습니다.

home2 게시판 Python, SQL 게시판 model.fit()에서 에러가 납니다

model.fit()에서 에러가 납니다

4 글 보임 - 1 에서 4 까지 (총 4 중에서)
  • 글쓴이
  • #41121

    이정범
    참가자
    제가 학교 과제로 이미지를 통해 동물을 판별하는 모델을 제작 중에 있습니다.
    인터넷 자료들을 참고하며 모델을 제작중인데 model을 학습시키는 단계에서 에러가 납니다.
    anaconda 가상환경에서 jupyter notebook을 이용해 작업중이며 
    tensorflow-gpu는 2.9.1버전 사용중이고 cuda = 11.2, cudnn = 8.1입니다
    -----코드 입니다----
    train_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=30,
            shear_range=0.3,
            horizontal_flip=True,
            width_shift_range=0.1,
            height_shift_range=0.1,
            zoom_range=0.25,
    )
    valid_datagen = ImageDataGenerator(
            rescale=1./255,
    )
    -----------------------------------
    
    batch_size = 64
    img_width = 128
    img_height = 128
    train_data = train_datagen.flow_from_directory(
        'images/data/endangered_species/train',
        batch_size=batch_size,
        target_size=(img_width, img_height),
        shuffle=True,
    )
    valid_data = valid_datagen.flow_from_directory(
        'images/data/endangered_species/test',
        target_size=(img_width, img_height),
        batch_size=batch_size,
        shuffle=False,
    )
    -----------------------------------------
    
    def visualize_images(images, labels):
        figure, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 14))
        classes = list(train_data.class_indices.keys())
        img_no = 0
        for i in range(3):
            for j in range(3):
                img = images[img_no]
                label_no = np.argmax(labels[img_no])
                ax[i,j].imshow(img)
                ax[i,j].set_title(classes[label_no])
                ax[i,j].set_axis_off()
                img_no += 1
    images, labels = next(train_data)
    visualize_images(images, labels)
    --------------------------------
    base = MobileNetV2(input_shape=(img_width, img_height,3),include_top=False,weights='imagenet')
    base.trainable = True
    model = Sequential()
    model.add(base)
    model.add(GlobalAveragePooling2D())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(6, activation='softmax'))
    opt = Adam(learning_rate=0.001)
    model.compile(optimizer=opt,loss = 'categorical_crossentropy',metrics=['accuracy'])
    --------------------------
    reduce_lr = ReduceLROnPlateau(monitor = 'val_accuracy',patience = 1,verbose = 1)
    early_stop = EarlyStopping(monitor = 'val_accuracy',patience = 5,verbose = 1,restore_best_weights = True)
    check_point = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True)
    ---------------------------
    history = model.fit(train_data, epochs=50, 
                        validation_data = valid_data, 
                        callbacks=[early_stop,reduce_lr,check_point])
    ------------------------------
    무엇때문에 에러가 나는걸까요?
    #41131

    codingapple
    키 마스터
    무슨에러입니까
    #41188

    이정범
    참가자
    아 죄송합니다
    제가 에러를 안보여 드렸네요
    에러와 전체적인 코드입니다
    
    
    #41256

    codingapple
    키 마스터
    라벨이 4갠데 예측결과는 6개나 나오게 설정했다는 에러같군요 레이어들을 확인해봅시다
4 글 보임 - 1 에서 4 까지 (총 4 중에서)
  • 답변은 로그인 후 가능합니다.

About

현재 월 700명 신규수강중입니다.

  (09:00~20:00) 빠른 상담은 카톡 플러스친구 코딩애플 (링크)
  admin@codingapple.com
  이용약관
ⓒ Codingapple, 강의 예제, 영상 복제 금지
top

© Codingapple, All rights reserved. 슈퍼로켓 에듀케이션 / 서울특별시 강동구 고덕로 19길 30 / 사업자등록번호 : 212-26-14752 온라인 교육학원업 / 통신판매업신고번호 : 제 2017-서울강동-0002 호 / 개인정보관리자 : 박종흠