神经网络

master
王兵 2 months ago
parent 63632a821d
commit cb98f19589

@ -0,0 +1,78 @@
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Ml;
import org.opencv.ml.TrainData;
import java.net.URL;
public class ML {
static {
URL systemResource = ClassLoader.getSystemResource("lib/x64/opencv_java460.dll");
System.load(systemResource.getPath());
// System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void main(String[] args) {
// 训练数据 体重,身高
float[] train_data = { 186, 80, 185, 81, 160, 68, 168, 61, 160, 50, 161, 48 };
// 训练标签 0f,0f->男,1f,1f->女
float[] labels = { 0f, 0f, 0f, 0f, 0f, 0f, 1f, 1f, 1f, 1f };
// 测试数据 身高,体重
float[] test = { 184, 76, 160, 68, 161, 48 };
Mat train_mat = new Mat(6, 2, CvType.CV_32FC1);
train_mat.put(0, 0, train_data);
Mat lable_mat = new Mat(6, 2, CvType.CV_32FC1);
lable_mat.put(0, 0, labels);
Mat test_mat = new Mat(3, 2, CvType.CV_32FC1);
test_mat.put(0, 0, test);
ANN(train_mat, lable_mat, test_mat);
}
/**
* OpenCV-4.1.0 ANN
* @Author: hyacinth
* @Title: ANN
* @param : tarin
* @param : lable
* @param : test
* @Description: TODO
* @return void
* @date: 2019112220:23:23
*/
public static void ANN(Mat tarin, Mat lable, Mat test) {
TrainData td = TrainData.create(tarin, Ml.ROW_SAMPLE, lable);
Mat layer = new Mat(1, 4, CvType.CV_32FC1);
// 含有两个隐含层的网络结构,输入、输出层各两个节点,每个隐含层含两个节点
layer.put(0, 0, new float[] { 2, 2, 2, 2 });
ANN_MLP ann = ANN_MLP.create();
ann.setLayerSizes(layer);
ann.setTrainMethod(ANN_MLP.BACKPROP);
ann.setBackpropWeightScale(0.1);
ann.setBackpropMomentumScale(0.1);
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
TermCriteria criteria=new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER,300,0);
ann.setTermCriteria(criteria);
ann.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());
ann.save("C:\\Users\\Administrator\\Desktop\\number.xml");
Mat response = new Mat();
ann.predict(test, response, 0);
System.out.println(response.dump());
for (int i = 0; i < response.size().height; i++) {
if (response.get(i, 0)[0] + response.get(i, 1)[0] >= 1) {
System.out.println("女");
}
if (response.get(i, 0)[0] + response.get(i, 1)[0] < 1) {
System.out.println("男");
}
}
}
}
Loading…
Cancel
Save

Powered by TurnKey Linux.