Tensor Nets (compressing neural networks)

Open this page in an interactive mode via Google Colaboratory.

In this notebook we provide an example of how to build a simple Tensor Net (see https://arxiv.org/abs/1509.06569).

The main ingredient is the so-called TT-Matrix, a generalization of the Kronecker product matrices, i.e. matrices of the form

\[A = A_1 \otimes A_2 \cdots \otimes A_n\]

In t3f TT-Matrices are represented using the TensorTrain class.

[1]:
import numpy as np
import tensorflow as tf
import keras.backend as K

tf.set_random_seed(0)
np.random.seed(0)
sess = tf.InteractiveSession()
K.set_session(sess)

try:
    import t3f
except ImportError:
    # Install T3F if it's not already installed.
    !git clone https://github.com/Bihaqo/t3f.git
    !cd t3f; pip install .
    import t3f
Using TensorFlow backend.
[2]:
W = t3f.random_matrix([[4, 7, 4, 7], [5, 5, 5, 5]], tt_rank=2)

print(W)
A TT-Matrix of size 784 x 625, underlying tensor shape: (4, 7, 4, 7) x (5, 5, 5, 5), TT-ranks: (1, 2, 2, 2, 1)

Using TT-Matrices we can compactly represent densely connected layers in neural networks, which allows us to greatly reduce number of parameters. Matrix multiplication can be handled by the t3f.matmul method which allows for multiplying dense (ordinary) matrices and TT-Matrices. Very simple neural network could look as following (for initialization several options such as t3f.glorot_initializer, t3f.he_initializer or t3f.random_matrix are available):

[ ]:
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.int64, [None])

initializer = t3f.glorot_initializer([[4, 7, 4, 7], [5, 5, 5, 5]], tt_rank=2)
W1 = t3f.get_variable('W1', initializer=initializer)
b1 = tf.get_variable('b1', shape=[625])
h1 = t3f.matmul(x, W1) + b1
h1 = tf.nn.relu(h1)

W2 = tf.get_variable('W2', shape=[625, 10])
b2 = tf.get_variable('b2', shape=[10])
h2 = tf.matmul(h1, W2) + b2

y_ = tf.one_hot(y, 10)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=h2))

For convenience we have implemented a layer analogous to Keras Dense layer but with a TT-Matrix instead of an ordinary matrix. An example of fully trainable net is provided below.

[ ]:
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Flatten
import numpy as np
from keras.utils import to_categorical
from keras import optimizers
[5]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
11493376/11490434 [==============================] - 3s 0us/step

Some preprocessing…

[ ]:
x_train = x_train / 127.5 - 1.0
x_test = x_test / 127.5 - 1.0

y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
[ ]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
tt_layer = t3f.nn.KerasDense(input_dims=[7, 4, 7, 4], output_dims=[5, 5, 5, 5],
                             tt_rank=4, activation='relu',
                             bias_initializer=1e-3)
model.add(tt_layer)
model.add(Dense(10))
model.add(Activation('softmax'))
[8]:
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_1 (Flatten)          (None, 784)               0
_________________________________________________________________
keras_dense_1 (KerasDense)   (None, 625)               1725
_________________________________________________________________
dense_1 (Dense)              (None, 10)                6260
_________________________________________________________________
activation_2 (Activation)    (None, 10)                0
=================================================================
Total params: 7,985
Trainable params: 7,985
Non-trainable params: 0
_________________________________________________________________

Note that in the dense layer we only have \(1725\) parameters instead of \(784 * 625 = 490000\).

