You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
7.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package xyz.wbsite.ai;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*
* *本计划和随附材料可在
* Apache许可证2.0版的条款,可在
* https://www.apache.org/licenses/LICENSE-2.0.
* 有关更多信息,请参阅随此作品分发的通知文件
* 关于版权所有权的信息。
*
* 除非适用法律要求或书面同意,否则软件
* 根据许可证分发的内容是按“原样”分发的,没有
* 任何明示或暗示的保证或条件。请参阅
* 特定语言的许可证管理权限和限制
* 根据许可证。
*
*SPDX许可证标识符Apache-2.0
******************************************************************************/
public class LeNetMNIST {
private static final Logger log = LoggerFactory.getLogger(LeNetMNIST.class);
public static void main(String[] args) throws Exception {
int nChannels = 1; // Number of input channels 输入通道数量
int outputNum = 10; // The number of possible outcomes 可能结果的数量
int batchSize = 64; // Test batch size 试验批量
int nEpochs = 1; // Number of training epochs 训练周期数
int seed = 123; //
/*
* Create an iterator using the batch size for one iteration
* 使用一次迭代的批大小创建迭代器
*/
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
/*
* Construct the neural network
* 构建神经网络
*/
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.l2(0.0005)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(1e-3))
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
//nIn和nOut指定深度。n这里是nChannelnOut是要应用的过滤器数量
.nIn(nChannels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new ConvolutionLayer.Builder(5, 5)
//Note that nIn need not be specified in later layers
//请注意不需要在后面的层中指定nIn
.stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below
.build();
/*
* Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things.
* (a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers
* and the dense layer
* (b) Does some additional configuration validation
* (c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each
* layer based on the size of the previous layer (but it won't override values manually set by the user)
* InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs.
* For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth).
* MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened"
* row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here.
*
* 关于.setInputTypeInputType.convolutionalFlat28,28,1这做了一些事情。
* a 它添加了预处理器,处理卷积/子采样层之间的转换等事情
* 致密层
* b 是否进行了一些额外的配置验证
* c 必要时为每个神经元设置nIn输入神经元数量或CNN情况下的输入深度
* 基于上一层大小的层(但它不会覆盖用户手动设置的值)
* InputTypes也可以与其他层类型RNN、MLP等一起使用而不仅仅是CNN。
* 对于普通图像使用ImageRecordReader时请使用InputType.convolutional高度、宽度、深度
* MNIST记录读取器是一种特殊情况它以“平坦”的方式输出28x28像素灰度nCannels=1图像
* 行向量格式即1x784向量因此这里使用的是“卷积平面”输入类型。
*/
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
log.info("Train model...");
model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(mnistTest, 1, InvocationType.EPOCH_END)); //Print score every 10 iterations and evaluate on test set every epoch
model.fit(mnistTrain, nEpochs);
String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "lenetmnist.zip");
log.info("Saving model to tmp folder: " + path);
model.save(new File(path), true);
log.info("****************Example finished********************");
}
}

Powered by TurnKey Linux.