pytorch实现bnn_FPGA硅农的博客-程序员ITS304

技术标签: python  

链接

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Function

# ********************* 二值(+-1) ***********************
# A
class Binary_a(Function):

    @staticmethod
    def forward(self, input):
        self.save_for_backward(input)
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        input, = self.saved_tensors
        #*******************ste*********************
        grad_input = grad_output.clone()
        #****************saturate_ste***************
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input
# W
class Binary_w(Function):

    @staticmethod
    def forward(self, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input
# ********************* 三值(+-1、0) ***********************
class Ternary(Function):

    @staticmethod
    def forward(self, input):
        # **************** channel级 - E(|W|) ****************
        E = torch.mean(torch.abs(input), (3, 2, 1), keepdim=True)
        # **************** 阈值 ****************
        threshold = E * 0.7
        # ************** W —— +-1、0 **************
        output = torch.sign(torch.add(torch.sign(torch.add(input, threshold)),torch.sign(torch.add(input, -threshold))))
        return output, threshold

    @staticmethod
    def backward(self, grad_output, grad_threshold):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input

# ********************* A(特征)量化(二值) ***********************
class activation_bin(nn.Module):
  def __init__(self, A):
    super().__init__()
    self.A = A
    self.relu = nn.ReLU(inplace=True)

  def binary(self, input):
    output = Binary_a.apply(input)
    return output

  def forward(self, input):
    if self.A == 2:
      output = self.binary(input)
      # ******************** A —— 1、0 *********************
      #a = torch.clamp(a, min=0)
    else:
      output = self.relu(input)
    return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):
    mean = w.data.mean(1, keepdim=True)
    w.data.sub(mean) # W中心化(C方向)
    w.data.clamp(-1.0, 1.0) # W截断
    return w
class weight_tnn_bin(nn.Module):
  def __init__(self, W):
    super().__init__()
    self.W = W

  def binary(self, input):
    output = Binary_w.apply(input)
    return output

  def ternary(self, input):
    output = Ternary.apply(input)
    return output

  def forward(self, input):
    if self.W == 2 or self.W == 3:
        # **************************************** W二值 *****************************************
        if self.W == 2:
            output = meancenter_clampConvParams(input) # W中心化+截断
            # **************** channel级 - E(|W|) ****************
            E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)
            # **************** α(缩放因子) ****************
            alpha = E
            # ************** W —— +-1 **************
            output = self.binary(output)
            # ************** W * α **************
            output = output * alpha # 若不需要α(缩放因子),注释掉即可
            # **************************************** W三值 *****************************************
        elif self.W == 3:
            output_fp = input.clone()
            # ************** W —— +-1、0 **************
            output, threshold = self.ternary(input)
            # **************** α(缩放因子) ****************
            output_abs = torch.abs(output_fp)
            mask_le = output_abs.le(threshold)
            mask_gt = output_abs.gt(threshold)
            output_abs[mask_le] = 0
            output_abs_th = output_abs.clone()
            output_abs_th_sum = torch.sum(output_abs_th, (3, 2, 1), keepdim=True)
            mask_gt_sum = torch.sum(mask_gt, (3, 2, 1), keepdim=True).float()
            alpha = output_abs_th_sum / mask_gt_sum # α(缩放因子)
            # *************** W * α ****************
            output = output * alpha # 若不需要α(缩放因子),注释掉即可
    else:
      output = input
    return output

# ********************* 量化卷积(同时量化A/W,并做卷积) ***********************
class Conv2d_Q(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        A=2,
        W=2
      ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        # 实例化调用A和W量化器
        self.activation_quantizer = activation_bin(A=A)
        self.weight_quantizer = weight_tnn_bin(W=W)
          
    def forward(self, input):
        # 量化A和W
        bin_input = self.activation_quantizer(input)
        tnn_bin_weight = self.weight_quantizer(self.weight)    
        #print(bin_input)
        #print(tnn_bin_weight)
        # 用量化后的A和W做卷积
        output = F.conv2d(
            input=bin_input, 
            weight=tnn_bin_weight, 
            bias=self.bias, 
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation, 
            groups=self.groups)
        return output

# *********************量化(三值、二值)卷积*********************
class Tnn_Bin_Conv2d(nn.Module):
    # 参数:last_relu-尾层卷积输入激活
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, A=2, W=2):
        super(Tnn_Bin_Conv2d, self).__init__()
        self.A = A
        self.W = W
        self.last_relu = last_relu

        # ********************* 量化(三/二值)卷积 *********************
        self.tnn_bin_conv = Conv2d_Q(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, A=A, W=W)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.tnn_bin_conv(x)
        x = self.bn(x)
        if self.last_relu:
            x = self.relu(x)
        return x

