Java学习之日撸代码300行(71-80天,BP 神经网络)_每天练习敲300行代码-程序员宅基地

技术标签: 机器学习  

原博文:minfanphd

第71天:BP神经网络基础类 (数据读取与基本结构)

以下资料来自于 《神经网络与深度学习》-邱锡鹏 一书

在这里插入图片描述

每一层的神经元可以接收前一层神经元的信号,并产生信号输出到下一层。第0 层称为输入层,最后一层称
为输出层,其他中间层称为隐藏层。

z ( l ) = W ( l ) a ( l − 1 ) + b ( l ) , \begin{aligned} z^{(l)} = W^{(l)}a^{(l-1)}+b^{(l)}, \end{aligned} z(l)=W(l)a(l1)+b(l),
a ( l ) = f l ( z ( l ) ) . \begin{aligned} a^{(l)} = f_l(z^{(l)}). \end{aligned} a(l)=fl(z(l)).

z ( l ) z^{(l)} z(l) 表示第 l l l 层的净输入,也就是值没有经过激活函数的输入。
a ( l ) a^{(l)} a(l) 则是指的经过激活函数后的输出。
W , b W,b W,b表示网络中所有层的连接权重和偏置。

package MachineLearning.ann;

import weka.core.Instances;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

/**
 * @description:抽象类Ann
 * @learner: Qing Zhang
 * @time: 07
 */
public abstract class GeneralAnn {
    
    //数据集
    Instances dataset;

    //层的数量
    int numLayers;

    //每层的节点数量,如[3, 4, 6, 2]表示输入层有三个节点,隐藏层分别有4个和6个节点,输出层有两个节点,二分类
    int[] layerNumNodes;

    //动量系数(Momentum coefficient)
    public double mobp;

    //学习率
    public double learningRate;

    //随机种子
    Random random = new Random();

    /** 
    * @Description: 构造函数
    * @Param: [paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp]
    * @return: 
    */
    public GeneralAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
    
        try{
    
            FileReader tempReader = new FileReader(paraFileName);
            dataset = new Instances(tempReader);
            dataset.setClassIndex(dataset.numAttributes()-1);
            tempReader.close();
        }catch (Exception ee){
    
            System.out.println("Error occurred while trying to read \'" + paraFileName
                    + "\' in GeneralAnn constructor.\r\n" + ee);
            System.exit(0);
        }
        
        //接收参数
        layerNumNodes = paraLayerNumNodes;
        numLayers = layerNumNodes.length;
        learningRate = paraLearningRate;
        layerNumNodes[0] = dataset.numAttributes() - 1;
        layerNumNodes[numLayers - 1] = dataset.numClasses();
        mobp = paraMobp;
    }

    /**
    * @Description: 前向预测
    * @Param: [paraInput]
    * @return: double[]
    */
    public abstract double[] forward(double[] paraInput);



    /**
    * @Description: 反向传播
    * @Param: [paraTarget]
    * @return: void
    */
    public abstract void backPropagation(double[] paraTarget);


    /**
    * @Description: 训练
    * @Param: []
    * @return: void
    */
    public void train(){
    
        double[] tempInput = new double[dataset.numAttributes() - 1];
        double[] tempTarget = new double[dataset.numClasses()];
        for (int i = 0; i < dataset.numInstances(); i++) {
    
            //填充数据
            for (int j = 0; j < tempInput.length; j++) {
    
                tempInput[j] = dataset.instance(i).value(j);
            }

            //填充类标签
            Arrays.fill(tempTarget, 0);
            tempTarget[(int) dataset.instance(i).classValue()] = 1;

            //使用该实例训练
            forward(tempInput);
            backPropagation(tempTarget);
        }
    }

    /**
     * @Description: 获取数组的最大值对应的索引
     * @Param: [paraArray]
     * @return: int
     */
    public static int argmax(double[] paraArray) {
    
        int resultIndex = -1;
        double tempMax = -1e10;
        for (int i = 0; i < paraArray.length; i++) {
    
            if (tempMax < paraArray[i]) {
    
                tempMax = paraArray[i];
                resultIndex = i;
            }
        }

        return resultIndex;
    }

    /**
     * @Description: 使用数据集测试
     * @Param: []
     * @return: double
     */
    public double test() {
    
        double[] tempInput = new double[dataset.numAttributes() - 1];

        double tempNumCorrect = 0;
        double[] tempPrediction;
        int tempPredictedClass = -1;

        for (int i = 0; i < dataset.numInstances(); i++) {
    
            //填充数据
            for (int j = 0; j < tempInput.length; j++) {
    
                tempInput[j] = dataset.instance(i).value(j);
            }

            //使用该实例训练
            tempPrediction = forward(tempInput);
            System.out.println("prediction: " + Arrays.toString(tempPrediction));
            tempPredictedClass = argmax(tempPrediction);
            if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
    
                tempNumCorrect++;
            }
        }

        System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());

        return tempNumCorrect / dataset.numInstances();
    }
}

第72天:固定激活函数的BP神经网络 (1. 网络结构理解)

  1. layerNumNodes 表示网络基本结构. 如: [3, 4, 6, 2] 表示:
    a) 输入端口有 3 个,即数据有 3 个条件属性. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 54 行.
    b) 输出端口有 2 个, 即数据的决策类别数为 2. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 55 行. 对于分类问题, 数据是哪个类别, 对应于输出值最大的端口.
    c) 有两个中间层(也就是隐藏层), 分别为 4 个和 6 个节点.
    纠正的原因主要还是需要跟数据集一致,毕竟这里的参数是人为设置,那么可能会出现错误,因此根据数据集的实际情况做纠正会更加严谨。
  2. layerNodeValues 表示各网络节点的值. 如上例, 网络的节点有 4 层, 即 layerNodeValues.length 为 4. 总结点数为 3 + 4 + 6 + 2 = 15 \mathbf{3 + 4 + 6 + 2 = 15} 3+4+6+2=15 个, 即 layerNodeValues[0].length = 3, layerNodeValues[1].length = 4, layerNodeValues[2].length = 6, layerNodeValues[3].length = 2. Java 支持这种不规则的矩阵 (不同行的列数不同), 因为二维矩阵被当作一维向量的一维向量.
  3. layerNodeErrors 表示各网络节点上的误差. 该数组大小于 layerNodeValues 一致.
  4. edgeWeights 表示各条边的权重. 由于两层之间的边为多对多关系 (二维数组), 多个层的边就成了三维数组. 例如, 上面例子的第 0 层就应该有 ( 3 + 1 ) × 4 = 16 \mathbf{( 3 + 1 ) \times 4 = 16} (3+1)×4=16 条边, 这里 + 1 \mathbf{+1} +1 表示有偏移量 offset. 总共的层数为 4 − 1 = 3 \mathbf{4 − 1 = 3} 41=3 , 即边的层数要比节点层数少 1. 这也是写程序过程中非常容易出错的地方.
  5. edgeWeightsDelta 与 edgeWeights 具有相同大小, 它辅助后者进行调整.

这里需要了解一下相应的优化函数,目前使用的是 momentum 动量法,具体的思想可以移步至这篇帖子
深度学习优化函数详解(4)-- momentum 动量法

下面是核心代码:

package MachineLearning.ann;

import weka.core.Instances;

import java.io.FileReader;

/**
 * @description:
 * @learner: Qing Zhang
 * @time: 07
 */
public class SimpleAnn extends GeneralAnn {
    


    //前向传播过程中每个节点变化的值。第一维表示层,第二维表示节点
    public double[][] layerNodeValues;

