Refer to the Capsule Network Code, I am using just the classification module from the mentioned code, So following is the complete classification code that I extracted from the link.
from __future__ import division, print_function, unicode_literals
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
np.random.seed(42)
tf.set_random_seed(42)
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name="X")
caps1_n_maps = 32
caps1_n_caps = caps1_n_maps * 6 * 6 # 1152 primary capsules
caps1_n_dims = 8
conv1_params = {
"filters": 256,
"kernel_size": 9,
"strides": 1,
"padding": "valid",
"activation": tf.nn.relu,
}
conv2_params = {
"filters": caps1_n_maps * caps1_n_dims, # 256 convolutional filters
"kernel_size": 9,
"strides": 2,
"padding": "valid",
"activation": tf.nn.relu
}
conv1 = tf.layers.conv2d(X, name="conv1", **conv1_params)
conv2 = tf.layers.conv2d(conv1, name="conv2", **conv2_params)
caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims],name="caps1_raw")
def squash(s, axis=-1, epsilon=1e-7, name=None):
with tf.name_scope(name, default_name="squash"):
squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True)
safe_norm = tf.sqrt(squared_norm + epsilon)
squash_factor = squared_norm / (1. + squared_norm)
unit_vector = s / safe_norm
return squash_factor * unit_vector
caps1_output = squash(caps1_raw, name="caps1_output")
caps2_n_caps = 10
caps2_n_dims = 16
init_sigma = 0.1
W_init = tf.random_normal(
shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims),
stddev=init_sigma, dtype=tf.float32, name="W_init")
W = tf.Variable(W_init, name="W")
batch_size = tf.shape(X)[0]
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1], name="W_tiled")
caps1_output_expanded = tf.expand_dims(caps1_output, -1,
name="caps1_output_expanded")
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2,
name="caps1_output_tile")
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1],
name="caps1_output_tiled")
caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled,
name="caps2_predicted")
raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],
dtype=np.float32, name="raw_weights")
#ROUND 1
routing_weights = tf.nn.softmax(raw_weights, dim=2, name="routing_weights")
weighted_predictions = tf.multiply(routing_weights, caps2_predicted,
name="weighted_predictions")
weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keep_dims=True,
name="weighted_sum")
caps2_output_round_1 = squash(weighted_sum, axis=-2,
name="caps2_output_round_1")
#ROUND 2
caps2_output_round_1_tiled = tf.tile(
caps2_output_round_1, [1, caps1_n_caps, 1, 1, 1],
name="caps2_output_round_1_tiled")
agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled,
transpose_a=True, name="agreement")
raw_weights_round_2 = tf.add(raw_weights, agreement,
name="raw_weights_round_2")
routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2,
dim=2,
name="routing_weights_round_2")
weighted_predictions_round_2 = tf.multiply(routing_weights_round_2,
caps2_predicted,
name="weighted_predictions_round_2")
weighted_sum_round_2 = tf.reduce_sum(weighted_predictions_round_2,
axis=1, keep_dims=True,
name="weighted_sum_round_2")
caps2_output_round_2 = squash(weighted_sum_round_2,
axis=-2,
name="caps2_output_round_2")
caps2_output = caps2_output_round_2
#ESTIMATE CLASS PROBABILITIES
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):
with tf.name_scope(name, default_name="safe_norm"):
squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
keep_dims=keep_dims)
return tf.sqrt(squared_norm + epsilon)
y_proba = safe_norm(caps2_output, axis=-2, name="y_proba")
y_proba_argmax = tf.argmax(y_proba, axis=2, name="y_proba")
y_pred = tf.squeeze(y_proba_argmax, axis=[1,2], name="y_pred")
y = tf.placeholder(shape=[None], dtype=tf.int64, name="y")
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5
T = tf.one_hot(y, depth=caps2_n_caps, name="T")
caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True,
name="caps2_output_norm")
present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm),
name="present_error_raw")
present_error = tf.reshape(present_error_raw, shape=(-1, 10),
name="present_error")
absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus),
name="absent_error_raw")
absent_error = tf.reshape(absent_error_raw, shape=(-1, 10),
name="absent_error")
L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error,
name="L")
margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name="margin_loss")
loss = tf.add(margin_loss, 0, name="loss")
correct = tf.equal(y, y_pred, name="correct")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")
optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss, name="training_op")
init = tf.global_variables_initializer()
saver = tf.train.Saver()
n_epochs = 10
batch_size = 50
restore_checkpoint = True
n_iterations_per_epoch = mnist.train.num_examples // batch_size
n_iterations_validation = mnist.validation.num_examples // batch_size
best_loss_val = np.infty
checkpoint_path = "./my_capsule_network"
with tf.Session() as sess:
if restore_checkpoint and tf.train.checkpoint_exists(checkpoint_path):
saver.restore(sess, checkpoint_path)
else:
init.run()
for epoch in range(n_epochs):
for iteration in range(1, n_iterations_per_epoch + 1):
X_batch, y_batch = mnist.train.next_batch(batch_size)
# Run the training operation and measure the loss:
_, loss_train = sess.run(
[training_op, loss],
feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
y: y_batch})
print("\rIteration: {}/{} ({:.1f}%) Loss: {:.5f}".format(
iteration, n_iterations_per_epoch,
iteration * 100 / n_iterations_per_epoch,
loss_train),
end="")
# At the end of each epoch,
# measure the validation loss and accuracy:
loss_vals = []
acc_vals = []
for iteration in range(1, n_iterations_validation + 1):
X_batch, y_batch = mnist.validation.next_batch(batch_size)
loss_val, acc_val = sess.run(
[loss, accuracy],
feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
y: y_batch})
loss_vals.append(loss_val)
acc_vals.append(acc_val)
print("\rEvaluating the model: {}/{} ({:.1f}%)".format(
iteration, n_iterations_validation,
iteration * 100 / n_iterations_validation),
end=" " * 10)
loss_val = np.mean(loss_vals)
acc_val = np.mean(acc_vals)
print("\rEpoch: {} Val accuracy: {:.4f}% Loss: {:.6f}{}".format(
epoch + 1, acc_val * 100, loss_val,
" (improved)" if loss_val < best_loss_val else ""))
# And save the model if it improved:
if loss_val < best_loss_val:
save_path = saver.save(sess, checkpoint_path)
best_loss_val = loss_val
n_iterations_test = mnist.test.num_examples // batch_size
with tf.Session() as sess:
saver.restore(sess, checkpoint_path)
loss_tests = []
acc_tests = []
for iteration in range(1, n_iterations_test + 1):
X_batch, y_batch = mnist.test.next_batch(batch_size)
loss_test, acc_test = sess.run(
[loss, accuracy],
feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
y: y_batch})
loss_tests.append(loss_test)
acc_tests.append(acc_test)
print("\rEvaluating the model: {}/{} ({:.1f}%)".format(
iteration, n_iterations_test,
iteration * 100 / n_iterations_test),
end=" " * 10)
loss_test = np.mean(loss_tests)
acc_test = np.mean(acc_tests)
print("\rFinal test accuracy: {:.4f}% Loss: {:.6f}".format(
acc_test * 100, loss_test))
Since reconstruction module is not required so in the final loss function I have set loss = 0 for the reconstruction part so that optimizer only work on the classification loss.
loss = tf.add(margin_loss, 0, name="loss")
Now the classification module is working but the accuracy is pathetically low even for MNIST dataset, I got around 10% validation accuracy, Can somebody explain to me what is the problem here because on MNIST the Capsule Networks perform well in classification as discussed in Paper?
Regards
from Classification Module from Capsule Network Code
No comments:
Post a Comment