|
|
|
@ -0,0 +1,158 @@
|
|
|
|
|
package xyz.wbsite.ai;
|
|
|
|
|
|
|
|
|
|
import org.apache.commons.io.FilenameUtils;
|
|
|
|
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
|
|
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
|
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
|
|
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
|
|
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
|
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
|
|
|
import org.deeplearning4j.optimize.api.InvocationType;
|
|
|
|
|
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
|
|
|
|
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
|
|
|
|
import org.nd4j.linalg.activations.Activation;
|
|
|
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
|
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
|
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
|
import org.slf4j.Logger;
|
|
|
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
|
|
|
|
|
|
import java.io.File;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 构建手写数字识别模型
|
|
|
|
|
* <p>
|
|
|
|
|
* <p>
|
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* See the NOTICE file distributed with this work for additional
|
|
|
|
|
* information regarding copyright ownership.
|
|
|
|
|
* <p>
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
|
* under the License.
|
|
|
|
|
* <p>
|
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
* <p>
|
|
|
|
|
* *本计划和随附材料可在
|
|
|
|
|
* Apache许可证2.0版的条款,可在
|
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* 有关更多信息,请参阅随此作品分发的通知文件
|
|
|
|
|
* 关于版权所有权的信息。
|
|
|
|
|
* <p>
|
|
|
|
|
* 除非适用法律要求或书面同意,否则软件
|
|
|
|
|
* 根据许可证分发的内容是按“原样”分发的,没有
|
|
|
|
|
* 任何明示或暗示的保证或条件。请参阅
|
|
|
|
|
* 特定语言的许可证管理权限和限制
|
|
|
|
|
* 根据许可证。
|
|
|
|
|
* <p>
|
|
|
|
|
* SPDX许可证标识符:Apache-2.0
|
|
|
|
|
******************************************************************************/
|
|
|
|
|
public class Dl4j_LeNetMNIST {
|
|
|
|
|
private static final Logger log = LoggerFactory.getLogger(Dl4j_LeNetMNIST.class);
|
|
|
|
|
|
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
|
|
|
int nChannels = 1; // Number of input channels 输入通道数量
|
|
|
|
|
int outputNum = 10; // The number of possible outcomes 可能结果的数量
|
|
|
|
|
int batchSize = 64; // Test batch size 试验批量
|
|
|
|
|
int nEpochs = 1; // Number of training epochs 训练周期数
|
|
|
|
|
int seed = 123; //
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* Create an iterator using the batch size for one iteration
|
|
|
|
|
* 使用一次迭代的批大小创建迭代器
|
|
|
|
|
*/
|
|
|
|
|
log.info("Load data....");
|
|
|
|
|
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
|
|
|
|
|
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* Construct the neural network
|
|
|
|
|
* 构建神经网络
|
|
|
|
|
*/
|
|
|
|
|
log.info("Build model....");
|
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
|
|
|
.seed(seed)
|
|
|
|
|
.l2(0.0005)
|
|
|
|
|
.weightInit(WeightInit.XAVIER)
|
|
|
|
|
.updater(new Adam(1e-3))
|
|
|
|
|
.list()
|
|
|
|
|
.layer(new ConvolutionLayer.Builder(5, 5)
|
|
|
|
|
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
|
|
|
|
|
//nIn和nOut指定深度。n这里是nChannel,nOut是要应用的过滤器数量
|
|
|
|
|
.nIn(nChannels)
|
|
|
|
|
.stride(1, 1)
|
|
|
|
|
.nOut(20)
|
|
|
|
|
.activation(Activation.IDENTITY)
|
|
|
|
|
.build())
|
|
|
|
|
.layer(new SubsamplingLayer.Builder(PoolingType.MAX)
|
|
|
|
|
.kernelSize(2, 2)
|
|
|
|
|
.stride(2, 2)
|
|
|
|
|
.build())
|
|
|
|
|
.layer(new ConvolutionLayer.Builder(5, 5)
|
|
|
|
|
//Note that nIn need not be specified in later layers
|
|
|
|
|
//请注意,不需要在后面的层中指定nIn
|
|
|
|
|
.stride(1, 1)
|
|
|
|
|
.nOut(50)
|
|
|
|
|
.activation(Activation.IDENTITY)
|
|
|
|
|
.build())
|
|
|
|
|
.layer(new SubsamplingLayer.Builder(PoolingType.MAX)
|
|
|
|
|
.kernelSize(2, 2)
|
|
|
|
|
.stride(2, 2)
|
|
|
|
|
.build())
|
|
|
|
|
.layer(new DenseLayer.Builder().activation(Activation.RELU)
|
|
|
|
|
.nOut(500).build())
|
|
|
|
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
|
|
|
.nOut(outputNum)
|
|
|
|
|
.activation(Activation.SOFTMAX)
|
|
|
|
|
.build())
|
|
|
|
|
.setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things.
|
|
|
|
|
* (a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers
|
|
|
|
|
* and the dense layer
|
|
|
|
|
* (b) Does some additional configuration validation
|
|
|
|
|
* (c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each
|
|
|
|
|
* layer based on the size of the previous layer (but it won't override values manually set by the user)
|
|
|
|
|
* InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs.
|
|
|
|
|
* For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth).
|
|
|
|
|
* MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened"
|
|
|
|
|
* row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here.
|
|
|
|
|
*
|
|
|
|
|
* 关于.setInputType(InputType.convolutionalFlat(28,28,1))行:这做了一些事情。
|
|
|
|
|
* (a) 它添加了预处理器,处理卷积/子采样层之间的转换等事情
|
|
|
|
|
* 致密层
|
|
|
|
|
* (b) 是否进行了一些额外的配置验证
|
|
|
|
|
* (c) 必要时,为每个神经元设置nIn(输入神经元数量,或CNN情况下的输入深度)值
|
|
|
|
|
* 基于上一层大小的层(但它不会覆盖用户手动设置的值)
|
|
|
|
|
* InputTypes也可以与其他层类型(RNN、MLP等)一起使用,而不仅仅是CNN。
|
|
|
|
|
* 对于普通图像(使用ImageRecordReader时),请使用InputType.convolutional(高度、宽度、深度)。
|
|
|
|
|
* MNIST记录读取器是一种特殊情况,它以“平坦”的方式输出28x28像素灰度(nCannels=1)图像
|
|
|
|
|
* 行向量格式(即1x784向量),因此这里使用的是“卷积平面”输入类型。
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
|
|
|
|
model.init();
|
|
|
|
|
|
|
|
|
|
log.info("Train model...");
|
|
|
|
|
model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(mnistTest, 1, InvocationType.EPOCH_END)); //Print score every 10 iterations and evaluate on test set every epoch
|
|
|
|
|
model.fit(mnistTrain, nEpochs);
|
|
|
|
|
|
|
|
|
|
String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "lenetmnist.zip");
|
|
|
|
|
|
|
|
|
|
log.info("Saving model to tmp folder: " + path);
|
|
|
|
|
model.save(new File(path), true);
|
|
|
|
|
|
|
|
|
|
log.info("****************Example finished********************");
|
|
|
|
|
}
|
|
|
|
|
}
|