    //反向传播过程中每个节点变化的错误。第一维表示层,第二维表示节点
    public double[][] layerNodeErrors;

    //边的权值。第一维表示层,第二维表示该层的节点下标,第三维表示下一层的节点下标
    public double[][][] edgeWeights;

    //边的权值变化值。它的大小与edgeWeights相同
    public double[][][] edgeWeightsDelta;

    /**
     * @Description: 构造函数
     * @Param: [paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp]
     * @return:
     */
    public SimpleAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
    
        super(paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp);

        //层层初始化
        layerNodeValues = new double[numLayers][];
        layerNodeErrors = new double[numLayers][];
        edgeWeights = new double[numLayers - 1][][];
        edgeWeightsDelta = new double[numLayers - 1][][];

        //层内初始化
        for (int l = 0; l < numLayers; l++) {
    
            layerNodeValues[l] = new double[layerNumNodes[l]];
            layerNodeErrors[l] = new double[layerNumNodes[l]];

            //后面初始化边时需要少一层,因为每条边穿过两层
            if (l + 1 == numLayers) {
    
                break;
            }

            //在 layerNumNodes[l] + 1,最后一个为偏移保留。
            edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
            edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];

            for (int i = 0; i < layerNumNodes[l] + 1; i++) {
    
                for (int j = 0; j < layerNumNodes[l + 1]; j++) {
    
                    //初始化权值
                    edgeWeights[l][i][j] = random.nextDouble();
                }
            }
        }
    }


    @Override
    public double[] forward(double[] paraInput) {
    
        //初始化输入层
        for (int i = 0; i < layerNodeValues[0].length; i++) {
    
            layerNodeValues[0][i] = paraInput[i];
        }

        //计算每层的节点值
        double z;
        for (int l = 1; l < numLayers; l++) {
    
            for (int j = 0; j < layerNodeValues[l].length; j++) {
    
                //根据偏置初始化,偏置为 +1
                //这里是先加上偏置
                z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
                //将所有边的加权和给该节点使用
                for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
    
                    z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
                }

                //Sigmoid 激活函数
                //对于其他激活函数,这一行应该更改。
                layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
            }
        }
        return layerNodeValues[numLayers - 1];
    }

    @Override
    public void backPropagation(double[] paraTarget) {
    
        //初始化输出层错误
        int l = numLayers - 1;
        for (int j = 0; j < layerNodeErrors[l].length; j++) {
    
            layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * (paraTarget[j] - layerNodeValues[l][j]);
        }

        //反向传播直到 l==0
        while (l > 0) {
    
            l--;
            //第l层的每个节点
            for (int j = 0; j < layerNumNodes[l]; j++) {
    
                double z = 0.0;
                //下一层的每个节点
                for (int i = 0; i < layerNumNodes[l + 1]; i++) {
    
                    if (l > 0) {
    
                        z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
                    }

                    //调整权值
                    edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i] + learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
                    edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];

                    if (j == layerNumNodes[l] - 1) {
    
                        //调整偏移部分的权值
                        edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
                                + learningRate * layerNodeErrors[l + 1][i];
                        edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
                    }
                }

                //根据Sigmoid的微分记录错误。
                //对于其他激活函数,这一行应该更改。
                layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
            }
        }
    }

    public static void main(String[] args) {
    
        int[] tempLayerNodes = {
    4, 8, 8, 3};
        BPNeuralNetwork tempNetwork = new BPNeuralNetwork("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff", tempLayerNodes, 0.01,
                0.6);

        for (int round = 0; round < 5000; round++) {
    
            tempNetwork.train();
        }

        double tempAccuray = tempNetwork.test();
        System.out.println("The accuracy is: " + tempAccuray);
    }
}

第73天:固定激活函数的BP神经网络 (2. 训练与测试过程理解)

  1. Forward 就是利用当前网络对一条数据进行预测的过程.
  2. BackPropagation 就是根据误差进行网络权重调节的过程.
  3. 训练的时候需要前向与后向, 测试的时候只需要前向.
  4. 这里只实现了 sigmoid 激活函数, 反向传播时的导数与正向传播时的激活函数相对应. 如果要换激活函数, 需要两个地方同时换.(这里需要重点去理解一下,因为后向结合了优化函数,因此需要根据相应的优化函数以及激活函数去调整代码)

第74天:通用BP神经网络 (1. 集中管理激活函数)

  1. 激活与求导是一个, 前者用于 forward, 后者用于 back-propagation.
  2. 有很多的激活函数, 它们的设计有相应准则, 如分段可导.
  3. 查资料补充几个未实现的激活函数.
  4. 进一步测试.

Sigmoid:
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述


Tanh:
σ ( x ) = 2 1 + e ( − 2 x ) − 1 \sigma(x) = \frac{2}{1+e^{(-2x)}}-1 σ(x)=1+e(2x)21
在这里插入图片描述


Arctan:
σ ( x ) = arctan ⁡ ( x ) \sigma(x) = \arctan(x) σ(x)=arctan(x)
在这里插入图片描述