[ ]:
optimizer = optimizers.Adam(lr=1e-2)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
[10]:
model.fit(x_train, y_train, epochs=2, batch_size=64, validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
60000/60000 [==============================] - 9s 151us/step - loss: 0.2311 - acc: 0.9298 - val_loss: 0.1536 - val_acc: 0.9560
Epoch 2/2
60000/60000 [==============================] - 8s 137us/step - loss: 0.1380 - acc: 0.9591 - val_loss: 0.1716 - val_acc: 0.9500
[10]:
<keras.callbacks.History at 0x7f86cb2a4400>

Compression of Dense layers

Let us now train an ordinary DNN (without TT-Matrices) and show how we can compress it using the TT decomposition. (In contrast to directly training a TT-layer from scratch in the example above.)

[ ]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(625, activation='relu'))
model.add(Dense(10))
model.add(Activation('softmax'))
[12]:
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_2 (Flatten)          (None, 784)               0
_________________________________________________________________
dense_2 (Dense)              (None, 625)               490625
_________________________________________________________________
dense_3 (Dense)              (None, 10)                6260
_________________________________________________________________
activation_3 (Activation)    (None, 10)                0
=================================================================
Total params: 496,885
Trainable params: 496,885
Non-trainable params: 0
_________________________________________________________________
[ ]:
optimizer = optimizers.Adam(lr=1e-3)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
[14]:
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/5
60000/60000 [==============================] - 6s 104us/step - loss: 0.2771 - acc: 0.9156 - val_loss: 0.1529 - val_acc: 0.9528
Epoch 2/5
60000/60000 [==============================] - 6s 101us/step - loss: 0.1278 - acc: 0.9613 - val_loss: 0.1079 - val_acc: 0.9680
Epoch 3/5
60000/60000 [==============================] - 6s 101us/step - loss: 0.0960 - acc: 0.9702 - val_loss: 0.1078 - val_acc: 0.9658
Epoch 4/5
60000/60000 [==============================] - 6s 102us/step - loss: 0.0806 - acc: 0.9744 - val_loss: 0.0948 - val_acc: 0.9714
Epoch 5/5
60000/60000 [==============================] - 6s 102us/step - loss: 0.0733 - acc: 0.9770 - val_loss: 0.1072 - val_acc: 0.9664
[14]:
<keras.callbacks.History at 0x7f87102116d8>

Let us convert the matrix used in the Dense layer to the TT-Matrix with tt-ranks equal to 16 (since we trained the network without the low-rank structure assumption we may wish start with high rank values).

[15]:
W = model.trainable_weights[0]
print(W)
Wtt = t3f.to_tt_matrix(W, shape=[[7, 4, 7, 4], [5, 5, 5, 5]], max_tt_rank=16)
print(Wtt)
<tf.Variable 'dense_2/kernel:0' shape=(784, 625) dtype=float32_ref>
A TT-Matrix of size 784 x 625, underlying tensor shape: (7, 4, 7, 4) x (5, 5, 5, 5), TT-ranks: (1, 16, 16, 16, 1)

We need to evaluate the tt-cores of Wtt. We also need to store other parameters for later (biases and the second dense layer).

[ ]:
cores = sess.run(Wtt.tt_cores)
other_params = model.get_weights()[1:]

Now we can construct a tensor network with the first Dense layer replaced by Wtt initialized using the previously computed cores.

[ ]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
tt_layer = t3f.nn.KerasDense(input_dims=[7, 4, 7, 4], output_dims=[5, 5, 5, 5],
                             tt_rank=16, activation='relu')
model.add(tt_layer)
model.add(Dense(10))
model.add(Activation('softmax'))
[ ]:
optimizer = optimizers.Adam(lr=1e-3)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
[ ]:
model.set_weights(list(cores) + other_params)
[20]:
print("new accuracy: ", model.evaluate(x_test, y_test)[1])
10000/10000 [==============================] - 1s 102us/step
new accuracy:  0.6533
[21]:
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_3 (Flatten)          (None, 784)               0
_________________________________________________________________
keras_dense_2 (KerasDense)   (None, 625)               15585
_________________________________________________________________
dense_4 (Dense)              (None, 10)                6260
_________________________________________________________________
activation_5 (Activation)    (None, 10)                0
=================================================================
Total params: 21,845
Trainable params: 21,845
Non-trainable params: 0
_________________________________________________________________

We see that even though we now have about 5% of the original number of parameters we still achieve a relatively high accuracy.

Finetuning the model

We can now finetune this tensor network.

[22]:
model.fit(x_train, y_train, epochs=2, batch_size=64, validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
60000/60000 [==============================] - 12s 196us/step - loss: 0.1353 - acc: 0.9589 - val_loss: 0.0983 - val_acc: 0.9710
Epoch 2/2
60000/60000 [==============================] - 11s 177us/step - loss: 0.0810 - acc: 0.9749 - val_loss: 0.0820 - val_acc: 0.9751
[22]:
<keras.callbacks.History at 0x7f86c9fc22e8>

We see that we were able to achieve higher validation accuracy than we had in the plain DNN, while keeping the number of parameters extremely small (21845 vs 496885 parameters in the uncompressed model).

[ ]: