注意力机制 YOLOv8添加注意力机制_yolov8引入注意力机制-程序员宅基地

技术标签: YOLO  python  yolov8  人工智能  开发语言  

一、注意力机制介绍:

注意力机制(Attention Mechanism)是深度学习中一种重要的技术,它可以帮助模型更好地关注输入数据中的关键信息,从而提高模型的性能。注意力机制最早在自然语言处理领域的序列到序列(seq2seq)模型中得到广泛应用,后来逐渐扩展到了计算机视觉、语音识别等多个领域。 

注意力机制的基本思想是为输入数据的每个部分分配一个权重,这个权重表示该部分对于当前任务的重要程度。在自然语言处理任务中,这通常意味着对输入句子中的每个单词分配一个权重,而在计算机视觉任务中,这可能意味着为输入图像的每个像素或区域分配一个权重。

添加方法

总结:1.在conv.py加入注意力代码

           2.在__init.oy__和tasks.py引用GAM

           3.修改yaml文件

1.在conv.py代码中加入注意力代码

conv.py的路径:ultralytics-main\ultralytics\nn\modules\conv.py 

如图下所示:

在conv.py的最下面添加注意力代码:

代码如下:

#-----------注意力机制代码-----------------
import torch.nn as nn
import torch
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels,c2, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
        x = x * x_channel_att
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c)
    y = net(x)
    print(y.size())

效果如图下所示:

 

 2.注册及引用GAM_Attention

__init__.py文件中引用GAM_Attention

路径:ultralytics-main\ultralytics\nn\modules\__init__.py

如图下:

在__init__.py文件中,在导包里面找到from .conv import和__all__,最后面添加GAM_Attention。

如图下所示:

tasks.py 文件中引用GAM_Attention

路径:ultralytics-main\ultralytics\nn\tasks.py

如图下:

在tasks.py文件中,在导包里面找到from ultralytics.nn.modules最后面添加GAM_Attention

如图下所示:

 在tasks.py里写入调用方式

打开tasks.py,Ctrl键+F查找n = 1(有空格)就可以找到添加的位置,如效果图:

        # """**************add Attention***************"""
        elif m in {GAM_Attention}:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if not output
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]

效果如图下所示:

 3.修改自己的yolov8.yaml文件:

路径如下:ultralytics-main\ultralytics\cfg\models\v8\my_yolov8.yaml

如图下所示:

 修改后的代码如下(可以直接复制到自己的yaml里面):

# Ultralytics YOLO , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8-SPPCSPC.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 3, GAM_Attention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

自己修改可以根据下图(修改后的图)红色箭头是需要修改的:

 

 完成以上就可以训练了

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/2301_79152843/article/details/132555870

智能推荐

【1+X Web前端等级考证 】 | Web前端开发中级理论 (附答案)_1+xweb前端开发中级-程序员宅基地

文章浏览阅读3.4w次,点赞77次,收藏438次。# 前言2020 12月 1+X Web 前端开发中级 模拟题大致就更这么多,我的重心不在这里,就不花太多时间在这里面了。但是,说说1+X Web前端开发等级考证这个证书,总有人跑到网上问:这个证书有没有用? 这个证书含金量高不高?# 关于考不考因为这个是工信部从2019年才开始实施试点的,目前还在各大院校试点中,就目前情况来看,知名度并不是很高,有没有用现在无法一锤定音,看它以后办的怎么样把,软考以前也是慢慢地才知名起来。能考就考吧,据所知,大部分学校报考,基本不用交什么报考费(小部分学校,个别除._1+xweb前端开发中级

Linux安装mysql8.0(官方教程!)-程序员宅基地

文章浏览阅读2.5w次,点赞46次,收藏345次。Linux安装mysql(官方教程!)_linux安装mysql8.0

win10微软账号登陆报错:0x80190001解决方案_微软账户登录0x8019001-程序员宅基地

文章浏览阅读3.5w次,点赞19次,收藏18次。win10微软账号登陆报错:0x80190001解决方案_微软账户登录0x8019001