class Net(nn.Module):
    def __init__(self, cfg = None, A=2, W=2):
        super(Net, self).__init__()
        # 模型结构与搭建
        if cfg is None:
            cfg = [192, 160, 96, 192, 192, 192, 192, 192]
        self.tnn_bin = nn.Sequential(
                nn.Conv2d(3, cfg[0], kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(cfg[0]),
                Tnn_Bin_Conv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, A=A, W=W),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                Tnn_Bin_Conv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, A=A, W=W),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                Tnn_Bin_Conv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, last_relu=1, A=A, W=W),
                nn.Conv2d(cfg[7],  10, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(10),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.tnn_bin(x)
        x = x.view(x.size(0), -1)
        return x
import sys
import math
import numpy as np
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import os

device = torch.device('cuda:0')

# 随机种子——训练结果可复现
def setup_seed(seed):
    torch.manual_seed(seed)                                 
    torch.cuda.manual_seed_all(seed)           
    np.random.seed(seed)                       
    torch.backends.cudnn.deterministic = True

# 训练lr调整
def adjust_learning_rate(optimizer, epoch):
    update_list = [10,20,30,40,50]
    if epoch in update_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.5
    return

# 模型训练
def train(epoch):
    model.train()

    for batch_idx, (data, target) in enumerate(trainloader):
        # 前向传播
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss = criterion(output, target)

        # 反向传播
        optimizer.zero_grad()
        loss.backward() # 求梯度
        optimizer.step() # 参数更新

        # 显示训练集loss(/100个batch)
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR: {}'.format(
                epoch, batch_idx * len(data), len(trainloader.dataset),
                100. * batch_idx / len(trainloader), loss.data.item(),
                optimizer.param_groups[0]['lr']))
    return

# 模型测试
def test():
    global best_acc
    model.eval()
    test_loss = 0
    average_test_loss = 0
    correct = 0

    for data, target in testloader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        # 前向传播
        output = model(data)
        test_loss += criterion(output, target).data.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    # 测试准确率
    acc = 100. * float(correct) / len(testloader.dataset)

    print(acc)

if __name__=='__main__':
    setup_seed(1)#随机种子——训练结果可复现

    # 训练集:随机裁剪 + 水平翻转 + 归一化
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    # 测试集:归一化
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    # 数据加载
    trainset = torchvision.datasets.CIFAR10(root='./data',train = True, download = True, transform = transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) # 训练集数据

    testset = torchvision.datasets.CIFAR10(root='./data',train = False, download = True, transform = transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=2) # 测试集数据

    # cifar10类别
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    print('******Initializing model******')
    # ******************** 在model的量化卷积中同时量化A(特征)和W(模型参数) ************************
    model = Net(A=2, W=2)
    best_acc = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()
    

    # cpu、gpu
    model.to(device)
    # 打印模型结构
    print(model)

    # 超参数
    param_dict = dict(model.named_parameters())
    params = []
    for key, value in param_dict.items():
        params += [{
    'params':[value], 'lr': 0.01, 'weight_decay':0.0}]

    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.Adam(params, lr=0.01, weight_decay=0.0)

    # 训练模型
    for epoch in range(1, 300):
        adjust_learning_rate(optimizer, epoch)
        train(epoch)
        test()

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_40268672/article/details/108730643

智能推荐

DCMTK开发笔记(一):我的第一个DCMTK demo_CaLMdoWN_的博客-程序员ITS304_dcmtk

实验平台Visual Studio 2010已编译的 DCMTK 3.6.2 Debug x64版本实验步骤在VS2010中新建Visual C++ Win32 控制台应用程序 空项目,命名为DcmtkDemo在源文件中添加新建项 .cpp文件,命名为main.cpp粘贴如下测试代码#include <stdio.h>#include <tchar.h>...

如何使用pycharm在工程中新建venv环境--(venv 三)_绛洞花主敏明的博客-程序员ITS304_pycharm venv

问题:使用pycharm打开从git上下载的项目后,会发现项目实际上中并不存在项目需要的环境,此时,就需要根据项目中的requirement.txt文件新建环境。实现步骤一、首先使用pycharm打开项目,发现项目中不存在venv环境。二、file --> setting --> project --> project interpreter --> 新建环境...

unity 监听build前、build完成后事件_凡情的博客-程序员ITS304

