DSP

一个偷偷写的svm库

2019-07-13 15:13发布

今早刚接触一个新的库——dlib(http://dlib.net),讲真,真的很好用。按照官方的介绍,就是:These wrappers provide a portable object oriented interface for networking, multithreading, GUI development, and file browsing. Programs written using them can be compiled under POSIX or MS Windows platforms without changing the code.也就是说,DLIB是一个C ++库,用于开发可移植的应用程序与网络处理,线程,图形界面,数据结构,线性代数,机器学习,XML和文本解析,数值优化,贝叶斯网,和许多其他任务。几乎涉及到数据分析的方方面面了。更重要的是,类似于openCV,它提供很多很详细的example,因此学习起来应该不难。由于今天第一天接触,就根据svm分类的example把它改写成了一个二类分类库,多类分类器以后再慢慢加进去。不过,功能应该不太完善。总之,先放上来吧,以后再慢慢改,目前是涉及到nu和C参数的调整,默认是对ganmma和C调参,因为这两个对结果影响最大嘛。
#include #include #include #include "dlib/rand/rand_kernel_abstract.h" using namespace std; using namespace dlib; //svm二类分类器,调用前请修改nFeatures值; namespace SVM{ #define nFeatures 2 typedef matrix sample_type;//定义数据类型; typedef radial_basis_kernel kernel_type;//定义核类型; typedef probabilistic_decision_function probabilistic_funct_type; typedef normalized_function pfunct_type; enum Trainer{CTrainer = 1, NUTrainer = 2}; enum LoadType {LoadSamples = 1, LoadTestData = 2}; class SVMClassification{ public: SVMClassification(){ samples.clear(); labels.clear(); } ~SVMClassification(){} bool loadData(const char* fn, int opt = LoadSamples) { if(! QFile::exists(fn)) { cout << fn << "does not exist! "; return false; } QFile infile(fn); if (!infile.open(QIODevice::ReadOnly)) { cout << fn << "open error! "; return false; } QTextStream _in(&infile); QString smsg = _in.readLine(); QStringList slist; if(opt == LoadSamples) { samples.clear(); labels.clear(); } else testData.clear(); while(! _in.atEnd()) { sample_type samp; smsg = _in.readLine(); slist = smsg.split(","); for (int i = 0; i < nFeatures; i ++) { samp(i) = slist[i+1].trimmed().toDouble(); //cout << samp(i)<<" "; } if(opt == LoadSamples) { samples.push_back(samp); labels.push_back(slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0); //cout << (slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0)< best_result(1, 2); best_result = 0; best_gamma = 0.0001, best_nu = 0.0001, best_c = 5; switch(opt) { case NUTrainer: for (double gamma = 0.00001; gamma <= 1; gamma *= 5) { for (double nu = 0.00001; nu < max_nu; nu *= 5) { trainer.set_kernel(kernel_type(gamma)); trainer.set_nu(nu); cout << "gamma: " << gamma << " nu: " << nu; matrix result = cross_validate_trainer(trainer, samples, labels, 10); cout << " cross validation accuracy: " << result; if (sum(result) > sum(best_result)) { best_result = result; best_gamma = gamma; best_nu = nu; } } } cout << " best gamma: " << best_gamma <<" best nu: " << best_nu<< " best score: "< result = cross_validate_trainer(c_trainer, samples, labels, 10); cout << " cross validation accuracy: " << result; if (sum(result) > sum(best_result)) { best_result = result; best_gamma = gamma; best_c = _c; } } } cout << " best gamma: " << best_gamma <<" best c: " << best_c<< " best score: "<> learned_pfunct; cout <<"loaded learned function from "<< fn< samples; std::vector labels; std::vector testData; svm_nu_trainer trainer; svm_c_trainer c_trainer; vector_normalizer normalizer; double best_gamma; double best_nu; double best_c; pfunct_type learned_pfunct; protected: }; }