Zi 字媒體
2017-07-25T20:27:27+00:00
fanfuhan OpenCV 教學116 ~ opencv-116-決策樹算法介紹與使用
資料來源: https://fanfuhan.github.io/
https://fanfuhan.github.io/2019/05/25/opencv-116/
GITHUB:https://github.com/jash-git/fanfuhan_ML_OpenCV
OpenCV中機器學習模塊的決策樹算法分為兩個類別,一個是隨機森林(Random Trees),另外一個強化分類(Boosting分類)
C++
#include
#include
using namespace cv;
using namespace cv::ml;
using namespace std;
int main(int argc, char** argv) {
Mat data = imread("D:/projects/opencv_tutorial/data/images/digits.png");
Mat gray;
cvtColor(data, gray, COLOR_BGR2GRAY);
// 分割为5000个cells
Mat images = Mat::zeros(5000, 400, CV_8UC1);
Mat labels = Mat::zeros(5000, 1, CV_8UC1);
int index = 0;
Rect roi;
roi.x = 0;
roi.height = 1;
roi.width = 400;
for (int row = 0; row < 50; row++) {
int label = row / 5;
int offsety = row * 20;
for (int col = 0; col < 100; col++) {
int offsetx = col * 20;
Mat digit = Mat::zeros(Size(20, 20), CV_8UC1);
for (int sr = 0; sr < 20; sr++) {
for (int sc = 0; sc < 20; sc++) {
digit.at(sr, sc) = gray.at(sr + offsety, sc + offsetx);
}
}
Mat one_row = digit.reshape(1, 1);
printf("index : %d \n", index);
roi.y = index;
one_row.copyTo(images(roi));
labels.at(index, 0) = label;
index++;
}
}
printf("load sample hand-writing data...\n");
imwrite("D:/result.png", images);
// 转换为浮点数
images.convertTo(images, CV_32FC1);
labels.convertTo(labels, CV_32SC1);
printf("load sample hand-writing data...\n");
// 开始训练
printf("Start to Random Trees train...\n");
Ptr model = RTrees::create();
/*model->setMaxDepth(10);
model->setMinSampleCount(10);
model->setRegressionAccuracy(0);
model->setUseSurrogates(false);
model->setMaxCategories(15);
model->setPriors(Mat());
model->setCalculateVarImportance(true);
model->setActiveVarCount(4);
*/
TermCriteria tc = TermCriteria(TermCriteria::MAX_ITER + TermCriteria::EPS, 100, 0.01);
model->setTermCriteria(tc);
Ptr tdata = ml::TrainData::create(images, ml::ROW_SAMPLE, labels);
model->train(tdata);
model->save("D:/vcworkspaces/rtrees_knowledge.yml");
printf("Finished Random trees...\n");
waitKey(0);
return true;
}
Python
"""
决策树算法 介绍与使用
"""
import cv2 as cv
import numpy as np
# 读取数据
img = cv.imread('images/digits.png')
gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)]
x = np.array(cells)
# 创建训练与测试数据
train = x[:, :50].reshape(-1, 400).astype(np.float32)
test = x[:, 50:100].reshape(-1, 400).astype(np.float32)
k = np.arange(10)
train_labels = np.repeat(k, 250)[:, np.newaxis]
test_labels = train_labels.copy()
# 训练随机树
dt = cv.ml.RTrees_create()
dt.train(train, cv.ml.ROW_SAMPLE, train_labels)
retval, results = dt.predict(test)
# 计算准确率
matches = results == test_labels
correct = np.count_nonzero(matches)
accuracy = correct / results.size
print(accuracy)
cv.waitKey(0)
cv.destroyAllWindows()
寫了
5860316篇文章,獲得
23313次喜歡