Elu:
σ ( x ) = { x , x ≥ 0 α ( e x − 1 ) , x < 0 \sigma(x) = \begin{cases} x,x\geq 0\\ \alpha(e^x-1), x<0 \end{cases} σ(x)={ x,x0α(ex1),x<0
在这里插入图片描述


Identity:
σ ( x ) = x \sigma(x) = x σ(x)=x
在这里插入图片描述


Soft Sign:
σ ( x ) = { x 1 + x , x ≥ 0 x 1 − x , x < 0 \sigma(x) = \begin{cases} \frac{x}{1+x},x\geq 0\\ \frac{x}{1-x}, x<0 \end{cases} σ(x)={ 1+xx,x01xx,x<0
在这里插入图片描述


Soft Plus:
σ ( x ) = log ⁡ ( 1 + e x ) \sigma(x) = \log(1+e^x) σ(x)=log(1+ex)
在这里插入图片描述


Relu:
σ ( x ) = { x , x ≥ 0 0 , x < 0 \sigma(x) = \begin{cases} x,x\geq 0\\ 0, x<0 \end{cases} σ(x)={ x,x00,x<0
在这里插入图片描述


Leaky Relu:
σ ( x ) = { x , x ≥ 0 α x , x < 0 \sigma(x) = \begin{cases} x,x\geq 0\\ \alpha x, x<0 \end{cases} σ(x)={ x,x0αx,x<0

图像源码:

from matplotlib import pyplot as plt
import numpy as np
import math


def sigmoid_function(x):
    fz = []
    for num in x:
        fz.append(1 / (1 + math.exp(-num)))
    return fz


def sigmoid_test():
    x = np.arange(-10, 10, 0.01)
    fz = sigmoid_function(x)
    show_graph('Sigmoid Function', 'x', 'σ(x)', x, fz)


def tanh_function(x):
    fz = []
    for num in x:
        fz.append(2 / (1 + math.exp(-2 * num)) - 1)
    return fz

def tanh_test():
    x = np.arange(-10, 10, 0.01)
    fz = tanh_function(x)
    show_graph('Tanh Function', 'x', 'σ(x)', x, fz)


def arctan_function(x):
    fz = []
    for num in x:
        fz.append(math.atan(num))
    return fz


def arctan_test():
    x = np.arange(-50, 50, 0.01)
    fz = arctan_function(x)
    show_graph('Arctan Function', 'x', 'σ(x)', x, fz)


def elu_function(x, alpha):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num)
        else:
            fz.append(alpha * (math.exp(num) - 1))
    return fz


def elu_test():
    x = np.arange(-50, 50, 0.01)
    fz = elu_function(x, 0.5)
    show_graph('Elu Function', 'x', 'σ(x)', x, fz)


def identity_function(x):
    fz = []
    for num in x:
        fz.append(num)
    return fz


def identity_test():
    x = np.arange(-10, 10, 0.01)
    fz = identity_function(x)
    show_graph('Identity Function', 'x', 'σ(x)', x, fz)


def leakyRelu_function(x, alpha):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num)
        else:
            fz.append(alpha * num)
    return fz


def leakyRelu_test():
    x = np.arange(-10, 10, 0.01)
    alpha = 0.5
    fz = leakyRelu_function(x, alpha)
    show_graph('Leaky Relu Function', 'x', 'σ(x)', x, fz)


def softSign_function(x):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num / (1 + num))
        else:
            fz.append(num / (1 - num))
    return fz


def softSign_test():
    x = np.arange(-10, 10, 0.01)
    fz = softSign_function(x)
    show_graph('Soft Sign Function', 'x', 'σ(x)', x, fz)


def softPlus_function(x):
    fz = []
    for num in x:
        fz.append(math.log(1 + math.exp(num)))
    return fz


def softPlus_test():
    x = np.arange(-10, 10, 0.01)
    fz = softPlus_function(x)
    show_graph('Soft Plus Function', 'x', 'σ(x)', x, fz)


def relu_function(x):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num)
        else:
            fz.append(0)
    return fz


def relu_test():
    x = np.arange(-10, 10, 0.01)
    fz = relu_function(x)
    show_graph('Relu Function', 'x', 'σ(x)', x, fz)


def show_graph(title, xlable, ylable, x, fz):
    plt.title(title)
    plt.xlabel(xlable)
    plt.ylabel(ylable)
    plt.plot(x, fz)
    plt.show()


if __name__ == '__main__':
    sigmoid_test()
    tanh_test()
    arctan_test()
    elu_test()
    identity_test()
    softSign_test()
    softPlus_test()
    relu_test()
    leakyRelu_test()


package MachineLearning.ann;

/**
 * @description:激活函数
 * @learner: Qing Zhang
 * @time: 07
 */
public class Activator {
    
    // Arc tan.
    public final char ARC_TAN = 'a';

    // Elu.
    public final char ELU = 'e';

    // Gelu.
    public final char GELU = 'g';

    // Hard logistic.
    public final char HARD_LOGISTIC = 'h';

    // Identity.
    public final char IDENTITY = 'i';

    // Leaky relu, also known as parametric relu.
    public final char LEAKY_RELU = 'l';

    // Relu.
    public final char RELU = 'r';

    // Soft sign.
    public final char SOFT_SIGN = 'o';

    // Sigmoid.
    public final char SIGMOID = 's';

    // Tanh.
    public final char TANH = 't';

    // Soft plus.
    public final char SOFT_PLUS = 'u';

    // Swish.
    public final char SWISH = 'w';

    // The activator.
    private char activator;

    // Alpha for elu.
    double alpha;

    // Beta for leaky relu.
    double beta;

    // Gamma for leaky relu.
    double gamma;

    /**
    * @Description: 构造函数
    * @Param: [paraActivator]
    * @return:
    */
    public Activator(char paraActivator) {
    
        activator = paraActivator;
    }

    /**
    * @Description: 设置
    * @Param: [paraActivator]
    * @return: void
    */
    public void setActivator(char paraActivator) {
    
        activator = paraActivator;
    }

    /**
    * @Description: 获取
    * @Param: []
    * @return: char
    */
    public char getActivator() {
    
        return activator;
    }

    /**
    * @Description: 设置α
    * @Param: [paraAlpha]
    * @return: void
    */
    void setAlpha(double paraAlpha) {
    
        alpha = paraAlpha;
    }// Of setAlpha

    /**
    * @Description: 设置β
    * @Param: [paraBeta]
    * @return: void
    */
    void setBeta(double paraBeta) {
    
        beta = paraBeta;
    }

    /**
    * @Description: 设置γ
    * @Param: [paraGamma]
    * @return: void
    */
    void setGamma(double paraGamma) {
    
        gamma = paraGamma;
    }

    /**
    * @Description: 根据设置的激活函数激活
    * @Param: [paraValue]
    * @return: double
    */
    public double activate(double paraValue) {
    
        double resultValue = 0;
        switch (activator) {
    
            case ARC_TAN:
                resultValue = Math.atan(paraValue);
                break;
            case ELU:
                if (paraValue >= 0) {
    
                    resultValue = paraValue;
                } else {
    
                    resultValue = alpha * (Math.exp(paraValue) - 1);
                }
                break;
            // case GELU:
            // resultValue = ?;
            // break;
            // case HARD_LOGISTIC:
            // resultValue = ?;
            // break;
            case IDENTITY:
                resultValue = paraValue;
                break;
            case LEAKY_RELU:
                if (paraValue >= 0) {
    
                    resultValue = paraValue;
                } else {
    
                    resultValue = alpha * paraValue;
                }
                break;
            case SOFT_SIGN:
                if (paraValue >= 0) {
    
                    resultValue = paraValue / (1 + paraValue);
                } else {
    
                    resultValue = paraValue / (1 - paraValue);
                }
                break;
            case SOFT_PLUS:
                resultValue = Math.log(1 + Math.exp(paraValue));
                break;
            case RELU:
                if (paraValue >= 0) {
    
                    resultValue = paraValue;
                } else {
    
                    resultValue = 0;
                }
                break;
            case SIGMOID:
                resultValue = 1 / (1 + Math.exp(-paraValue));
                break;
            case TANH:
                resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
                break;
            // case SWISH:
            // resultValue = ?;
            // break;
            default:
                System.out.println("Unsupported activator: " + activator);
                System.exit(0);
        }

        return resultValue;
    }

    /**
    * @Description: 根据激活函数求导。有些使用x,有些使用f(x)
    * @Param: [paraValue:x, paraActivatedValue:f(x)]
    * @return: double
    */
    public double derive(double paraValue, double paraActivatedValue) {
    
        double resultValue = 0;
        switch (activator) {
    
            case ARC_TAN:
                resultValue = 1 / (paraValue * paraValue + 1);
                break;
            case ELU:
                if (paraValue >= 0) {
    
                    resultValue = 1;
                } else {
    
                    resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
                }
                break;
            // case GELU:
            // resultValue = ?;
            // break;
            // case HARD_LOGISTIC:
            // resultValue = ?;
            // break;
            case IDENTITY:
                resultValue = 1;
                break;
            case LEAKY_RELU:
                if (paraValue >= 0) {
    
                    resultValue = 1;
                } else {
    
                    resultValue = alpha;
                }
                break;
            case SOFT_SIGN:
                if (paraValue >= 0) {
    
                    resultValue = 1 / (1 + paraValue) / (1 + paraValue);
                } else {
    
                    resultValue = 1 / (1 - paraValue) / (1 - paraValue);
                }
                break;
            case SOFT_PLUS:
                resultValue = 1 / (1 + Math.exp(-paraValue));
                break;
            case RELU: // Updated
                if (paraValue >= 0) {
    
                    resultValue = 1;
                } else {
    
                    resultValue = 0;
                }
                break;
            case SIGMOID: // Updated
                resultValue = paraActivatedValue * (1 - paraActivatedValue);
                break;
            case TANH: // Updated
                resultValue = 1 - paraActivatedValue * paraActivatedValue;
                break;
            // case SWISH:
            // resultValue = ?;
            // break;
            default:
                System.out.println("Unsupported activator: " + activator);
                System.exit(0);
        }

        return resultValue;
    }


    public String toString() {
    
        String resultString = "Activator with function '" + activator + "'";
        resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;

        return resultString;
    }


    public static void main(String[] args) {
    
        Activator tempActivator = new Activator('s');
        double tempValue = 0.6;
        double tempNewValue;
        tempNewValue = tempActivator.activate(tempValue);
        System.out.println("After activation: " + tempNewValue);

        tempNewValue = tempActivator.derive(tempValue, tempNewValue);
        System.out.println("After derive: " + tempNewValue);
    }

}


在这里插入图片描述

第75天:通用BP神经网络 (2. 单层实现)

  1. 仅实现单层 ANN.
  2. 可以有自己的激活函数.
  3. 正向计算输出, 反向计算误差并调整权值.

这里对单层的ANN进行了编码,同时进行测试,可以结合之前创建的 Activator 类调整激活函数。

package MachineLearning.ann;

import java.util.Arrays;
import java.util.Random;

/**
 * @description: Ann层
 * @learner: Qing Zhang
 * @time: 07
 */
public class AnnLayer {
    

    //输入数量
    int numInput;

    //输出数量
    int numOutput;

    //学习率
    double learningRate;

    //动量系数
    double mobp;

    //权值矩阵
    double[][] weights, deltaWeights;

    double[] offset, deltaOffset, errors;

    //输入
    double[] input;

    //输出
    double[] output;

    //激活后的输出
    double[] activatedOutput;

    //输入
    Activator activator;

    //输入
    Random random = new Random();

    public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator, double paraLearningRate, double paraMobp) {
    
        numInput = paraNumInput;
        numOutput = paraNumOutput;
        learningRate = paraLearningRate;
        mobp = paraMobp;

        weights = new double[numInput + 1][numOutput];
        deltaWeights = new double[numInput + 1][numOutput];
        for (int i = 0; i < numInput + 1; i++) {
    
            for (int j = 0; j < numOutput; j++) {
    
                weights[i][j] = random.nextDouble();
            }
        }

        offset = new double[numOutput];
        deltaOffset = new double[numOutput];
        errors = new double[numInput];

        input = new double[numInput];
        output = new double[numOutput];
        activatedOutput = new double[numOutput];

        activator = new Activator(paraActivator);
    }

    /**
     * @Description: 前向预测
     * @Param: [paraInput]
     * @return: double[]
     */
    public double[] forward(double[] paraInput) {
    
        //拷贝数据
        for (int i = 0; i < numInput; i++) {
    
            input[i] = paraInput[i];
        }

        //计算加权和以求得每个输出
        for (int i = 0; i < numOutput; i++) {
    
            output[i] = weights[numInput][i];
            for (int j = 0; j < numInput; j++) {
    
                output[i] += input[j] * weights[j][i];
            }

            activatedOutput[i] = activator.activate(output[i]);
        }

        return activatedOutput;
    }


    /**
     * @Description: 反向传播并改变权值
     * @Param: [paraInput]
     * @return: double[]
     */
    public double[] backPropagation(double[] paraErrors) {
    
        //拷贝数据
        for (int i = 0; i < paraErrors.length; i++) {
    
            paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
        }

        //计算当前的错误
        for (int i = 0; i < numInput; i++) {
    
            errors[i] = 0;
            for (int j = 0; j < numOutput; j++) {
    
                errors[i] += paraErrors[j] * weights[i][j];
                deltaWeights[i][j] = mobp * deltaWeights[i][j] + learningRate * paraErrors[j] * input[i];
                weights[i][j] += deltaWeights[i][j];

                if (i == numInput - 1) {
    
                    //调整偏置
                    deltaOffset[j] = mobp * deltaOffset[j] + learningRate * paraErrors[j];
                    offset[j] += deltaOffset[j];
                }
            }
        }

        return errors;
    }

    /**
     * @Description: 获取最后一层的错误
     * @Param: [paraTarget]
     * @return: double[]
     */
    public double[] getLastLayerErrors(double[] paraTarget) {
    
        double[] resultErrors = new double[numOutput];
        for (int i = 0; i < numOutput; i++) {
    
            resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
        }

        return resultErrors;
    }

    @Override
    public String toString() {
    
        String resultString = "";
        resultString += "Activator: " + activator;
        resultString += "\r\n weights = " + Arrays.deepToString(weights);
        return resultString;
    }

    /**
     * @Description: 单元测试
     * @Param: []
     * @return: void
     */
    public static void unitTest() {
    
        AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
        double[] tempInput = {
    1, 4};

        System.out.println(tempLayer);

        double[] tempOutput = tempLayer.forward(tempInput);
        System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));

        double[] tempError = tempLayer.backPropagation(tempOutput);
        System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
    }

    public static void main(String[] args) {
    
        unitTest();
    }
}

在这里插入图片描述

第76天:通用BP神经网络 (3. 综合测试)

  1. 自己尝试其它的激活函数.
package MachineLearning.ann;

/**
 * @description: 完整的神经网络
 * @learner: Qing Zhang
 * @time: 07
 */
public class FullAnn extends GeneralAnn {
    

    AnnLayer[] layers;


    public FullAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp, String paraActivators) {
    
        super(paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp);

        //初始化层
        layers = new AnnLayer[numLayers - 1];
        for (int i = 0; i < layers.length; i++) {
    
            layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate, paraMobp);
        }
    }

    @Override
    public double[] forward(double[] paraInput) {
    
        double[] resultArray = paraInput;
        for (int i = 0; i < numLayers - 1; i++) {
    
            resultArray = layers[i].forward(resultArray);
        }
        return resultArray;
    }

    @Override
    public void backPropagation(double[] paraTarget) {
    
        double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
        for (int i = numLayers - 2; i >= 0; i--) {
    
            tempErrors = layers[i].backPropagation(tempErrors);
        }
    }

    @Override
    public String toString() {
    
        String resultString = "I am a full ANN with " + numLayers + " layers";
        return resultString;
    }

    public static void main(String[] args) {
    
        int[] tempLayerNodes = {
    4, 8, 8, 3};
        FullAnn tempNetwork = new FullAnn("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff", tempLayerNodes, 0.01, 0.6, "sss");

        for (int round = 0; round < 5000; round++) {
    
            tempNetwork.train();
        }

        double tempAccuray = tempNetwork.test();
        System.out.println("The accuracy is: " + tempAccuray);
        System.out.println("FullAnn ends.");
    }
}

