千家信息网

dl4j如何使用遗传神经网络完成手写数字识别

发表于:2024-11-26 作者:千家信息网编辑
千家信息网最后更新 2024年11月26日,今天就跟大家聊聊有关dl4j如何使用遗传神经网络完成手写数字识别,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。实现步骤1.随机初始化若干个智能
千家信息网最后更新 2024年11月26日dl4j如何使用遗传神经网络完成手写数字识别

今天就跟大家聊聊有关dl4j如何使用遗传神经网络完成手写数字识别,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。

实现步骤

  • 1.随机初始化若干个智能体(神经网络),并让智能体识别训练数据,并对识别结果进行排序

  • 2.随机在排序结果中选择一个作为母本,并在比母本识别率更高的智能体中随机选择一个作为父本

  • 3.随机选择母本或父本同位的神经网络超参组成新的智能体

  • 4.按照母本的排序对智能体进行超参调整,排序越靠后调整幅度越大(1%~10%)之间

  • 5.让新的智能体识别训练集并放入排行榜,并移除排行榜最后一位

  • 6.重复2~5过程,让识别率越来越高

这个过程就类似于自然界的优胜劣汰,将神经网络超参看作dna,超参的调整看作dna的突变;当然还可以把拥有不同隐藏层的神经网络看作不同的物种,让竞争过程更加多样化.当然我们这里只讨论一种神经网络的情况

优势: 可以解决很多没有头绪的问题 劣势: 训练效率极低

gitee地址:

https://gitee.com/ichiva/gnn.git

实现步骤 1.进化接口

public interface Evolution {    /**     * 遗传     * @param mDna     * @param fDna     * @return     */    INDArray inheritance(INDArray mDna,INDArray fDna);    /**     * 突变     * @param dna     * @param v     * @param r 突变范围     * @return     */    INDArray mutation(INDArray dna,double v, double r);    /**     * 置换     * @param dna     * @param v     * @return     */    INDArray substitution(INDArray dna,double v);    /**     * 外源     * @param dna     * @param v     * @return     */    INDArray other(INDArray dna,double v);    /**     * DNA 是否同源     * @param mDna     * @param fDna     * @return     */    boolean iSogeny(INDArray mDna, INDArray fDna);}

一个比较通用的实现

public class MnistEvolution implements Evolution {    private static final MnistEvolution instance = new MnistEvolution();    public static MnistEvolution getInstance() {        return instance;    }    @Override    public INDArray inheritance(INDArray mDna, INDArray fDna) {        if(mDna == fDna) return mDna;        long[] mShape = mDna.shape();        if(!iSogeny(mDna,fDna)){            throw new RuntimeException("非同源dna");        }        INDArray nDna = Nd4j.create(mShape);        NdIndexIterator it = new NdIndexIterator(mShape);        while (it.hasNext()){            long[] next = it.next();            doubleval;            if(Math.random() > 0.5){                val = fDna.getDouble(next);            }else {                val = mDna.getDouble(next);            }            nDna.putScalar(next,val);        }        return nDna;    }    @Override    public INDArray mutation(INDArray dna, double v, double r) {        long[] shape = dna.shape();        INDArray nDna = Nd4j.create(shape);        NdIndexIterator it = new NdIndexIterator(shape);        while (it.hasNext()) {            long[] next = it.next();            if(Math.random() < v){                dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2));            }else {                nDna.putScalar(next,dna.getDouble(next));            }        }        return nDna;    }    @Override    public INDArray substitution(INDArray dna, double v) {        long[] shape = dna.shape();        INDArray nDna = Nd4j.create(shape);        NdIndexIterator it = new NdIndexIterator(shape);        while (it.hasNext()) {            long[] next = it.next();            if(Math.random() > v){                long[] tag = new long[shape.length];                for (int i = 0; i < shape.length; i++) {                    tag[i] = (long) (Math.random() * shape[i]);                }                nDna.putScalar(next,dna.getDouble(tag));            }else {                nDna.putScalar(next,dna.getDouble(next));            }        }        return nDna;    }    @Override    public INDArray other(INDArray dna, double v) {        long[] shape = dna.shape();        INDArray nDna = Nd4j.create(shape);        NdIndexIterator it = new NdIndexIterator(shape);        while (it.hasNext()) {            long[] next = it.next();            if(Math.random() > v){                nDna.putScalar(next,Math.random());            }else {                nDna.putScalar(next,dna.getDouble(next));            }        }        return nDna;    }    @Override    public boolean iSogeny(INDArray mDna, INDArray fDna) {        long[] mShape = mDna.shape();        long[] fShape = fDna.shape();        if (mShape.length == fShape.length) {            for (int i = 0; i < mShape.length; i++) {                if (mShape[i] != fShape[i]) {                    return false;                }            }            return true;        }        return false;    }}

