keras data generation, python生成器
发布日期:2021-11-21 04:41:32 浏览次数:44 分类:技术文章

本文共 2373 字,大约阅读时间需要 7 分钟。

Implement fit_generator( ) in Keras

Here is an example of fit_generator():

model.fit_generator(generator(), samples_per_epoch=50, nb_epoch=10)

Breaking it down:

generator() generates batches of samples indefinitely

sample_per_epoch number of samples you want to train in each epoch

nb_epoch number of epochs

As you can manually definesample_per_epoch andnb_epoch , you have to provide codes forgenerator . Here is an example:

Assume features is an array of data with shape (100,64,64,3) and labels is an array of data with shape (100,1). We use data from features and labels to train our model.

def generator(features, labels, batch_size): # Create empty arrays to contain batch of features and labels# batch_features = np.zeros((batch_size, 64, 64, 3)) batch_labels = np.zeros((batch_size,1)) while True:   for i in range(batch_size):     # choose random index in features     index= random.choice(len(features),1)     batch_features[i] = some_processing(features[index])     batch_labels[i] = labels[index]   yield batch_features, batch_labels

在python中,当你定义一个函数,使用了yield关键字时,这个函数就是一个生成器" (也就是说,只要有yield这个词出现,你在用def定义函数的时候,系统默认这就不是一个函数啦,而是一个生成器)。如果需要生成器返回(下一个)值,需要调用.next()函数。其实当系统判断def是生成器时,就会自动支持.next()函数,例如:

def fib(max):          a, b = 1, 1          while a < max:              yield a              a, b = b, a+b            for n in fib(15):          print n            m = fib(13)      print m      print m.next()      print m.next()      print m.next()

1. 每个生成器只能使用一次。比如上个例子中的m生成器,一旦打印完m的6个值,就没有办法再打印m的值了,因为已经吐完了。生成器每次运行之后都会在运行到yield的位置时候,保存暂时的状态,跳出生成器函数,在下次执行生成器函数的时候会从上次截断的位置继续开始执行循环。

2. yield一般都在def生成器定义中搭配一些循环语句使用,比如for或者while,以防止运行到生成器末尾跳出生成器函数,就不能再yield了。有时候,为了保证生成器函数永远也不会执行到函数末尾,会用while True: 语句,这样就会保证只要使用next(),这个生成器就会生成一个值,是处理无穷序列的常见方法。

拿上面那个为例, 每次继续开始执行上次没处理完成的位置,但后面的每次循环都只在while True这个循环体内部运行,之前的非循环体batch_feature...  batch_label ...并没有执行,因为它们只在第一次进入生成其函数的时候才有效地运行过一次。

With the generator above, if we definebatch_size = 10 , that means it will randomly taking out 10 samples fromfeatures and labels to feed into each epoch until an epoch hits 50 sample limit. Then fit_generator() destroys the used data and move on repeating the same process in new epoch.

One great advantage aboutfit_generator() besides saving memory is user can integrate random augmentation inside the generator, so it will always provide model with new data to train on the fly.

转载地址:https://blog.csdn.net/xiaojiajia007/article/details/74909518 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:卷积与反卷积(转置卷积)关系的公式推导 及其各自的形式
下一篇:tensorflow poolallocator

发表评论

最新留言

路过按个爪印,很不错,赞一个!
[***.219.124.196]2024年04月17日 23时55分25秒