Sigmoid函数:
在这里插入图片描述
SOFT_SIGN:
在这里插入图片描述

SOFT_PLUS:
在这里插入图片描述

RELU:
在这里插入图片描述
LEAKY_RELU:

在这里插入图片描述
ELU:
在这里插入图片描述

第77天:GUI (1. 对话框相关控件)

  1. ApplicationShowdown.java 仅用于退出图形用户界面 GUI.
  2. 只生成了一个静态的实例对象. 构造方法是 private 的, 不允许在该类之外 new. 这是一个有意思的小技巧.
package MachineLearning.gui;

import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowEvent;
import java.awt.event.WindowListener;

/**
 * @description:通过窗口事件或者按钮事件关闭应用程序
 * @learner: Qing Zhang
 * @time: 07
 */
public class ApplicationShutdown implements WindowListener, ActionListener {
    

    //只能存在一个对象
    public static ApplicationShutdown applicationShutdown = new ApplicationShutdown();


    //构造函数是私人的,因为只能存在一个对象,而静态对象已经声明了。
    private ApplicationShutdown() {
    
    }


    //关闭系统
    public void windowClosing(WindowEvent comeInWindowEvent) {
    
        System.exit(0);
    }// Of windowClosing.

    public void windowActivated(WindowEvent comeInWindowEvent) {
    
    }

