keras 迁移学习, 微调, model的predict函数定义
发布日期:2021-11-21 04:41:31 浏览次数:7 分类:技术文章

本文共 3946 字,大约阅读时间需要 13 分钟。


def add_new_last_layer(base_model, nb_classes):  """Add last layer to the convnet  Args:
base_model: keras model excluding top
nb_classes: # of classes  Returns:
new keras model with last layer  """  x = base_model.output  x = GlobalAveragePooling2D()(x)  x = Dense(FC_SIZE, activation='relu')(x)   predictions = Dense(nb_classes, activation='softmax')(x)   model = Model(input=base_model.input, output=predictions)  return model

载入预训练模型作为前端的网络,在自己的数据集上进行微调,最好按照以下两步进行:

  1. Transfer learning:freeze all but the penultimate layer and re-train the lastDense layer
  2. Fine-tuning:un-freeze the lower convolutional layers and retrain more layers

Doing both, in that order, will ensure a more stable and consistent training. This is because the large gradient updates triggered by randomly initialized weights could wreck the learned weights in the convolutional base if not frozen. Once the last layer has stabilized (transfer learning), then we move onto retraining more layers (fine-tuning).

Transfer learning

def setup_to_transfer_learn(model, base_model):  """Freeze all layers and compile the model"""  for layer in base_model.layers:
layer.trainable = False  model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
 metrics=['accuracy'])

Fine-tune

def setup_to_finetune(model):   """Freeze the bottom NB_IV3_LAYERS and retrain the remaining top
   layers.   note: NB_IV3_LAYERS corresponds to the top 2 inception blocks in
  the inceptionv3 architecture   Args:
 model: keras model   """   for layer in model.layers[:NB_IV3_LAYERS_TO_FREEZE]:
  layer.trainable = False   for layer in model.layers[NB_IV3_LAYERS_TO_FREEZE:]:
  layer.trainable = True   model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy')

When fine-tuning, it’s important to lower your learning rate relative to the rate that was used when training from scratch (lr=0.0001), otherwise, the optimization could destabilize and the loss diverge.

Training

Now we’re all set for training. Usefit_generator for both transfer learning and fine-tuning. 分两个阶段依次进行训练

history = model.fit_generator(  train_generator,  samples_per_epoch=nb_train_samples,  nb_epoch=nb_epoch,  validation_data=validation_generator,  nb_val_samples=nb_val_samples,  class_weight='auto')
model.save(args.output_model_file)


在keras2.0版本以上时,函数参数做了改变

datagen = ImageDataGenerator(
featurewise_center=False,  # set input mean to 0 over the dataset
samplewise_center=False,  # set each sample mean to 0
featurewise_std_normalization=False,  # divide inputs by std of the dataset
samplewise_std_normalization=False,  # divide each input by its std
zca_whitening=False,  # apply ZCA whitening
rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
horizontal_flip=True,  # randomly flip images
vertical_flip=False)  # randomly flip images
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)
# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(x_train, y_train,
 batch_size=batch_size),
steps_per_epoch=x_train.shape[0] // batch_size,
  epochs=epochs,
validation_data=(x_test, y_test))

预测函数:

def predict(model, img, target_size, top_n=3):  """Run model prediction on image  Args:
model: keras model
img: PIL format image
target_size: (width, height) tuple
top_n: # of top predictions to return  Returns:
list of predicted labels and their probabilities  """  if img.size != target_size:
img = img.resize(target_size)
x = image.img_to_array(img)  x = np.expand_dims(x, axis=0)   # 插入这一个轴是关键,因为keras中的model的tensor的shape是(bath_size, h, w, c),如果是tf后台  x = preprocess_input(x)  preds = model.predict(x)  return decode_predictions(preds, top=top_n)[0]

转载地址:https://blog.csdn.net/xiaojiajia007/article/details/73826222 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:tensorflow poolallocator
下一篇:keras 模型用于预测时的注意事项

发表评论

最新留言

哈哈,博客排版真的漂亮呢~
[***.67.49.69]2022年08月16日 12时00分15秒