Visualizating learning is a great way to gain better understaning of your machine learning model's inputs, outputs and/or the model parameters. In this article we discuss
Tensorboard is a tool that comes with the automatic differentiation library Tensorflow. To use it with PyTorch codes, you will first have to install an extension of tensorboard for PyTorch called tensorboardX. Note: All the commands listed here, for installation and otherwise, were tried on Debian GNU/Linux 9.6 (stretch).
pip install tensorboardX
To use the lastest version, you might have to build from the source or run:
pip install tensorboardX --no-cache-dir
To access tensorboard's web interface you will have to install Tensorflow.
To install tensorboard run:
pip install tensorboard
In case the command above does not work becuse of python version mismatch, you can create a conda environment to install the latest Tensorflow. Tensorboard is then accessed through Tensorflow. Run the following command to create a conda environment named 'Tensorflow' with Python version 3.6:
conda create -n Tensorflow anaconda python=3.6
To activate or de-activate the environment, the following commands can be used:
conda activate Tensorflow conda deactivate Tensorflow
Tensorflow can then be installed in the conda environment using the following command:
conda install -c conda-forge tensorflow
Tensorboard summary writers can be used to summarize various types of data types: scalar, histogram, image, graph, etc. I am including examples of scalar and histogram variables here. The reader can refer to this document for logging the other data types.
First, you need to import tensorboardX's summary writer in your code. The import command is:
from tensorboardX import SummaryWriter
Now create a writer variable where the data summaries would be saved using:
writer = SummaryWriter() # Saves the summaries to default directory names 'runs' in the current parent directory writer1 = SummaryWriter('runs/exp-1') # Saves the summaries to the directory 'runs/exp-1' in the current parent directory writer2 = SummaryWriter(comment='I hope the model works this time around') # Saves the summaries to folder named 'runs/Jan07-18-13-40-I hope the model works this time around'
The general format to save any type of data is:
add_something(name, object, iteration number) # something = scalar, histogram, image, etc
To write scalar variable 'good_god' to the summary named 'writer' at epoch number 'e', run:
writer.add_scalar('Good_God', good_god, e)
Remember to extract data from a PyTorch tensor using .item() before passing it to the writer. To save histogram summaries, use:
writer.add_histogram(hist_name, f, e)
If you are aiming to save a neural network (NN) weights, biases or gradients, you can use a code similar to the following:
for f in model.parameters(): # model is the NN model, f is one set of parameters of the model # Create a dynamic name for the histogram summary # Use current parameter shape to identify the variale hist_name = 'hist' + str(list(f.grad.data.size())) # Save the entire list of gradients of parameters f writer.add_histogram(hist_name, f, e) # Save the norm of list of gradients of parameters f scalar_name = 'scalar' + str(list(f.grad.data.size())) writer.add_scalar(scalar_name, torch.norm(f.grad.data).item(), e) # parameter update step f.data.sub_(f.grad.data * LR)
To run the tensorboard web interface, from the correct directory, simple run:
tensorboard --logdir runs # runs is the name of the folder that has summaries saved # -> TensorBoard 1.10.0 at http://wks-40-817:6006 (Press CTRL+C to quit)
Tensorboard can be accessed from your web browser at the link mentioned above, where 6006 is the default port used for running the web interface. In case you need to explicitly specify a port or need to run web interface with different folders, you can specify a different port using the following command:
tensorboard --logdir runs --port 6007
In case you need to access the web interface when you are working remotely, you can ssh into the server using:
ssh -X -L 16006:127.0.0.1:6007 cosmos@my_server_ip