    public void windowClosed(WindowEvent comeInWindowEvent) {
    
    }

    public void windowDeactivated(WindowEvent comeInWindowEvent) {
    
    }

    public void windowDeiconified(WindowEvent comeInWindowEvent) {
    
    }

    public void windowIconified(WindowEvent comeInWindowEvent) {
    
    }

    public void windowOpened(WindowEvent comeInWindowEvent) {
    
    }


    public void actionPerformed(ActionEvent ee) {
    
        System.exit(0);
    }

}

DialogCloser.java 用于关闭窗口, 而不是整个的 GUI.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;

/**
 * @description:关闭当前窗口
 * @learner: Qing Zhang
 * @time: 07
 */
public class DialogCloser extends WindowAdapter implements ActionListener {
    

    //当前打开的窗口
    private Dialog currentDialog;


    public DialogCloser() {
    
        super();
    }


    public DialogCloser(Dialog paraDialog) {
    
        currentDialog = paraDialog;
    }// Of the second constructor

    /**
    * @Description: 关闭窗口
     * 点击窗口右上角时
    * @Param: [paraWindowEvent]
    * @return: void
    */
    public void windowClosing(WindowEvent paraWindowEvent) {
    
        paraWindowEvent.getWindow().dispose();
    }

    /**
     ***************************
     * Close the dialog while pushing an "OK" or "Cancel" button.
     *
     * @param paraEvent
     *            Not considered.
     ***************************
     */
    /**
    * @Description: 关闭窗口
     * 当点击“OK”或者“Cancel”按钮时
    * @Param: [paraEvent]
    * @return: void
    */
    public void actionPerformed(ActionEvent paraEvent) {
    
        currentDialog.dispose();
    }
}

ErrorDialog.java 用于显示出错信息. 有了 GUI 我们可以不再使用 System.out.println.

package MachineLearning.gui;

import java.awt.*;

/**
 * @description:错误窗口
 * @learner: Qing Zhang
 * @time: 07
 */
public class ErrorDialog extends Dialog {
    

    //Serial uid. 不一定有用
    private static final long serialVersionUID = 124535235L;

    //唯一的错误窗口
    public static ErrorDialog errorDialog = new ErrorDialog();


    //用于显示信息的标签文本
    private TextArea messageTextArea;


    /**
    * @Description: 错误窗口
     * 该窗口与其他窗口一样也只存在一个,这样可以节省内存,
     * 当出现许多错误时,一个错误窗口即可解决
    * @Param: []
    * @return:
    */
    private ErrorDialog() {
    
        //模型窗口
        super(GUICommon.mainFrame, "Error", true);

        //初始化该窗口的内容
        messageTextArea = new TextArea();

        Button okButton = new Button("OK");
        okButton.setSize(20, 10);
        okButton.addActionListener(new DialogCloser(this));
        Panel okPanel = new Panel();
        okPanel.setLayout(new FlowLayout());
        okPanel.add(okButton);

        //添加文本域和按钮
        setLayout(new BorderLayout());
        add(BorderLayout.CENTER, messageTextArea);
        add(BorderLayout.SOUTH, okPanel);

        setLocation(200, 200);
        setSize(500, 200);
        addWindowListener(new DialogCloser());
        setVisible(false);
    }

    /**
    * @Description: 设置信息
    * @Param: [paramMessage]
    * @return: void
    */
    public void setMessageAndShow(String paramMessage) {
    
        messageTextArea.setText(paramMessage);
        setVisible(true);
    }
}

GUICommon.java 存储一些公用变量.

package MachineLearning.gui;

import javax.swing.*;
import java.awt.*;

/**
 * @description:公共变量
 * @learner: Qing Zhang
 * @time: 07
 */
public class GUICommon extends Object {
    

    //仅一个主窗口
    public static Frame mainFrame = null;


    //一个主布局
    public static JTabbedPane mainPane = null;

    //默认数量
    public static int currentProjectNumber = 0;

    //默认文字
    public static final Font MY_FONT = new Font("Times New Romans", Font.PLAIN, 12);

    //默认颜色
    public static final Color MY_COLOR = Color.lightGray;


    /** 
    * @Description: 设置主窗口。这一步骤仅在开始时执行一次
    * @Param: [paraFrame]
    * @return: void
    */
    public static void setFrame(Frame paraFrame) throws Exception {
    
        if (mainFrame == null) {
    
            mainFrame = paraFrame;
        } else {
    
            throw new Exception("The main frame can be set only ONCE!");
        }
    }

    
    
    /** 
    * @Description: 设置主布局。这一步骤仅在开始时执行一次
    * @Param: [paramPane]
    * @return: void
    */
    public static void setPane(JTabbedPane paramPane) throws Exception {
    
        if (mainPane == null) {
    
            mainPane = paramPane;
        } else {
    
            throw new Exception("The main panel can be set only ONCE!");
        }
    }
}

HelpDialog.java 显示帮助信息, 这样, 在主界面点击 Help 按钮时, 就会显示相关参数的说明. 其目的在于提高软件的易用性、可维护性.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import java.io.RandomAccessFile;

/**
 * @description:帮助框
 * @learner: Qing Zhang
 * @time: 07
 */
public class HelpDialog extends Dialog implements ActionListener {
    
    /**
     * Serial uid. Not quite useful.
     */
    private static final long serialVersionUID = 3869415040299264995L;


    /**
    * @Description: 显示帮助窗口
    * @Param: [paraTitle, paraFilename]
    * @return:
    */
    public HelpDialog(String paraTitle, String paraFilename) {
    
        super(GUICommon.mainFrame, paraTitle, true);
        setBackground(GUICommon.MY_COLOR);

        TextArea displayArea = new TextArea("", 10, 10, TextArea.SCROLLBARS_VERTICAL_ONLY);
        displayArea.setEditable(false);
        String textToDisplay = "";
        try {
    
            RandomAccessFile helpFile = new RandomAccessFile(paraFilename, "r");
            String tempLine = helpFile.readLine();
            while (tempLine != null) {
    
                textToDisplay = textToDisplay + tempLine + "\n";
                tempLine = helpFile.readLine();
            }
            helpFile.close();
        } catch (IOException ee) {
    
            dispose();
            ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
        }
        // 如果需要显示中文就使用这个。用这个方法
        // method.
        // textToDisplay = SimpleTools.GB2312ToUNICODE(textToDisplay);
        displayArea.setText(textToDisplay);
        displayArea.setFont(new Font("Times New Romans", Font.PLAIN, 14));

        Button okButton = new Button("OK");
        okButton.setSize(20, 10);
        okButton.addActionListener(new DialogCloser(this));
        Panel okPanel = new Panel();
        okPanel.setLayout(new FlowLayout());
        okPanel.add(okButton);

        // OK 按钮
        setLayout(new BorderLayout());
        add(BorderLayout.CENTER, displayArea);
        add(BorderLayout.SOUTH, okPanel);

        setLocation(120, 70);
        setSize(500, 400);
        addWindowListener(new DialogCloser());
        setVisible(false);
    }
    
    /** 
    * @Description: 简单的激活使它可视化
    * @Param: [ee]
    * @return: void
    */
    public void actionPerformed(ActionEvent ee) {
    
        setVisible(true);
    }
}

第78天:GUI (2. 数据读取控件)

DoubleField.java 用于接受实型值, 如果不能解释成实型值会报错. 这样可以把用户的低级错误扼杀在摇篮中.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;