定义智能体配置接口

public interface AgentConfig {    /**     * 输入量     * @return     */    int getInput();    /**     * 输出量     * @return     */    int getOutput();    /**     * 神经网络配置     * @return     */    MultiLayerConfiguration getMultiLayerConfiguration();}

按手写数字识别进行配置实现

public class MnistConfig implements AgentConfig {    @Override    public int getInput() {        return 28 * 28;    }    @Override    public int getOutput() {        return 10;    }    @Override    public MultiLayerConfiguration getMultiLayerConfiguration() {        return new NeuralNetConfiguration.Builder()                .seed((long) (Math.random() * Long.MAX_VALUE))                .updater(new Nesterovs(0.006, 0.9))                .l2(1e-4)                .list()                .layer(0, new DenseLayer.Builder()                        .nIn(getInput())                        .nOut(1000)                        .activation(Activation.RELU)                        .weightInit(WeightInit.XAVIER)                        .build())                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer                        .nIn(1000)                        .nOut(getOutput())                        .activation(Activation.SOFTMAX)                        .weightInit(WeightInit.XAVIER)                        .build())                .pretrain(false).backprop(true)                .build();    }}

智能体基类

@Getterpublic class Agent {    private final AgentConfig config;    private final INDArray dna;    private final MultiLayerNetwork multiLayerNetwork;    /**     * 采用默认方法初始化参数     * @param config     */    public Agent(AgentConfig config){        this(config,null);    }    /**     *     * @param config     * @param dna     */    public Agent(AgentConfig config, INDArray dna){        if(dna == null){            this.config = config;            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();            this.multiLayerNetwork = new MultiLayerNetwork(conf);            multiLayerNetwork.init();            this.dna = multiLayerNetwork.params();        }else {            this.config = config;            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();            this.multiLayerNetwork = new MultiLayerNetwork(conf);            multiLayerNetwork.init(dna,true);            this.dna = dna;        }    }}

手写数字智能体实现类

@Getter@Setterpublic class MnistAgent extends Agent {    private static final AtomicInteger index = new AtomicInteger(0);    private String name;    /**     * 环境适应分数     */    private double score;    /**     * 验证分数     */    private double validScore;    public MnistAgent(AgentConfig config) {        this(config,null);    }    public MnistAgent(AgentConfig config, INDArray dna) {        super(config, dna);        name = "agent-" + index.incrementAndGet();    }    public static MnistConfig mnistConfig = new MnistConfig();    public static MnistAgent newInstance(){        return new MnistAgent(mnistConfig);    }    public static MnistAgent create(INDArray dna){        return new MnistAgent(mnistConfig,dna);    }}

手写数字识别环境构建