从微软AzureDevOps看实施基于DevOps全流程软件交付-程序员宅基地

文章浏览阅读1.5k次。Azure DevOpsAzure DevOps 汇集人员、流程和技术,实现软件交付自动化,为用户提供持续的价值。借助 Azure DevOps 解决方案,帮助您全流程构建你的软件产品,它使流程和产品更可靠。Azure DevOps帮助你用敏捷工具计划项目;用Git管理你的代码;..._azure devops approve

Typora中使用LaTeX:多行公式左对齐_typora对齐公式-程序员宅基地

文章浏览阅读1.4w次,点赞7次,收藏40次。Typora中使用LaTeX:多行公式左对齐有时候公式太长,用=号对齐很难看(有的公式左边很长,右边很短),此时难免需要进行"公式左对齐"。所需要的环境还是"align"(或者是align*,不带公式编号)。语法如下:\begin{align*}\label{2} & X(0) = x(0)W_{N}^{0\cdot0} + x(1)W_{N}^{0\cdot1} + \cdots + x(N-1)W_{N}^{0\cdot(N-1)}\\ & X(1) = x(0)W_{N}_typora对齐公式

springboot配置文件加载顺序, java启动参数优先级_nacos默认覆盖本地吗-程序员宅基地

文章浏览阅读1k次。(12)、jar包外面的 Profile-specific application properties (application- {profile} .properties和YAML)(13)、jar包内的 Profile-specific application properties (application-{profile}.properties和YAML)(1)、在您的HOME目录设置的Devtools全局属性(~/.spring-boot-devtools.properties)。_nacos默认覆盖本地吗

随便推点

公钥私钥加解密原理_公钥加解密-程序员宅基地

文章浏览阅读4.5k次。一、文章来由网络安全课花了不少篇幅讲解非对称加密技术,做一个整理。二、基本概念 公开密钥加密,也称为非对称加密(asymmetric cryptography)。在這種密碼學方法中,需要一對金鑰,一個是私人金鑰,另一個則是公開金鑰。这两个密钥是数学相关,用某用户密钥加密后所得的信息,只能用该用户的解密密钥才能解密。如果知道了其中一个,并不能计算出另外一个。因此如果公开了一对密钥中的一个,并不会危_公钥加解密

项目管理-概述_项目过程控制 确保 项目 资金-程序员宅基地

文章浏览阅读119次。项目管理的工具虽多,但要记住一点:所有的工具,只有在对的时间,用在对的地方,才能真正指导实际工作。管理是一门大的学问,理论知识只是基础,更多的需要靠实践,从被管理者到管理者的过程,不仅仅只是一个角色的转变,正所谓不在其位,不谋其政,很多东西,也只有走上管理岗位才能慢慢体会了。_项目过程控制 确保 项目 资金

lua面向对象编程之点号与冒号的差异详细比较-程序员宅基地

文章浏览阅读41次。首先,先来一段在lua创建一个类与对象的代码 1 Class = {} 2 Class.__index = Class 3 4 function Class:new(x,y) 5 local temp = {} 6 setmetatable(temp, Class) 7 temp.x = x 8 temp.y = y 9 return...

百度云虚假下载_虚假新闻:关于公共云的5种常见误解-程序员宅基地

文章浏览阅读212次。百度云虚假下载 In the complex world of IT, there are many misconceptions about migrating to the public cloud. Some of these portray the public cloud as the panacea for every IT issue, whereas others consider..._from diggers to data centres从淘金企业到数据中心

Tesseract图像识别OCR的学习1_tesseract doocr-程序员宅基地

文章浏览阅读1.1k次。领导让做一个识别发票的服务,之前都是写增删改查,完全没接触过图像识别这种高大上的东西,记录一下吧新建一个项目,导入tess4j <dependency> <groupId>net.sourceforge.tess4j</groupId> <artifactId>tess4j&l..._tesseract doocr

推荐文章

热门文章

相关标签