/**
 * @description:用于接收double值
 * @learner: Qing Zhang
 * @time: 07
 */
public class DoubleField extends TextField implements FocusListener {
    

    //Serial uid. 不一定有用
    private static final long serialVersionUID = 363634723L;

    //值
    protected double doubleValue;

    //赋予默认值
    public DoubleField() {
    
        this("5.13", 10);
    }// Of the first constructor

    //只指定内容
    public DoubleField(String paraString) {
    
        this(paraString, 10);
    }// Of the second constructor

    //只指定宽
    public DoubleField(int paraWidth) {
    
        this("5.13", paraWidth);
    }// Of the third constructor

    /**
    * @Description: 指定宽和长
    * @Param: [paraString, paraWidth]
    * @return:
    */
    public DoubleField(String paraString, int paraWidth) {
    
        super(paraString, paraWidth);
        addFocusListener(this);
    }

    /**
    * @Description:获得焦点事件
    * @Param: [paraEvent]
    * @return: void
    */
    public void focusGained(FocusEvent paraEvent) {
    
    }

    /**
    * @Description: 执行焦点的监听事件
    * @Param: [paraEvent]
    * @return: void
    */
    public void focusLost(FocusEvent paraEvent) {
    
        try {
    
            doubleValue = Double.parseDouble(getText());
        } catch (Exception ee) {
    
            ErrorDialog.errorDialog
                    .setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
            requestFocus();
        }
    }

    /**
    * @Description: 获取值
    * @Param: []
    * @return: double
    */
    public double getValue() {
    
        try {
    
            doubleValue = Double.parseDouble(getText());
        } catch (Exception ee) {
    
            ErrorDialog.errorDialog
                    .setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
            requestFocus();
        } 
        return doubleValue;
    }
}

IntegerField.java 同理.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;

/**
 * @description: 用于接收int值
 * @learner: Qing Zhang
 * @time: 07
 */
public class IntegerField extends TextField implements FocusListener {
    

    //Serial uid. 不一定有用
    private static final long serialVersionUID = -2462338973265150779L;

    //只指定内容
    public IntegerField() {
    
        this("513");
    }// Of constructor

    /** 
    * @Description: 指定宽和长
    * @Param: [paraString, paraWidth]
    * @return: 
    */
    public IntegerField(String paraString, int paraWidth) {
    
        super(paraString, paraWidth);
        addFocusListener(this);
    }

    //只指定内容
    public IntegerField(String paraString) {
    
        super(paraString);
        addFocusListener(this);
    }

    //只指定宽
    public IntegerField(int paraWidth) {
    
        super(paraWidth);
        setText("513");
        addFocusListener(this);
    }

    /**
     * @Description:获得焦点事件
     * @Param: [paraEvent]
     * @return: void
     */
    public void focusGained(FocusEvent paraEvent) {
    
    }

    /**
     * @Description: 执行焦点的监听事件
     * @Param: [paraEvent]
     * @return: void
     */
    public void focusLost(FocusEvent paraEvent) {
    
        try {
    
            Integer.parseInt(getText());
            // System.out.println(tempInt);
        } catch (Exception ee) {
    
            ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
                    + "\"Not an integer. Please check.");
            requestFocus();
        }
    }

    /**
     * @Description: 获取值
     * @Param: []
     * @return: int
     */
    public int getValue() {
    
        int tempInt = 0;
        try {
    
            tempInt = Integer.parseInt(getText());
        } catch (Exception ee) {
    
            ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
                    + "\" Not an int. Please check.");
            requestFocus();
        }
        return tempInt;
    }
}

FilenameField.java 则需要借助于系统提供的 FileDialog.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;
import java.io.File;

/**
 * @description:
 * @learner: Qing Zhang
 * @time: 07
 */
public class FilenameField extends TextField implements ActionListener, FocusListener {
    

    //Serial uid. 不一定有用
    private static final long serialVersionUID = 4572287941606065298L;

    /** 
    * @Description: 初始化
    * @Param: []
    * @return: 
    */
    public FilenameField() {
    
        super();
        setText("");
        addFocusListener(this);
    }

    /** 
    * @Description: 初始化
    * @Param: [paraWidth]
    * @return: 
    */
    public FilenameField(int paraWidth) {
    
        super(paraWidth);
        setText("");
        addFocusListener(this);
    }

    /** 
    * @Description: 初始化
    * @Param: [paraWidth, paraText]
    * @return: 
    */
    public FilenameField(int paraWidth, String paraText) {
    
        super(paraWidth);
        setText(paraText);
        addFocusListener(this);
    }

    /** 
    * @Description: 初始化
    * @Param: [paraText, paraWidth]
    * @return: 
    */
    public FilenameField(String paraText, int paraWidth) {
    
        super(paraWidth);
        setText(paraText);
        addFocusListener(this);
    }

    /** 
    * @Description: 避免null或者空串
    * @Param: [paraText]
    * @return: void
    */
    public void setText(String paraText) {
    
        if (paraText.trim().equals("")) {
    
            super.setText("unspecified");
        } else {
    
            super.setText(paraText.replace('\\', '/'));
        }
    }

    /** 
    * @Description: 执行活动监听
    * @Param: [paraEvent]
    * @return: void
    */
    public void actionPerformed(ActionEvent paraEvent) {
    
        FileDialog tempDialog = new FileDialog(GUICommon.mainFrame,
                "Select a file");
        tempDialog.setVisible(true);
        if (tempDialog.getDirectory() == null) {
    
            setText("");
            return;
        }

        String directoryName = tempDialog.getDirectory();

        String tempFilename = directoryName + tempDialog.getFile();
        //System.out.println("tempFilename = " + tempFilename);

        setText(tempFilename);
    }

    /** 
    * @Description: 执行焦点监听事件
    * @Param: [paraEvent]
    * @return: void
    */
    public void focusGained(FocusEvent paraEvent) {
    
    }

    /**
     * @Description: 执行焦点监听事件
     * @Param: [paraEvent]
     * @return: void
     */
    public void focusLost(FocusEvent paraEvent) {
    
        // System.out.println("Focus lost exists.");
        String tempString = getText();
        if ((tempString.equals("unspecified"))
                || (tempString.equals("")))
            return;
        File tempFile = new File(tempString);
        if (!tempFile.exists()) {
    
            ErrorDialog.errorDialog.setMessageAndShow("File \"" + tempString
                    + "\" not exists. Please check.");
            requestFocus();
            setText("");
        }
    }
}

第79天:GUI (3. 总体布局)

  1. 用了 GridLayout 和 BorderLayout 来组织控件.
  2. 按下 OK 执行 actionPerformed. 前两天已经有类似代码了.
package MachineLearning.gui;

import MachineLearning.ann.FullAnn;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Date;

/**
 * @description:
 * @learner: Qing Zhang
 * @time: 07
 */
public class AnnMain implements ActionListener {
    

    //选择arff文件
    private FilenameField arffFilenameField;


    //设置α
    private DoubleField alphaField;


    //设置β
    private DoubleField betaField;

    //设置γ
    private DoubleField gammaField;


    //每层节点,如 "4, 8, 8, 3".
    private TextField layerNodesField;


    //激活函数的选择,例如 "ssa".
    private TextField activatorField;

    //训练次数
    private IntegerField roundsField;

    //学习率
    private DoubleField learningRateField;

    //mobp
    private DoubleField mobpField;


    //信息区域
    private TextArea messageTextArea;

