Tensor Nets (compressing neural networks)

Open this page in an interactive mode via Google Colaboratory.

In this notebook we provide an example of how to build a simple Tensor Net (see https://arxiv.org/abs/1509.06569).

The main ingredient is the so-called TT-Matrix, a generalization of the Kronecker product matrices, i.e. matrices of the form

\[A = A_1 \otimes A_2 \cdots \otimes A_n\]

In t3f TT-Matrices are represented using the TensorTrain class.

[1]:
# Import TF 2.
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as K

# Fix seed so that the results are reproducable.
tf.random.set_seed(0)
np.random.seed(0)

try:
    import t3f
except ImportError:
    # Install T3F if it's not already installed.
    !git clone https://github.com/Bihaqo/t3f.git
    !cd t3f; pip install .
    import t3f
TensorFlow 2.x selected.
Cloning into 't3f'...
remote: Enumerating objects: 321, done.
remote: Counting objects: 100% (321/321), done.
remote: Compressing objects: 100% (182/182), done.
remote: Total 4715 (delta 209), reused 226 (delta 139), pack-reused 4394
Receiving objects: 100% (4715/4715), 1.52 MiB | 1.26 MiB/s, done.
Resolving deltas: 100% (3203/3203), done.
Processing /content/t3f
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from t3f==1.1.0) (1.18.1)
Building wheels for collected packages: t3f
  Building wheel for t3f (setup.py) ... done
  Created wheel for t3f: filename=t3f-1.1.0-cp36-none-any.whl size=75051 sha256=a20c22745abcbe82d9a467cf607135da9d5399940712bfbf134bbf7e40ac53b3
  Stored in directory: /tmp/pip-ephem-wheel-cache-vnw71g5i/wheels/66/f2/16/8d2b16c34f7e12d446db3584514f9e34e681f4c602325d175c
Successfully built t3f
Installing collected packages: t3f
Successfully installed t3f-1.1.0
[3]:
W = t3f.random_matrix([[4, 7, 4, 7], [5, 5, 5, 5]], tt_rank=2)

print(W)
A TT-Matrix of size 784 x 625, underlying tensor shape: (4, 7, 4, 7) x (5, 5, 5, 5), TT-ranks: (1, 2, 2, 2, 1)

Using TT-Matrices we can compactly represent densely connected layers in neural networks, which allows us to greatly reduce number of parameters. Matrix multiplication can be handled by the t3f.matmul method which allows for multiplying dense (ordinary) matrices and TT-Matrices. Very simple neural network could look as following (for initialization several options such as t3f.glorot_initializer, t3f.he_initializer or t3f.random_matrix are available):

[ ]:
class Learner:
  def __init__(self):
    initializer = t3f.glorot_initializer([[4, 7, 4, 7], [5, 5, 5, 5]], tt_rank=2)
    self.W1 = t3f.get_variable('W1', initializer=initializer)
    self.W2 = tf.Variable(tf.random.normal([625, 10]))
    self.b2 = tf.Variable(tf.random.normal([10]))

  def predict(self, x):
    b1 = tf.Variable(tf.zeros([625]))
    h1 = t3f.matmul(x, W1) + b1
    h1 = tf.nn.relu(h1)
    return tf.matmul(h1, W2) + b2

  def loss(self, x, y):
    y_ = tf.one_hot(y, 10)
    logits = self.predict(x)
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))

For convenience we have implemented a layer analogous to Keras Dense layer but with a TT-Matrix instead of an ordinary matrix. An example of fully trainable net is provided below.

[ ]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout, Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import optimizers
[9]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Some preprocessing…

[ ]:
x_train = x_train / 127.5 - 1.0
x_test = x_test / 127.5 - 1.0

y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
[ ]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
tt_layer = t3f.nn.KerasDense(input_dims=[7, 4, 7, 4], output_dims=[5, 5, 5, 5],
                             tt_rank=4, activation='relu',
                             bias_initializer=1e-3)
model.add(tt_layer)
model.add(Dense(10))
model.add(Activation('softmax'))
[68]:
model.summary()
Model: "sequential_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_12 (Flatten)         (None, 784)               0
_________________________________________________________________
tt_dense_1 (KerasDense)      (None, 625)               1725
_________________________________________________________________
dense_8 (Dense)              (None, 10)                6260
_________________________________________________________________
activation_7 (Activation)    (None, 10)                0
=================================================================
Total params: 7,985
Trainable params: 7,985
Non-trainable params: 0
_________________________________________________________________

Note that in the dense layer we only have \(1725\) parameters instead of \(784 * 625 = 490000\).