using System.Collections;using System.Collections.Generic;using System.IO;using UnityEditor;using UnityEditor.Build;using UnityEditor.Callbacks;using UnityEngine;// 实现接口的方式public class BuildReport:IPostprocessBuildWithReport,IPreprocessBuildWithRe

Webapi之文件上传_Song_Lynn的博客-程序员ITS304_webapi 文件上传

Webapi之文件上传范例说明:前端:vue.js + element-ui + axios后端:c# webapi先上传存储起来,然后再读取文件仅尝试过在本地调试,未验证服务器前端部分使用element-ui的上传组件// html 直接调用api<el-upload class="upload-demo" ref...

java中Scanner类nextLine()和next()的使用方法和注意事项_羽涵w的博客-程序员ITS304_scanner.nextline

Scanner实现字符串的输入有两种方法,一种是next(),一种nextLine()。next():一定要读取到有效字符后才可以结束输入,对输入有效字符之前遇到的空格键、Tab键或Enter键等结束符,next()方法会自动将其去掉,只有在输入有效字符之后,next()方法才将其后输入的空格键、Tab键或Enter键等视为分隔符或结束符。简单地说:next()查找并返回来自此扫描器的下一个完整标记。完整标记的前后是与分隔模式匹配的输入信息,所以next方法不能得到带空格的字符串。nextLine(

React-native 安装基础篇_srxboys的博客-程序员ITS304_安装reactnative基础库

React-native 安装基础篇 RN官方文档 (0.55): - http://facebook.github.io RN 中文翻译 文档 (0.51): - https://reactnative.cn 推荐博客 ES6 语法学习(阮一峰) - http://es6.ruanyifeng.com以下基于MacOS 一...

随便推点

IP协议介绍_杜某1997的博客-程序员ITS304_ip协议解决什么问题

1、IP协议解决的问题实际的互联网络是错综复杂的,物理设备通过使用IP协议,屏蔽了物理网络之间的差异,网络中的主机使用IP协议连接时,就无需关注网络细节。1、使复杂的实际网络变成一个虚拟互联网络。2、使网络层可以屏蔽掉底层细节,专注于网络层的数据转发。3、解决了在虚拟网络中数据报传输路径的问题。2、IP数据报在物理层中传输的数据是比特流,在数据链路层中将数据封装成帧,在网络层中将帧数据表示成IP数据报。IP数据报分为IP首部和IP数据报的数据。其中IP首部是重点学习内容。4位版本号:指I

全国地区对应身份证号码值关系----身份证号前6位_花言简的博客-程序员ITS304

省份码值 省份 城市码值 城市 县级 县级 11 北京市 1100 北京市 110000 北京市 11 北京市 1101 北京市市辖区 110101 东城区 11 北京市 1101 北京市市辖区 110102 西城区 11 北京市 1101 北京...

C# WinAPI 编程详解(一)_yang_B621的博客-程序员ITS304_c# winapi

C# WIN32 API编程最近要实现一个微信/QQ自动定时发送推送的小工具 ,用到API编程,下面一起开始学习Win32 API编程吧!!!C# 用户经常提出2两个问题:“我为什么要另外编写代码来使用内置于Windows中的功能?在框架中为什么没有相应的内容可以让我们直接完成这一任务呢?”当框架小组构建它们的.NET部分时,他们评估了为使.NET程序猿可以使用Win32...

wx:key 详解及其 警告处理_荒--的博客-程序员ITS304

<block wx:for="{{movies}}" wx:for-item="movie">该代码在循环的时候控制台会警告warning, 如果明确知道该列表是静态,或者不必关注其顺序,可以选择忽略不影响使用可以修改如下<block wx:for="{{movies}}" wx:key="movies" wx:for-item="movie">wx:key是用来告诉程序按照某个key去排序这个组件,例如wx:key="Id",此时组件顺序就会按照你arr中Id..

itoa函数使用--c语言[email protected]桐同学的博客-程序员ITS304_c语言itoa函数用法

1.随意输出整数的二进制形式 这个时候我们可以任意打印整数的二进制形式我们如果要想看-1的二进制的话 我们会看到-1的补码#include <stdio.h>#include <stdlib.h>int main(){ char t[50];//注意这个数组的大小要足够包含我们想要的内容 _itoa(9, t, 2);//itoa i to arry i整数转换到数组或字符串(里面包含\0) 这句代码意思是将整数9放到这个数组里面并以二进制形式储存 puts(

推荐文章

热门文章

相关标签