    /**
    * @Description: 唯一的构造函数
    * @Param: []
    * @return:
    */
    public AnnMain() {
    
        //一个简单的窗口包含对话框
        Frame mainFrame = new Frame();
        mainFrame.setTitle("ANN. [email protected]");
        //顶部:选择arff文件
        arffFilenameField = new FilenameField(30);
        arffFilenameField.setText("d:/data/iris.arff");
        Button browseButton = new Button(" Browse ");
        browseButton.addActionListener(arffFilenameField);

        Panel sourceFilePanel = new Panel();
        sourceFilePanel.add(new Label("The .arff file:"));
        sourceFilePanel.add(arffFilenameField);
        sourceFilePanel.add(browseButton);

        //设置面板
        Panel settingPanel = new Panel();
        settingPanel.setLayout(new GridLayout(3, 6));

        settingPanel.add(new Label("alpha"));
        alphaField = new DoubleField("0.01");
        settingPanel.add(alphaField);

        settingPanel.add(new Label("beta"));
        betaField = new DoubleField("0.02");
        settingPanel.add(betaField);

        settingPanel.add(new Label("gamma"));
        gammaField = new DoubleField("0.03");
        settingPanel.add(gammaField);

        settingPanel.add(new Label("layer nodes"));
        layerNodesField = new TextField("4, 8, 8, 3");
        settingPanel.add(layerNodesField);

        settingPanel.add(new Label("activators"));
        activatorField = new TextField("sss");
        settingPanel.add(activatorField);

        settingPanel.add(new Label("training rounds"));
        roundsField = new IntegerField("5000");
        settingPanel.add(roundsField);

        settingPanel.add(new Label("learning rate"));
        learningRateField = new DoubleField("0.01");
        settingPanel.add(learningRateField);

        settingPanel.add(new Label("mobp"));
        mobpField = new DoubleField("0.5");
        settingPanel.add(mobpField);

        Panel topPanel = new Panel();
        topPanel.setLayout(new BorderLayout());
        topPanel.add(BorderLayout.NORTH, sourceFilePanel);
        topPanel.add(BorderLayout.CENTER, settingPanel);

        messageTextArea = new TextArea(80, 40);

        //底部:ok和exit
        Button okButton = new Button(" OK ");
        okButton.addActionListener(this);
        // DialogCloser dialogCloser = new DialogCloser(this);
        Button exitButton = new Button(" Exit ");
        // cancelButton.addActionListener(dialogCloser);
        exitButton.addActionListener(ApplicationShutdown.applicationShutdown);
        Button helpButton = new Button(" Help ");
        helpButton.setSize(20, 10);
        helpButton.addActionListener(new HelpDialog("ANN", "src/machinelearning/gui/help.txt"));
        Panel okPanel = new Panel();
        okPanel.add(okButton);
        okPanel.add(exitButton);
        okPanel.add(helpButton);

        mainFrame.setLayout(new BorderLayout());
        mainFrame.add(BorderLayout.NORTH, topPanel);
        mainFrame.add(BorderLayout.CENTER, messageTextArea);
        mainFrame.add(BorderLayout.SOUTH, okPanel);

        mainFrame.setSize(600, 500);
        mainFrame.setLocation(100, 100);
        mainFrame.addWindowListener(ApplicationShutdown.applicationShutdown);
        mainFrame.setBackground(GUICommon.MY_COLOR);
        mainFrame.setVisible(true);
    }


    /**
    * @Description: 读入arff文件
    * @Param: [ae]
    * @return: void
    */
    public void actionPerformed(ActionEvent ae) {
    
        String tempFilename = arffFilenameField.getText();

        // Read the layers nodes.
        String tempString = layerNodesField.getText().trim();

        int[] tempLayerNodes = null;
        try {
    
            tempLayerNodes = stringToIntArray(tempString);
        } catch (Exception ee) {
    
            ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
            return;
        }

        double tempLearningRate = learningRateField.getValue();
        double tempMobp = mobpField.getValue();
        String tempActivators = activatorField.getText().trim();
        FullAnn tempNetwork = new FullAnn(tempFilename, tempLayerNodes, tempLearningRate, tempMobp,
                tempActivators);
        int tempRounds = roundsField.getValue();

        long tempStartTime = new Date().getTime();
        for (int i = 0; i < tempRounds; i++) {
    
            tempNetwork.train();
        }
        long tempEndTime = new Date().getTime();
        messageTextArea.append("\r\nSummary:\r\n");
        messageTextArea.append("Trainng time: " + (tempEndTime - tempStartTime) + "ms.\r\n");

        double tempAccuray = tempNetwork.test();
        messageTextArea.append("Accuracy: " + tempAccuray + "\r\n");
        messageTextArea.append("End.");
    }

    /** 
    * @Description: 将带逗号的字符串转换为int数组。
    * @Param: [paraString]
    * @return: int[]
    */
    public static int[] stringToIntArray(String paraString) throws Exception {
    
        int tempCounter = 1;
        for (int i = 0; i < paraString.length(); i++) {
    
            if (paraString.charAt(i) == ',') {
    
                tempCounter++;
            }
        }

        int[] resultArray = new int[tempCounter];

        String tempRemainingString = new String(paraString) + ",";
        String tempString;
        for (int i = 0; i < tempCounter; i++) {
    
            tempString = tempRemainingString.substring(0, tempRemainingString.indexOf(",")).trim();
            if (tempString.equals("")) {
    
                throw new Exception("Blank is unsupported");
            }

            resultArray[i] = Integer.parseInt(tempString);

            tempRemainingString = tempRemainingString
                    .substring(tempRemainingString.indexOf(",") + 1);
        }

        return resultArray;
    }


    public static void main(String args[]) {
    
        new AnnMain();
    }// Of main
}

在这里插入图片描述

第80天:GUI (4. 各种监听机制)

  1. 从监听机制、接口等角度, 分析在 GUI 上的各种操作分别会触发哪些代码;

    由于之前用C#写过winform程序,所以对于GUI上的事件响应,监听机制还是比较熟悉的,这里主要是使用了观察者设计模式,事件源注册事件监听器后,当事件源上发生某个动作时,事件源就会调用事件监听的一个方法,并将事件对象传递进去,开发者可以利用事件对象操作事件源。比如当操作鼠标点击某个部件时,可以将鼠标的点击事件触发,从而传递消息,比如是否点击以及鼠标在窗体上的位置等信息。

  2. 总结基础的人工神经网络.

迭代算法,随机设定参数的初始值,计算当前网络的输出,根据当前输出与样本决策标签的误差再反向传播,改变参数值,不断循环往复直至收敛至某一阈值。

缺点:

  1. 不知道你的神经网络将会如何产出结果,更不知道为什么会产生这种结果。
  2. 比较耗时;
  3. 难以找到大量有标签的数据;
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Z_122113/article/details/119026781

智能推荐

C#连接OPC C#上位机链接PLC程序源码 1.该程序是通讯方式是CSharp通过OPC方式连接PLC_c#opc通信-程序员宅基地

文章浏览阅读565次。本文主要介绍如何使用C#通过OPC方式连接PLC,并提供了相应的程序和学习资料,以便读者学习和使用。OPC服务器是一种软件,可以将PLC的数据转换为标准的OPC格式,允许其他软件通过标准接口读取或控制PLC的数据。此外,本文还提供了一些学习资料,包括OPC和PLC的基础知识,C#编程语言的教程和实例代码。这些资料可以帮助读者更好地理解和应用本文介绍的程序。1.该程序是通讯方式是CSharp通过OPC方式连接PLC,用这种方式连PLC不用考虑什么种类PLC,只要OPC服务器里有的PLC都可以连。_c#opc通信

Hyper-V内的虚拟机复制粘贴_win10 hyper-v ubuntu18.04 文件拷贝-程序员宅基地