[ ]:
optimizer = optimizers.Adam(lr=1e-2)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
[70]:
model.fit(x_train, y_train, epochs=3, batch_size=64, validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/3
60000/60000 [==============================] - 4s 69us/sample - loss: 0.2549 - accuracy: 0.9248 - val_loss: 0.1195 - val_accuracy: 0.9638
Epoch 2/3
60000/60000 [==============================] - 4s 62us/sample - loss: 0.1448 - accuracy: 0.9574 - val_loss: 0.1415 - val_accuracy: 0.9585
Epoch 3/3
60000/60000 [==============================] - 4s 62us/sample - loss: 0.1308 - accuracy: 0.9619 - val_loss: 0.1198 - val_accuracy: 0.9638
[70]:
<tensorflow.python.keras.callbacks.History at 0x7fd5263629b0>

Compression of Dense layers

Let us now train an ordinary DNN (without TT-Matrices) and show how we can compress it using the TT decomposition. (In contrast to directly training a TT-layer from scratch in the example above.)

[ ]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(625, activation='relu'))
model.add(Dense(10))
model.add(Activation('softmax'))
[72]:
model.summary()
Model: "sequential_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_13 (Flatten)         (None, 784)               0
_________________________________________________________________
dense_9 (Dense)              (None, 625)               490625
_________________________________________________________________
dense_10 (Dense)             (None, 10)                6260
_________________________________________________________________
activation_8 (Activation)    (None, 10)                0
=================================================================
Total params: 496,885
Trainable params: 496,885
Non-trainable params: 0
_________________________________________________________________
[ ]:
optimizer = optimizers.Adam(lr=1e-3)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
[74]:
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/5
60000/60000 [==============================] - 3s 57us/sample - loss: 0.2779 - accuracy: 0.9158 - val_loss: 0.1589 - val_accuracy: 0.9501
Epoch 2/5
60000/60000 [==============================] - 3s 52us/sample - loss: 0.1297 - accuracy: 0.9610 - val_loss: 0.1632 - val_accuracy: 0.9483
Epoch 3/5
60000/60000 [==============================] - 3s 53us/sample - loss: 0.0991 - accuracy: 0.9692 - val_loss: 0.1083 - val_accuracy: 0.9674
Epoch 4/5
60000/60000 [==============================] - 3s 54us/sample - loss: 0.0835 - accuracy: 0.9742 - val_loss: 0.1191 - val_accuracy: 0.9619
Epoch 5/5
60000/60000 [==============================] - 3s 55us/sample - loss: 0.0720 - accuracy: 0.9771 - val_loss: 0.0918 - val_accuracy: 0.9714
[74]:
<tensorflow.python.keras.callbacks.History at 0x7fd5260c8240>

Let us convert the matrix used in the Dense layer to the TT-Matrix with tt-ranks equal to 16 (since we trained the network without the low-rank structure assumption we may wish start with high rank values).

[75]:
W = model.trainable_weights[0]
print(W)
Wtt = t3f.to_tt_matrix(W, shape=[[7, 4, 7, 4], [5, 5, 5, 5]], max_tt_rank=16)
print(Wtt)
<tf.Variable 'dense_9/kernel:0' shape=(784, 625) dtype=float32, numpy=
array([[-0.03238887,  0.06103956,  0.03255948, ..., -0.02577683,
         0.06993102, -0.00263362],
       [-0.05367032, -0.0324776 , -0.04441883, ...,  0.0338573 ,
         0.01554517,  0.04145934],
       [ 0.03441307,  0.04183276,  0.05157001, ...,  0.00082603,
         0.03731582, -0.01392014],
       ...,
       [ 0.03070629,  0.02113252,  0.01526976, ..., -0.00541451,
         0.03794012,  0.04027091],
       [-0.01376432, -0.0064889 , -0.03118961, ...,  0.06237663,
        -0.000577  , -0.02628548],
       [-0.01680673,  0.00364697,  0.01722438, ...,  0.01579029,
        -0.00826585,  0.03203061]], dtype=float32)>
A TT-Matrix of size 784 x 625, underlying tensor shape: (7, 4, 7, 4) x (5, 5, 5, 5), TT-ranks: (1, 16, 16, 16, 1)

We need to evaluate the tt-cores of Wtt. We also need to store other parameters for later (biases and the second dense layer).

[ ]:
cores = Wtt.tt_cores
other_params = model.get_weights()[1:]

Now we can construct a tensor network with the first Dense layer replaced by Wtt initialized using the previously computed cores.

[ ]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
tt_layer = t3f.nn.KerasDense(input_dims=[7, 4, 7, 4], output_dims=[5, 5, 5, 5],
                             tt_rank=16, activation='relu')
model.add(tt_layer)
model.add(Dense(10))
model.add(Activation('softmax'))
[ ]:
optimizer = optimizers.Adam(lr=1e-3)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
[ ]:
model.set_weights(list(cores) + other_params)
[97]:
print("new accuracy: ", model.evaluate(x_test, y_test)[1])
10000/10000 [==============================] - 1s 91us/sample - loss: 1.0276 - accuracy: 0.6443
new accuracy:  0.6443
[98]:
model.summary()
Model: "sequential_16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_16 (Flatten)         (None, 784)               0
_________________________________________________________________
tt_dense_2 (KerasDense)      (None, 625)               15585
_________________________________________________________________
dense_13 (Dense)             (None, 10)                6260
_________________________________________________________________
activation_11 (Activation)   (None, 10)                0
=================================================================
Total params: 21,845
Trainable params: 21,845
Non-trainable params: 0
_________________________________________________________________

We see that even though we now have about 5% of the original number of parameters we still achieve a relatively high accuracy.

Finetuning the model

We can now finetune this tensor network.

[99]:
model.fit(x_train, y_train, epochs=2, batch_size=64, validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
60000/60000 [==============================] - 5s 81us/sample - loss: 0.1349 - accuracy: 0.9594 - val_loss: 0.0982 - val_accuracy: 0.9703
Epoch 2/2
60000/60000 [==============================] - 5s 75us/sample - loss: 0.0822 - accuracy: 0.9750 - val_loss: 0.0826 - val_accuracy: 0.9765
[99]:
<tensorflow.python.keras.callbacks.History at 0x7fd526574198>

We see that we were able to achieve higher validation accuracy than we had in the plain DNN, while keeping the number of parameters extremely small (21845 vs 496885 parameters in the uncompressed model).

[ ]: