caffe lstm_unit_layer.cu源码解析
发布日期:2021-08-19 11:10:18 浏览次数:4 分类:技术文章

本文共 5117 字,大约阅读时间需要 17 分钟。

源码的解析都写到文件里面如下

#include 
#include
#include
#include "caffe/layer.hpp"#include "caffe/layers/lstm_layer.hpp"namespace caffe {//sigmoid函数 1/(1+e^(-x))template
__device__ Dtype sigmoid(const Dtype x) { return Dtype(1) / (Dtype(1) + exp(-x));}//tanh函数 2*sigmoid(2*x) -1template
__device__ Dtype tanh(const Dtype x) { return Dtype(2) * sigmoid(Dtype(2) * x) - Dtype(1);}//X_acts:经过激活函数之后X的值template
__global__ void LSTMActsForward(const int nthreads, const int dim, const Dtype* X, Dtype* X_acts) { CUDA_KERNEL_LOOP(index, nthreads) { const int x_dim = 4 * dim; const int d = index % x_dim; if (d < 3 * dim) { X_acts[index] = sigmoid(X[index]);//对应于黄色的delta模块 } else { X_acts[index] = tanh(X[index]);//对应于tanh模块 } }}template
__global__ void LSTMUnitForward(const int nthreads, const int dim, const Dtype* C_prev, const Dtype* X, const Dtype* cont, Dtype* C, Dtype* H) { CUDA_KERNEL_LOOP(index, nthreads) { const int n = index / dim; const int d = index % dim; const Dtype* X_offset = X + 4 * dim * n; const Dtype i = X_offset[d]; //i(t) const Dtype f = X_offset[1 * dim + d];//f(t) const Dtype o = X_offset[2 * dim + d];//o(t) const Dtype g = X_offset[3 * dim + d];//c(~t) const Dtype c_prev = C_prev[index]; //C(t-1) const Dtype c = cont[n] * f * c_prev + i * g;//对应于C(t)=f(t)*C(t-1) + i(t)*C(~t) C[index] = c; const Dtype tanh_c = tanh(c); H[index] = o * tanh_c; //对应于 h(t) = o(t)*tanh(C(t)) }}template
void LSTMUnitLayer
::Forward_gpu(const vector
*>& bottom, const vector
*>& top) { const int count = top[1]->count(); const Dtype* C_prev = bottom[0]->gpu_data();//输入C(t-1) const Dtype* X = bottom[1]->gpu_data(); //输入x(t) const Dtype* cont = bottom[2]->gpu_data();//应该是h(t-1) ? Dtype* X_acts = X_acts_.mutable_gpu_data(); Dtype* C = top[0]->mutable_gpu_data();//一个输出C(t) Dtype* H = top[1]->mutable_gpu_data();//另一个输出h(t),相当于return的两个结果 const int X_count = bottom[1]->count(); // NOLINT_NEXT_LINE(whitespace/operators) LSTMActsForward
<<
>>( X_count, hidden_dim_, X, X_acts); CUDA_POST_KERNEL_CHECK; // NOLINT_NEXT_LINE(whitespace/operators) LSTMUnitForward
<<
>>( count, hidden_dim_, C_prev, X_acts, cont, C, H); CUDA_POST_KERNEL_CHECK;}template
__global__ void LSTMUnitBackward(const int nthreads, const int dim, const Dtype* C_prev, const Dtype* X, const Dtype* C, const Dtype* H, const Dtype* cont, const Dtype* C_diff, const Dtype* H_diff, Dtype* C_prev_diff, Dtype* X_diff) { CUDA_KERNEL_LOOP(index, nthreads) { const int n = index / dim; const int d = index % dim; const Dtype* X_offset = X + 4 * dim * n; const Dtype i = X_offset[d]; const Dtype f = X_offset[1 * dim + d]; const Dtype o = X_offset[2 * dim + d]; const Dtype g = X_offset[3 * dim + d]; const Dtype c_prev = C_prev[index]; const Dtype c = C[index]; const Dtype tanh_c = tanh(c); Dtype* c_prev_diff = C_prev_diff + index; Dtype* X_diff_offset = X_diff + 4 * dim * n; Dtype* i_diff = X_diff_offset + d; //相当于fc层的一个输出 Dtype* f_diff = X_diff_offset + 1 * dim + d; Dtype* o_diff = X_diff_offset + 2 * dim + d; Dtype* g_diff = X_diff_offset + 3 * dim + d; const Dtype c_term_diff = C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c); const Dtype cont_n = cont[n]; *c_prev_diff = cont_n * c_term_diff * f;//C(t-1)改变量 *i_diff = c_term_diff * g;//i(t)改变量 *f_diff = cont_n * c_term_diff * c_prev;//f(t)改变量 *o_diff = H_diff[index] * tanh_c;//o(t)改变量 *g_diff = c_term_diff * i; //c(~t)改变量 }}//激活函数部分的反向template
__global__ void LSTMActsBackward(const int nthreads, const int dim, const Dtype* X_acts, const Dtype* X_acts_diff, Dtype* X_diff) { CUDA_KERNEL_LOOP(index, nthreads) { const int x_dim = 4 * dim; const int d = index % x_dim; const Dtype X_act = X_acts[index]; if (d < 3 * dim) { X_diff[index] = X_acts_diff[index] * X_act * (Dtype(1) - X_act); } else { X_diff[index] = X_acts_diff[index] * (Dtype(1) - X_act * X_act); } }}template
void LSTMUnitLayer
::Backward_gpu(const vector
*>& top, const vector
& propagate_down, const vector
*>& bottom) { CHECK(!propagate_down[2]) << "Cannot backpropagate to sequence indicators."; if (!propagate_down[0] && !propagate_down[1]) { return; } const int count = top[1]->count(); const Dtype* C_prev = bottom[0]->gpu_data(); const Dtype* X_acts = X_acts_.gpu_data(); const Dtype* cont = bottom[2]->gpu_data(); const Dtype* C = top[0]->gpu_data(); const Dtype* H = top[1]->gpu_data(); const Dtype* C_diff = top[0]->gpu_diff(); const Dtype* H_diff = top[1]->gpu_diff(); Dtype* C_prev_diff = bottom[0]->mutable_gpu_diff(); Dtype* X_acts_diff = X_acts_.mutable_gpu_diff(); LSTMUnitBackward
// NOLINT_NEXT_LINE(whitespace/operators) <<
>>(count, hidden_dim_, C_prev, X_acts, C, H, cont, C_diff, H_diff, C_prev_diff, X_acts_diff); CUDA_POST_KERNEL_CHECK; const int X_count = bottom[1]->count(); Dtype* X_diff = bottom[1]->mutable_gpu_diff(); LSTMActsBackward
// NOLINT_NEXT_LINE(whitespace/operators) <<
>>( X_count, hidden_dim_, X_acts, X_acts_diff, X_diff); CUDA_POST_KERNEL_CHECK;}INSTANTIATE_LAYER_GPU_FUNCS(LSTMUnitLayer);} // namespace caffe

转载于:https://www.cnblogs.com/hellokittyblog/p/9128459.html

转载地址:https://blog.csdn.net/weixin_30924087/article/details/97031043 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:centos7下kubernetes(6。kubernetes创建资源的两种方式)
下一篇:ADO.Net的小知识(连接数据库)

发表评论

最新留言

第一次来,支持一个
[***.219.124.196]2024年04月01日 22时17分44秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章

【C++】攻克哈希表(unordered_map) 2019-04-27
转:【答学员问】- 该如何根据岗位学习相关技能 2019-04-27
转:【答学员问】有什么经验教训,是你在面试很多次之后才知道的? 2019-04-27
消息队列:解耦、异步、削峰,现有MQ对比以及新手入门该如何选择MQ? 2019-04-27
【奇技淫巧】-- 三角形最小路径和 2019-04-27
【小技巧】argc和argv的用法 2019-04-27
学不下去了怎么办? 2019-04-27
二叉树的前中后序遍历(迭代法)(带动画) 2019-04-27
【小技巧】【XShell】【Xftp】Windows桌面与Linux虚拟机互传文件 2019-04-27
【redis入门】Centos下安装redis 2019-04-27
【redis入门】redis安装后相关知识串讲 2019-04-27
【redis】来吧,展示一下redis 发布-订阅模式 2019-04-27
讲通C/C++预编译/条件编译指令 #ifdef,#ifndef,#endif,#define,… 2019-04-27
【redis6.0.6】redis源码慢慢学,慢慢看 -- 第二天:空间配置(zmalloc) 2019-04-27
当下热点词再学:redis缓存预热、更新、降级,限流 2019-04-27
【redis6.0.6】redis源码慢慢学,慢慢看 -- 第五天:adlist 2019-04-27
别抖,OK? 操作系统抖动现象、网络抖动与延迟、函数抖动之防抖与节流,串讲 2019-04-27
第六天:网络处理(anet部分)-- redis源码慢慢学,慢慢看【redis6.0.6】 2019-04-27
通过域名获取主机IP -- struct addrinfo 2019-04-27
【C++】算法集锦(8):从两数和问题拓展到一百数和问题 2019-04-27