欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 养生 > TopK activation function(TopK激活函数)

TopK activation function(TopK激活函数)

2025/4/4 23:55:20 来源:https://blog.csdn.net/u010165147/article/details/139993233  浏览:    关键词:TopK activation function(TopK激活函数)

最近看了一篇关于topk激活函数的文章[1]就顺便实现了一下,测试了一下收敛、运行速度和最后的精度基本和ReLU差别不大,topk激活函数有一个优点就是激活的节点数是确定的,不会产生死区,也可以自由控制特征向量的稀疏程度,相对来说ReLU则不可控。顺便也实现了基于topk的池化,这个池化速度较慢。有需要的可以参考。

import torch
import torch.nn
import torch.nn as nn
import torch.nn.functional as Fclass TopKLU(nn.Module):def __init__(self,active_ratio = 0.5):  super(TopKLU, self).__init__()self.active_ratio = active_ratiodef forward(self, x):size = x.size()topk = int(size[1]*self.active_ratio)topk = 1 if topk<1 else topkwith torch.no_grad():z = torch.zeros_like(x)_,indices = torch.topk(x,topk,1,True,False)z.scatter_(1,indices,1)return z*xclass TopKPool2d(nn.Module):def __init__(self,active_ratio = 0.5,kernel_size=3,stride = 1,padding = 0):  super(TopKPool2d, self).__init__()assert kernel_size>=2,"kernel_size should >= 2"assert stride>=1,"stride should >= 1"self.kernel_size=kernel_sizeself.padding = paddingself.stride = strideself.topk = int(active_ratio * kernel_size * kernel_size)self.topk = 1 if self.topk<1 else self.topkdef forward(self, x):size = x.size()with torch.no_grad():col = F.unfold(x,self.kernel_size,dilation=1,padding=self.padding,stride=self.stride)col_t = col.transpose(2,1)col_tr = col_t.reshape(col_t.size(0),col_t.size(1)*size[1],-1)_,indices = torch.topk(col_tr,self.topk,-1,True,False)z = torch.zeros_like(col_tr)z.scatter_(2,indices,1)z = z.reshape(col_t.size())z = z.transpose(2,1)z = F.fold(z,(size[2],size[3]),self.kernel_size,dilation=1,padding=self.padding,stride=self.stride)z = z.gt(0).to(torch.float32)return z * xif __name__=="__main__":act = TopKPool2d(kernel_size=2,stride=2)x = torch.randn(2,2,4,4)print(x)print(act(x))act = TopKLU()x = torch.randn(1,2,3,3)print(x)print(act(x))x = torch.randn(2,4)print(x)print(act(x))

参考

  1. Scaling and evaluating sparse autoencoders

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词