一个偷偷写的svm库
2019-07-13 15:13发布
生成海报
今早刚接触一个新的库——dlib(http://dlib.net),讲真,真的很好用。按照官方的介绍,就是:These wrappers provide a portable object oriented interface for networking, multithreading, GUI development, and file browsing. Programs written
using them can be compiled under POSIX or MS Windows platforms without changing the code.也就是说,DLIB是一个C ++库,用于开发可移植的应用程序与网络处理,线程,图形界面,数据结构,线性代数,机器学习,XML和文本解析,数值优化,贝叶斯网,和许多其他任务。几乎涉及到数据分析的方方面面了。更重要的是,类似于openCV,它提供很多很详细的example,因此学习起来应该不难。由于今天第一天接触,就根据svm分类的example把它改写成了一个二类分类库,多类分类器以后再慢慢加进去。不过,功能应该不太完善。总之,先放上来吧,以后再慢慢改,目前是涉及到nu和C参数的调整,默认是对ganmma和C调参,因为这两个对结果影响最大嘛。
#include
#include
#include
#include "dlib/rand/rand_kernel_abstract.h"
using namespace std;
using namespace dlib;
//svm二类分类器,调用前请修改nFeatures值;
namespace SVM{
#define nFeatures 2
typedef matrix sample_type;//定义数据类型;
typedef radial_basis_kernel kernel_type;//定义核类型;
typedef probabilistic_decision_function probabilistic_funct_type;
typedef normalized_function pfunct_type;
enum Trainer{CTrainer = 1, NUTrainer = 2};
enum LoadType {LoadSamples = 1, LoadTestData = 2};
class SVMClassification{
public:
SVMClassification(){
samples.clear();
labels.clear();
}
~SVMClassification(){}
bool loadData(const char* fn, int opt = LoadSamples)
{
if(! QFile::exists(fn))
{
cout << fn << "does not exist!
";
return false;
}
QFile infile(fn);
if (!infile.open(QIODevice::ReadOnly))
{
cout << fn << "open error!
";
return false;
}
QTextStream _in(&infile);
QString smsg = _in.readLine();
QStringList slist;
if(opt == LoadSamples)
{
samples.clear();
labels.clear();
}
else
testData.clear();
while(! _in.atEnd())
{
sample_type samp;
smsg = _in.readLine();
slist = smsg.split(",");
for (int i = 0; i < nFeatures; i ++)
{
samp(i) = slist[i+1].trimmed().toDouble();
//cout << samp(i)<<" ";
}
if(opt == LoadSamples)
{
samples.push_back(samp);
labels.push_back(slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0);
//cout << (slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0)< best_result(1, 2);
best_result = 0;
best_gamma = 0.0001, best_nu = 0.0001, best_c = 5;
switch(opt)
{
case NUTrainer:
for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
{
for (double nu = 0.00001; nu < max_nu; nu *= 5)
{
trainer.set_kernel(kernel_type(gamma));
trainer.set_nu(nu);
cout << "gamma: " << gamma << " nu: " << nu;
matrix result = cross_validate_trainer(trainer, samples, labels, 10);
cout << " cross validation accuracy: " << result;
if (sum(result) > sum(best_result))
{
best_result = result;
best_gamma = gamma;
best_nu = nu;
}
}
}
cout << "
best gamma: " << best_gamma <<" best nu: " << best_nu<< " best score: "< result = cross_validate_trainer(c_trainer, samples, labels, 10);
cout << " cross validation accuracy: " << result;
if (sum(result) > sum(best_result))
{
best_result = result;
best_gamma = gamma;
best_c = _c;
}
}
}
cout << "
best gamma: " << best_gamma <<" best c: " << best_c<< " best score: "<> learned_pfunct;
cout <<"loaded learned function from "<< fn< samples;
std::vector labels;
std::vector testData;
svm_nu_trainer trainer;
svm_c_trainer c_trainer;
vector_normalizer normalizer;
double best_gamma;
double best_nu;
double best_c;
pfunct_type learned_pfunct;
protected:
};
}
打开微信“扫一扫”,打开网页后点击屏幕右上角分享按钮