博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
c++实现unet
阅读量:5878 次
发布时间:2019-06-19

本文共 5402 字,大约阅读时间需要 18 分钟。

#include
#include
#include
#include
#include
#include
#include
class double_conv:public torch::nn::Module{ public: torch::nn::Conv2d conv1,conv2; torch::nn::BatchNorm bn1,bn2; int in_ch,out_ch; public: double_conv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv1(torch::nn::Conv2dOptions(in_ch,out_ch,3).padding(1)),bn1(out_ch), conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch) { register_module("conv1",conv1); register_module("conv2",conv2); register_module("bn1",bn1); register_module("bn2",bn2); } torch::Tensor forward(torch::Tensor x) { x = conv1->forward(x); x = bn1->forward(x); x = torch::relu(x); x = conv2->forward(x); x = bn2->forward(x); x = torch::relu(x); return x; }};class inconv:public torch::nn::Module{ public: int in_ch,out_ch; public: inconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){} torch::Tensor forward(torch::Tensor x) { double_conv dc(in_ch,out_ch); x = dc.forward(x); return x; }};class down:public torch::nn::Module{ public: int in_ch,out_ch; public: down(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){} torch::Tensor forward(torch::Tensor x) { x = torch::max_pool2d(x,2); double_conv dc(in_ch,out_ch); x = dc.forward(x); return x; }};class up:public torch::nn::Module{ public: int in_ch,out_ch; torch::nn::Conv2d upconv; torch::nn::Conv2d conv1,conv2; torch::nn::BatchNorm bn1,bn2; torch::Tensor x; public: up(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),upconv(torch::nn::Conv2dOptions(in_ch,out_ch,4).padding(1).stride(2).transposed(new bool(true))), conv1(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn1(out_ch),conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch) { register_module("upconv",upconv); register_module("conv1",conv2); register_module("conv2",conv2); register_module("bn1",bn1); register_module("bn2",bn2); } torch::Tensor forward(torch::Tensor x1,torch::Tensor x2) { x = upconv->forward(x1); x = torch::cat({x,x2},1); double_conv dc(x.size(1),out_ch); x = dc.forward(x); //x = conv1->forward(x); //x = bn1->forward(x); //x = torch::relu(x); //x = conv2->forward(x); //x = bn2->forward(x); //x = torch::relu(x); return x; }};class outconv:public torch::nn::Module{ public: int in_ch,out_ch; torch::nn::Conv2d conv; public: outconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv(torch::nn::Conv2dOptions(in_ch,out_ch,1).padding(0)) { register_module("conv",conv); } torch::Tensor forward(torch::Tensor x) { return conv->forward(x); }};class unet:public torch::nn::Module{ public: int n_ch,n_class; inconv *iconv= new inconv(n_ch,64); down *down1= new down(64,256); down *down2= new down(256,512); down *down3= new down(512,512); down *down4= new down(512,512); up *up1= new up(512,256); up *up2= new up(256,128); up *up3= new up(128,64); up *up4= new up(64,64); outconv *oconv= new outconv(64,n_class); torch::Tensor x1,x2,x3,x4,x5; public: unet(int n_ch,int n_class):n_ch(n_ch),n_class(n_class){} torch::Tensor forward(torch::Tensor x) { x1 = iconv->forward(x); x2 = down1->forward(x1); x3 = down2->forward(x2); x4 = down3->forward(x3); x5 = down4->forward(x4); x = up1->forward(x5,x4); x = up2->forward(x,x3); x = up3->forward(x,x2); x = up4->forward(x,x1); x = oconv->forward(x); return x; }};std::vector
Tokenize(const std::string& str,const std::string& delimiters){ std::vector
tokens; std::string::size_type lastPos = str.find_first_not_of(delimiters, 0); std::string::size_type pos = str.find_first_of(delimiters, lastPos); while (std::string::npos != pos || std::string::npos != lastPos) { tokens.push_back(std::atof(str.substr(lastPos, pos - lastPos).c_str())); lastPos = str.find_first_not_of(delimiters, pos); pos = str.find_first_of(delimiters, lastPos); } return tokens;}std::vector
> readTxt(std::string file){ std::ifstream infile; infile.open(file.data()); assert(infile.is_open()); std::string s; std::vector
vec; std::vector
> res; while(getline(infile,s)) { std::string tt= static_cast
(s); vec = Tokenize(tt, " "); res.push_back(vec); } infile.close(); std::cout<<"gdood"<
> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/LabelData.txt"); int ch = vec.size(); int len = vec[0].size(); for(int i=0;i
> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/ImageData.txt"); int ch = vec.size(); int len = vec[0].size(); for(int i=0;i
DataLoader(torch::Tensor data,torch::Tensor label,int batch_size){ int imghight = data.size(1); int imgwidth = data.size(2); int randhight,randwidth; torch::Tensor resdata = torch::zeros({batch_size,7,imgH,imgW}); torch::Tensor reslabel = torch::zeros({batch_size,imgH,imgW}); for(int i=0;i
weights,std::vector
key){ std::ofstream fout("unet.txt"); //std::unordered_map
mp; for(int i=0;i
vecdata; for(int epoch=0;epoch<20;epoch++) { vecdata = DataLoader(data,label,2); std::cout<<"vecdata after done!!"<
vecValue; std::vector
vecKey; torch::nn::ParameterCursor tt = model.parameters(); for(auto it=tt.begin();it!=tt.end();it++) { vecValue.push_back((*it).value); vecKey.push_back((*it).key); } saveModel(vecValue,vecKey); torch::autograd::Variable predData = Get_predData(data); torch::autograd::Variable fl = model.forward(predData); torch::autograd::Variable result = torch::squeeze(fl); torch::autograd::Variable rt = result.argmax(0); std::cout<
<

 

转载于:https://www.cnblogs.com/semen/p/9778300.html

你可能感兴趣的文章
Android网络编程11之源码解析Retrofit
查看>>
esxi主机之添加新用户的访问权限
查看>>
AD账户被锁信息通知脚本
查看>>
数据集市和数据仓库的关系
查看>>
python 断言
查看>>
在Centos 5.6下面利用instant 安装oracle客户端
查看>>
用虚拟环境保存库文件--Python打包
查看>>
NoSQL数据库一MongoDB基本使用
查看>>
/proc/sys/vm/drop_caches用法备忘
查看>>
selinux 常用命令
查看>>
linux中KVM桥接网卡br0
查看>>
Redis的安装和使用之一 -----Redis相关运用
查看>>
snmp安装
查看>>
spring mvc 批量上传+文件上传
查看>>
Asm Instance Parameter Best Practice
查看>>
思科路由器寄存器值
查看>>
发送验证邮件的三种方法
查看>>
如何一键去除域名非80端口,教你如何去除网址后面的端口号
查看>>
rsync的应用实践详解
查看>>
Linux安装Nginx
查看>>