@Slf4jpublic class MnistEnv {    /**     * 环境数据     */    private static final ThreadLocal tLocal = ThreadLocal.withInitial(() -> {        try {            return new MnistDataSetIterator(128, true, 0);        } catch (IOException e) {            throw new RuntimeException("mnist 文件读取失败");        }    });    private static final ThreadLocal testLocal = ThreadLocal.withInitial(() -> {        try {            return new MnistDataSetIterator(128, false, 0);        } catch (IOException e) {            throw new RuntimeException("mnist 文件读取失败");        }    });    private static final MnistEvolution evolution = MnistEvolution.getInstance();    /**     * 环境承载上限     *     * 超过上限AI会进行激烈竞争     */    private final int max;    private Double maxScore,minScore;    /**     * 环境中的生命体     *     * 新生代与历史代共同排序,选出最适应环境的个体     */    //2个变量,一个队列保存KEY的顺序,一个MAP保存KEY对应的具体对象的数据  线程安全map    private final TreeMap lives = new TreeMap<>();    /**     * 初始化环境     *     * 1.向环境中初始化ai     * 2.将初始化ai进行环境适应性测试,并排序     * @param max     */    public MnistEnv(int max){        this.max = max;        for (int i = 0; i < max; i++) {            MnistAgent agent = MnistAgent.newInstance();            test(agent);            synchronized (lives) {                lives.put(agent.getScore(),agent);            }            log.info("初始化智能体 name = {} , score = {}",i,agent.getScore());        }        synchronized (lives) {            minScore = lives.firstKey();            maxScore = lives.lastKey();        }    }    /**     * 环境适应性评估     * @param ai     */    public void test(MnistAgent ai){        MultiLayerNetwork network = ai.getMultiLayerNetwork();        MnistDataSetIterator dataIterator = tLocal.get();        Evaluation eval = new Evaluation(ai.getConfig().getOutput());        try {            while (dataIterator.hasNext()) {                DataSet data = dataIterator.next();                INDArray output = network.output(data.getFeatures(), false);                eval.eval(data.getLabels(),output);            }        }finally {            dataIterator.reset();        }        ai.setScore(eval.accuracy());    }    /**     * 迁移评估     *     * @param ai     */    public void validation(MnistAgent ai){        MultiLayerNetwork network = ai.getMultiLayerNetwork();        MnistDataSetIterator dataIterator = testLocal.get();        Evaluation eval = new Evaluation(ai.getConfig().getOutput());        try {            while (dataIterator.hasNext()) {                DataSet data = dataIterator.next();                INDArray output = network.output(data.getFeatures(), false);                eval.eval(data.getLabels(),output);            }        }finally {            dataIterator.reset();        }        ai.setValidScore(eval.accuracy());    }    /**     * 进化     *     * 每轮随机创建ai并放入环境中进行优胜劣汰     * @param n 进化次数     */    public void evolution(int n){        BlockThreadPool blockThreadPool=new BlockThreadPool(2);        for (int i = 0; i < n; i++) {            blockThreadPool.execute(() -> contend(newLive()));        }//        for (int i = 0; i < n; i++) {//            contend(newLive());//        }    }    /**     * 竞争     * @param ai     */    public void contend(MnistAgent ai){        test(ai);        quality(ai);        double score = ai.getScore();        if(score <= minScore){            UI.put("无法生存",String.format("name = %s,  score = %s", ai.getName(),ai.getScore()));            return;        }        Map.Entry lastEntry;        synchronized (lives) {            lives.put(score,ai);            if (lives.size() > max) {                MnistAgent lastAI = lives.remove(lives.firstKey());                UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore()));            }            lastEntry = lives.lastEntry();            minScore = lives.firstKey();        }        Double lastScore = lastEntry.getKey();        if(lastScore > maxScore){            maxScore = lastScore;            MnistAgent agent = lastEntry.getValue();            validation(agent);            UI.put("max验证",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore()));            try {                Warehouse.write(agent);            } catch (IOException ex) {                log.error("保存对象失败",ex);            }        }    }    ArrayList scoreList = new ArrayList<>(100);    ArrayList avgList = new ArrayList<>();    private void quality(MnistAgent ai) {        synchronized (scoreList) {            scoreList.add(ai.getScore());            if (scoreList.size() >= 100) {                double avg = scoreList.stream().mapToDouble(e -> e)                        .average().getAsDouble();                avgList.add((int) (avg * 1000));                StringBuffer buffer = new StringBuffer();                avgList.forEach(e -> buffer.append(e).append('\t'));                UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString()));                scoreList.clear();            }        }    }    /**     * 随机生成新智能体     *     * 完全随机产生母本     * 随机从比目标相同或更高评分中选择父本     *     * 基因进化在1%~10%之间进行,评分越高基于越稳定     */    public MnistAgent newLive(){        double r = Math.random();        //基因突变率        double v = r / 11 + 0.01;        //母本        MnistAgent mAgent = getMother(r);        //父本        MnistAgent fAgent = getFather(r);        int i = (int) (Math.random() * 3);        INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna());        switch (i){            case 0:                newDNA = evolution.other(newDNA,v);                break;            case 1:                newDNA = evolution.mutation(newDNA,v,0.1);                break;            case 2:                newDNA = evolution.substitution(newDNA,v);                break;        }        return MnistAgent.create(newDNA);    }    /**     * 父本只选择比母本评分高的样本     * @param r     * @return     */    private MnistAgent getFather(double r) {        r += (Math.random() * (1-r));        return getMother(r);    }    private MnistAgent getMother(double r) {        int index = (int) (r * max);        return getMnistAgent(index);    }    private MnistAgent getMnistAgent(int index) {        synchronized (lives) {            Iterator> it = lives.entrySet().iterator();            for (int i = 0; i < index; i++) {                it.next();            }            return it.next().getValue();        }    }}

主函数

@Slf4jpublic class Program {    public static void main(String[] args) {        UI.put("开始时间",new Date().toLocaleString());        MnistEnv env = new MnistEnv(128);        env.evolution(Integer.MAX_VALUE);    }}

运行截图

看完上述内容,你们对dl4j如何使用遗传神经网络完成手写数字识别有进一步的了解吗?如果还想了解更多知识或者相关内容,请关注行业资讯频道,感谢大家的支持。

0