TensorflowでU-netを動かしたメモ
U-netのtensorflow版は下記で公開されている。
https://github.com/jakeret/tf_unet
Installなどはサイトのインストラクションに従う。
Jupyterで書いたもの
=== Train phase ===
from __future__ import division, print_function
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import glob
plt.rcParams['image.cmap'] = 'gist_earth'
# import tf_unet
from tf_unet import unet, util, image_util
#preparing data loading
data_provider = image_util.ImageDataProvider("./images/train/*.png",data_suffix='.png',mask_suffix='_mask.png')
#setup & training
net = unet.Unet(layers=4,features_root=128, channels=3, n_class=2)
trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
output_path = './check_points'
path = trainer.train(data_provider, output_path, training_iters=100, epochs=10
, dropout=0.75
, display_step=10
, write_graph=True
)
===
=== Test phase ===
from __future__ import division, print_function
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import glob
plt.rcParams['image.cmap'] = 'gist_earth'
from tf_unet import unet, util, image_util
# set model parameters :: should be same as those in train phase
net = unet.Unet(layers=4,features_root=128, channels=3, n_class=2)
# prediction with trained model
tmp_path='./check_points/model.cpkt'
data_test = image_util.ImageDataProvider("./images/test/*.png",data_suffix='.png',mask_suffix='_mask.png')
x_test, y_test = data_test(1)
prediction = net.predict(tmp_path, x_test)
# Output #
fig, ax = plt.subplots(1,3, figsize=(12,4))
# Convert float to int64 for rgb color output
rgb_test = util.to_rgb(x_test)
rgb_test = rgb_test.astype(np.int)
ax[0].imshow(rgb_test[0,:,:,:], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")
ax[2].imshow(prediction[0,...,1], aspect="auto")
===