tensorflow2之数据管道Dataset
发布日期:2022-02-14 23:02:49 浏览次数:43 分类:技术文章

本文共 2771 字,大约阅读时间需要 9 分钟。

原则

  • 数据量不大,直接入内存计算即可
  • 数据量过大,无法一次性载入内存,需要分批读入:tf.data的API构建数据输入管道

构建

  • numpy

    ds = tf.data.Dataset.from_tensor_slices((['train_x'], ['train_y']))
  • pandas:同上df.to_dict('list')

  • generator

    def generator():    for features, labels in ds:         yield (features, labels)ds = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.int32))
  • csv

    tf.data.experimental.make_csv_dataset(file_pattern = ['x.csv', 'xx.csv'], batch_size=3, label_name='survived',                                         na_value='', num_epochs=1, ignore_errors=True)
  • 文本:

    tf.data.TextLineDataset(filenames = ['x.csv', 'xx.csv']).skip(1) # 去掉第一行的header
  • 文件路径:

    tf.data.Dataset.list_files('./*/*.jpg')
  • tfrecords文件

    • 缺点:复杂,需要对样本构建tf.Example后压缩城字符串写到tfrecords文件,读取后再解析成tf.Example
    • 优点:压缩后文件较小,便于网络传播,加载速度快

管道提升

  • 模型训练耗时的两个部分
    • 数据准备:构建高效的数据管道来提升

      • 使用prefetch方法让数据准备和参数迭代两个过程相互并行

        # 模拟数据准备def generator():    for i in range(10):        time.sleep(2)        yield i# 模拟参数迭代def train_step():    time.sleep(1)# 一般情况下的串行,耗时:10 * 2 + 10 * 1 = 30sds = tf.data.Dataset.from_generator(generator, output_types=(tf.int32))for x in ds:    train_step()# prefetch实现数据准备和参数迭代相互并行,耗时:max(10 * 2, 10 * 1) = 20sfor x in ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE):    train_step()
      • 使用interleave方法可以让数据读取过程多进程执行,并将不同来源数据夹在一起

        ds_files = tf.data.Dataset.list_files("./data/titanic/*.csv")# flat_map单进程ds = ds_files.flat_map(lambda x:tf.data.TextLineDataset(x).skip(1))# interleave多进程ds = ds_files.interleave(lambda x:tf.data.TextLineDataset(x).skip(1))
      • 使用map时设置num_parallel_calls让数据转换过程多进程执行

        ds = tf.data.Dataset.list_files("./*/*.jpg")def load_image(img_path,size = (32,32)):     label = 1 if tf.strings.regex_full_match(img_path,".*/automobile/.*") else 0 # 文件夹automobile下的label为1,否则为0     img = tf.io.read_file(img_path)     img = tf.image.decode_jpeg(img)     img = tf.image.resize(img,size)     return(img,label)# 单进程ds_map = ds.map(load_image)for _ in ds_map:     opera# 多进程ds_map_parallel = ds.map(load_image, num_parallel_calls = tf.data.experimental.AUTOTUNE)for _ in ds_map_parallel :     opera
      • 使用cache方法让数据在第一个epoch后缓存到内存中,仅限于数据集不大的情况

        ds = tf.data.Dataset.from_generator(generator,output_types = (tf.int32)).cache()
      • 使用map转换时,先batch,然后采用向量化的转换方法对每个batch进行转换

        ds = tf.data.Dataset.range(100000)# 先map后batchds_map_batch = ds.map(lambda x: x ** 2).batch(20)for x in ds_map_batch :     opera     # 先batch后mapds_batch_map = ds.batch(20).map(lambda x: x ** 2)for x in ds_batch_map :     opera
    • 参数迭代:依赖GPU来提升

数据转换

  • map:同Python的map,将转化函数映射到数据集每一个元素
  • flat_map:映射后将多维压平成一维
  • interleave:类似flat_map,但可以将不同来源的数据夹在一起
  • filter:过滤某些元素
  • zip:横向铰合
  • concatenate:纵向铰合
  • reduce:归并
  • batch:构建批次,每次一批。逆操作unbatch
  • padded_batch:构建批次,类似batch,但可以填充到相同的形状
  • window:滑动窗口
  • shuffle:同np.shuffle
  • repeat:重复数据若干次
  • shard:采样,从某个位置开始隔固定距离采样一个元素
  • take:采样,类似top(n), head(n)

转载地址:https://blog.csdn.net/fish2009122/article/details/106238114 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:神经网络常见激活函数(包含tensorflow2的api)
下一篇:tensorflow2之数学运算

发表评论

最新留言

能坚持,总会有不一样的收获!
[***.219.124.196]2024年04月20日 22时26分16秒