package xyz.wbsite.ai; import org.apache.commons.io.FilenameUtils; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.InvocationType; import org.deeplearning4j.optimize.listeners.EvaluativeListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; /* ***************************************************************************** * * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 * * *本计划和随附材料可在 * Apache许可证2.0版的条款,可在 * https://www.apache.org/licenses/LICENSE-2.0. * 有关更多信息,请参阅随此作品分发的通知文件 * 关于版权所有权的信息。 * * 除非适用法律要求或书面同意,否则软件 * 根据许可证分发的内容是按“原样”分发的,没有 * 任何明示或暗示的保证或条件。请参阅 * 特定语言的许可证管理权限和限制 * 根据许可证。 * *SPDX许可证标识符:Apache-2.0 ******************************************************************************/ public class LeNetMNIST { private static final Logger log = LoggerFactory.getLogger(LeNetMNIST.class); public static void main(String[] args) throws Exception { int nChannels = 1; // Number of input channels 输入通道数量 int outputNum = 10; // The number of possible outcomes 可能结果的数量 int batchSize = 64; // Test batch size 试验批量 int nEpochs = 1; // Number of training epochs 训练周期数 int seed = 123; // /* * Create an iterator using the batch size for one iteration * 使用一次迭代的批大小创建迭代器 */ log.info("Load data...."); DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); /* * Construct the neural network * 构建神经网络 */ log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Adam(1e-3)) .list() .layer(new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied //nIn和nOut指定深度。n这里是nChannel,nOut是要应用的过滤器数量 .nIn(nChannels) .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build()) .layer(new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(new ConvolutionLayer.Builder(5, 5) //Note that nIn need not be specified in later layers //请注意,不需要在后面的层中指定nIn .stride(1, 1) .nOut(50) .activation(Activation.IDENTITY) .build()) .layer(new SubsamplingLayer.Builder(PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(new DenseLayer.Builder().activation(Activation.RELU) .nOut(500).build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) .setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below .build(); /* * Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things. * (a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers * and the dense layer * (b) Does some additional configuration validation * (c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each * layer based on the size of the previous layer (but it won't override values manually set by the user) * InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs. * For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth). * MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened" * row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here. * * 关于.setInputType(InputType.convolutionalFlat(28,28,1))行:这做了一些事情。 * (a) 它添加了预处理器,处理卷积/子采样层之间的转换等事情 * 致密层 * (b) 是否进行了一些额外的配置验证 * (c) 必要时,为每个神经元设置nIn(输入神经元数量,或CNN情况下的输入深度)值 * 基于上一层大小的层(但它不会覆盖用户手动设置的值) * InputTypes也可以与其他层类型(RNN、MLP等)一起使用,而不仅仅是CNN。 * 对于普通图像(使用ImageRecordReader时),请使用InputType.convolutional(高度、宽度、深度)。 * MNIST记录读取器是一种特殊情况,它以“平坦”的方式输出28x28像素灰度(nCannels=1)图像 * 行向量格式(即1x784向量),因此这里使用的是“卷积平面”输入类型。 */ MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); log.info("Train model..."); model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(mnistTest, 1, InvocationType.EPOCH_END)); //Print score every 10 iterations and evaluate on test set every epoch model.fit(mnistTrain, nEpochs); String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "lenetmnist.zip"); log.info("Saving model to tmp folder: " + path); model.save(new File(path), true); log.info("****************Example finished********************"); } }