文章浏览阅读1.6w次,点赞3次,收藏10次。实践环境物理机:Windows10教育版,操作系统版本 17763.914虚拟机:Ubuntu18.04.3桌面版在Hyper-V中的刚安装好Ubuntu虚拟机之后,会发现鼠标滑动很不顺畅,也不能向虚拟机中拖拽文件或者复制内容。在VMware中,可以通过安装VMware tools来使物理机和虚拟机之间达到更好的交互。在Hyper-V中,也有这样的工具。这款工具可以完成更好的鼠标交互,我的..._win10 hyper-v ubuntu18.04 文件拷贝

java静态变量初始化多线程,持续更新中_类初始化一个静态属性 为线程池-程序员宅基地

文章浏览阅读156次。前言互联网时代,瞬息万变。一个小小的走错,就有可能落后于别人。我们没办法去预测任何行业、任何职业未来十年会怎么样,因为未来谁都不能确定。只能说只要有互联网存在,程序员依然是个高薪热门行业。只要跟随着时代的脚步,学习新的知识。程序员是不可能会消失的,或者说不可能会没钱赚的。我们经常可以听到很多人说,程序员是一个吃青春饭的行当。因为大多数人认为这是一个需要高强度脑力劳动的工种,而30岁、40岁,甚至50岁的程序员身体机能逐渐弱化,家庭琐事缠身,已经不能再进行这样高强度的工作了。那么,这样的说法是对的么?_类初始化一个静态属性 为线程池

idea 配置maven,其实不用单独下载Maven的。以及设置新项目配置,省略每次创建新项目都要配置一次Maven_安装idea后是不是不需要安装maven了?-程序员宅基地

文章浏览阅读1w次,点赞13次,收藏43次。说来也是惭愧,一直以来,在装环境的时候都会从官网下载Maven。然后再在idea里配置Maven。以为从官网下载的Maven是必须的步骤,直到今天才得知,idea有捆绑的 Maven 我们只需要搞一个配置文件就行了无需再官网下载Maven包以后再在新电脑装环境的时候,只需要下载idea ,网上找一个Maven的配置文件 放到 默认的 包下面就可以了!也省得每次创建项目都要重新配一次Maven了。如果不想每次新建项目都要重新配置Maven,一种方法就是使用默认的配置,另一种方法就是配置 .._安装idea后是不是不需要安装maven了?

奶爸奶妈必看给宝宝摄影大全-程序员宅基地

文章浏览阅读45次。家是我们一生中最重要的地方,小时候,我们在这里哭、在这里笑、在这里学习走路,在这里有我们最真实的时光,用相机把它记下吧。  很多家庭在拍摄孩子时有一个看法,认为儿童摄影团购必须是在风景秀丽的户外,即便是室内那也是像大酒店一样...

构建Docker镜像指南,含实战案例_rocker/r-base镜像-程序员宅基地

文章浏览阅读429次。Dockerfile介绍Dockerfile是构建镜像的指令文件,由一组指令组成,文件中每条指令对应linux中一条命令,在执行构建Docker镜像时,将读取Dockerfile中的指令,根据指令来操作生成指定Docker镜像。Dockerfile结构:主要由基础镜像信息、维护者信息、镜像操作指令、容器启动时执行指令。每行支持一条指令,每条指令可以携带多个参数。注释可以使用#开头。指令说明FROM 镜像 : 指定新的镜像所基于的镜像MAINTAINER 名字 : 说明新镜像的维护(制作)人,留下_rocker/r-base镜像

随便推点

毕设基于微信小程序的小区管理系统的设计ssm毕业设计_ssm基于微信小程序的公寓生活管理系统-程序员宅基地

文章浏览阅读223次。该系统将提供便捷的信息发布、物业报修、社区互动等功能,为小区居民提供更加便利、高效的服务。引言: 随着城市化进程的加速,小区管理成为一个日益重要的任务。因此,设计一个基于微信小程序的小区管理系统成为了一项具有挑战性和重要性的毕设课题。本文将介绍该小区管理系统的设计思路和功能,以期为小区提供更便捷、高效的管理手段。四、总结与展望: 通过本次毕设项目,我们实现了一个基于微信小程序的小区管理系统,为小区居民提供了更加便捷、高效的服务。通过该系统的设计与实现,能够提高小区管理水平,提供更好的居住环境和服务。_ssm基于微信小程序的公寓生活管理系统

如何正确的使用Ubuntu以及安装常用的渗透工具集.-程序员宅基地

文章浏览阅读635次。文章来源i春秋入坑Ubuntu半年多了记得一开始学的时候基本一星期重装三四次=-= 尴尬了 觉得自己差不多可以的时候 就吧Windows10干掉了 c盘装Ubuntu 专心学习. 这里主要来说一下使用Ubuntu的正确姿势Ubuntu(友帮拓、优般图、乌班图)是一个以桌面应用为主的开源GNU/Linux操作系统,Ubuntu 是基于DebianGNU/Linux,支..._ubuntu安装攻击工具包

JNI参数传递引用_jni引用byte[]-程序员宅基地

文章浏览阅读335次。需求:C++中将BYTE型数组传递给Java中,考虑到内存释放问题,未采用通过返回值进行数据传递。public class demoClass{public native boolean getData(byte[] tempData);}JNIEXPORT jboolean JNICALL Java_com_core_getData(JNIEnv *env, jobject thisObj, jbyteArray tempData){ //resultsize为s..._jni引用byte[]

三维重建工具——pclpy教程之点云分割_pclpy.pcl.pointcloud.pointxyzi转为numpy-程序员宅基地

文章浏览阅读2.1k次,点赞5次,收藏30次。本教程代码开源:GitHub 欢迎star文章目录一、平面模型分割1. 代码2. 说明3. 运行二、圆柱模型分割1. 代码2. 说明3. 运行三、欧几里得聚类提取1. 代码2. 说明3. 运行四、区域生长分割1. 代码2. 说明3. 运行五、基于最小切割的分割1. 代码2. 说明3. 运行六、使用 ProgressiveMorphologicalFilter 分割地面1. 代码2. 说明3. 运行一、平面模型分割在本教程中,我们将学习如何对一组点进行简单的平面分割,即找到支持平面模型的点云中的所有._pclpy.pcl.pointcloud.pointxyzi转为numpy

以NFS启动方式构建arm-linux仿真运行环境-程序员宅基地

文章浏览阅读141次。一 其实在 skyeye 上移植 arm-linux 并非难事,网上也有不少资料, 只是大都遗漏细节, 以致细微之处卡壳,所以本文力求详实清析, 希望能对大家有点用处。本文旨在将 arm-linux 在 skyeye 上搭建起来,并在 arm-linux 上能成功 mount NFS 为目标, 最终我们能在 arm-linux 里运行我们自己的应用程序. 二 安装 Sky..._nfs启动 arm

攻防世界 Pwn 进阶 第二页_pwn snprintf-程序员宅基地

文章浏览阅读598次,点赞2次,收藏5次。00为了形成一个体系,想将前面学过的一些东西都拉来放在一起总结总结,方便学习,方便记忆。攻防世界 Pwn 新手攻防世界 Pwn 进阶 第一页01 4-ReeHY-main-100超详细的wp1超详细的wp203 format2栈迁移的两种作用之一:栈溢出太小,进行栈迁移从而能够写入更多shellcode,进行更多操作。栈迁移一篇搞定有个陌生的函数。C 库函数 void *memcpy(void *str1, const void *str2, size_t n) 从存储区 str2 _pwn snprintf

推荐文章

热门文章

相关标签