parent
63632a821d
commit
cb98f19589
@ -0,0 +1,78 @@
|
|||||||
|
import org.opencv.core.CvType;
|
||||||
|
import org.opencv.core.Mat;
|
||||||
|
import org.opencv.core.TermCriteria;
|
||||||
|
import org.opencv.ml.ANN_MLP;
|
||||||
|
import org.opencv.ml.Ml;
|
||||||
|
import org.opencv.ml.TrainData;
|
||||||
|
|
||||||
|
import java.net.URL;
|
||||||
|
|
||||||
|
public class ML {
|
||||||
|
|
||||||
|
static {
|
||||||
|
URL systemResource = ClassLoader.getSystemResource("lib/x64/opencv_java460.dll");
|
||||||
|
System.load(systemResource.getPath());
|
||||||
|
// System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
// 训练数据 体重,身高
|
||||||
|
float[] train_data = { 186, 80, 185, 81, 160, 68, 168, 61, 160, 50, 161, 48 };
|
||||||
|
// 训练标签 0f,0f->男,1f,1f->女
|
||||||
|
float[] labels = { 0f, 0f, 0f, 0f, 0f, 0f, 1f, 1f, 1f, 1f };
|
||||||
|
// 测试数据 身高,体重
|
||||||
|
float[] test = { 184, 76, 160, 68, 161, 48 };
|
||||||
|
|
||||||
|
Mat train_mat = new Mat(6, 2, CvType.CV_32FC1);
|
||||||
|
train_mat.put(0, 0, train_data);
|
||||||
|
|
||||||
|
Mat lable_mat = new Mat(6, 2, CvType.CV_32FC1);
|
||||||
|
lable_mat.put(0, 0, labels);
|
||||||
|
|
||||||
|
Mat test_mat = new Mat(3, 2, CvType.CV_32FC1);
|
||||||
|
test_mat.put(0, 0, test);
|
||||||
|
|
||||||
|
ANN(train_mat, lable_mat, test_mat);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenCV-4.1.0 ANN 人工神经网络
|
||||||
|
* @Author: hyacinth
|
||||||
|
* @Title: ANN
|
||||||
|
* @param : tarin 训练数据
|
||||||
|
* @param : lable 训练标签
|
||||||
|
* @param : test 测试数据
|
||||||
|
* @Description: TODO
|
||||||
|
* @return void
|
||||||
|
* @date: 2019年11月22日20:23:23
|
||||||
|
*/
|
||||||
|
public static void ANN(Mat tarin, Mat lable, Mat test) {
|
||||||
|
TrainData td = TrainData.create(tarin, Ml.ROW_SAMPLE, lable);
|
||||||
|
Mat layer = new Mat(1, 4, CvType.CV_32FC1);
|
||||||
|
// 含有两个隐含层的网络结构,输入、输出层各两个节点,每个隐含层含两个节点
|
||||||
|
layer.put(0, 0, new float[] { 2, 2, 2, 2 });
|
||||||
|
ANN_MLP ann = ANN_MLP.create();
|
||||||
|
ann.setLayerSizes(layer);
|
||||||
|
ann.setTrainMethod(ANN_MLP.BACKPROP);
|
||||||
|
ann.setBackpropWeightScale(0.1);
|
||||||
|
ann.setBackpropMomentumScale(0.1);
|
||||||
|
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 1, 1);
|
||||||
|
TermCriteria criteria=new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER,300,0);
|
||||||
|
ann.setTermCriteria(criteria);
|
||||||
|
ann.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());
|
||||||
|
ann.save("C:\\Users\\Administrator\\Desktop\\number.xml");
|
||||||
|
|
||||||
|
Mat response = new Mat();
|
||||||
|
ann.predict(test, response, 0);
|
||||||
|
System.out.println(response.dump());
|
||||||
|
for (int i = 0; i < response.size().height; i++) {
|
||||||
|
if (response.get(i, 0)[0] + response.get(i, 1)[0] >= 1) {
|
||||||
|
System.out.println("女");
|
||||||
|
}
|
||||||
|
if (response.get(i, 0)[0] + response.get(i, 1)[0] < 1) {
|
||||||
|
System.out.println("男");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue