CNN Confusion Matrix with PyTorch – Neural Network Programming

welcome to this neural network programming series in this episode we're going to analyze the training results of a neural network by building plotting and interpreting a confusion matrix without further ado let's get started in this episode we will build some functions that will allow us to get a prediction tensor for every sample in our training set we'll see how we can take this prediction tensor and create a confusion matrix this confusion matrix will allow us to see which categories our network is confusing with one another let's get started with some code we're here in the notebook now we're we've been working over the past couple of episodes I'm not going to go over the steps that it took to get here because if you want to know those steps you can just look at the previous video where we're starting out here is with a trained network we used PI Torx to train our network but the framework doesn't matter here what we're going to cover applies to any framework all we need to create this confusion matrix is to have a tensor of predictions and the tensor that has the corresponding tree values or labels we have those two things then we can create a confusion matrix in this way so if you want to see how we created and trained our network up to this point be sure to check out the previous videos in the series before we jump into the code I want to take a quick look at what we're trying to achieve we're trying to build this confusion matrix it's available on deep laser comm on the corresponding blog post for this video so if you want to check it out go there the link is in the description now this confusion matrix has three axes on the x-axis we have the predicted labels on the y-axis we have the true labels and inside the confusion matrix we have these colors which build a heat map for us now the color corresponds to higher values so the higher the values the darker the color and what we see here would this dark set of squares going down the diagonal is what we want to see when we're training our network and this is because the diagonal is where the predicted label is equal to the true label so if you take for example the first category on the x-axis which is a t-shirt and then you go all the way up to the top where you see the corresponding true label t-shirt we can see that the number there the 5430 one that number means that the network predicted a t-shirt five thousand four hundred thirty one times when the true label for that prediction was actually a t-shirt so that means that the network was correct five thousand times for that particular category now if we go down from there and we start looking at the other categories we can see that four times the predicted label was a t-shirt and the actual label was trouser and then 92 times the network predicted a t-shirt when the actual value was a pullover we don't want any positive values to be outside of the diagonal we want all the values to be going right down the diagonal that would mean the network was a hundred percent accurate so what this allows us to do is tell where our network is getting confused so if we look at large values that are outside of the diagonal that's where we can really start to see okay this is where the network's getting confused if we look at the t-shirt for example and then we see that large number that eleven hundred and fifty nine we can see that the network predicted incorrectly eleven hundred fifty nine times a shirt so the network predicted t-shirt when the actual value was a shirt so we can kind of understand why the network would be getting that confused so typically what we'll see in the confusion matrix is the network getting confused on categories that are sort of similar to each other so in order to build this what we need to do is take all of our categories and we need to run them on the X and y axis then we just iterate over all of our predictions and corresponding true labels and we just count up for each square how many times did this occur and that will give us the confusion matrix to do this we'll build a rank two tensor that actually has this data inside of it and then we'll be able to plot that data to generate this image that we see here so let's jump back over into the code and see how this is done all right so we're gonna start out by just checking out the two things that we need which is gonna be our training set and our training set targets we're gonna look at the length of both of these all right they're both 60,000 in length and this is because the fashion name is de to set that we're working with has 60,000 samples so in the training set we have 60,000 images and then in the targets attribute of the training set we also have 60,000 corresponding labels with PI Torche datasets we access the labels via the targets attribute what we need to do is to get a predictions tensor that has 60,000 predictions for all samples in our training set we can use this predictions tensor and compare it with the targets tensor and then this will enable us to generate the confusion matrix so to do that we're going to build a function here that will go through our entire training set and produce a prediction all right so we're gonna name the function get all Preds or get all predictions and then we're gonna pass through this function our model or network and then we're also gonna pass a data loader and the reason we're doing this is because we don't want to pass the entire training set of 60,000 samples to our network at one time our machine's not going to be able to handle that computation all at once so what we want to do is break up the training set into batches and then just predict on each batch at a time and gather all of these predictions up into a single tensor so that's why we're going to pass a data loader in that way we can generate batches so in the first line of implementation of this method we're going to create a tensor that we're going to return called all Preds and then we're going to set it equal to a new PI torch tensor that's empty and then what we're going to do is go through all of the batches in our data loader so we're gonna iterate over this data loader by creating a for loop for batch in loader the next line we're gonna unpack this batch into an images and labels tensor then we pass the images to our model and we get back the predictions for the batch and then we're going to concatenate these predictions to the all predictions tensor to do this we're going to use the torch not cap function and we're going to do this on the first dimension this will give us a tensor of predictions for all samples from our training set all right so let's go ahead and run this code so now we have that function defined now to make the call all we have to do is create a data loader and then pass in our network and our data loader to this function get all Preds that we just defined so we're gonna create a data loader called prediction loader and we're going to pass our training set in and then we're also gonna set the batch size I'm gonna set mine at 10,000 you may find that you need to set this a little bit lower depending on the resources of your machine then we passed the network and this prediction loader into the get all preps function and then returned to us is going to be a prediction tensor that contains predictions for all samples in our training set so we're gonna call it train predictions or train Preds let's run this code you'll see that it is going to take several seconds to actually do this computation all right so now we have our training Preds tensor that's a tensor with predictions for every sample in our training set so now let's just take a look at the shape of this tensor the shape of this tensor is 60,000 by 10 for every one of the 60,000 samples in our training set we have ten prediction categories and the network has given us a value a prediction value for each one of those categories the category with the highest prediction value is the one the network is predicting most strongly so to figure out which category has the highest prediction we can use the Arg max method but before we do that I want to show you something about PI torch right here really quickly so let's just take a look at the requires grad attribute on the train predictions tensor that we just created and I want you see that it's true and what this means is that this particular tensor required PI torch as a gradient tracking feature so in past episodes we would turn PI torches gradient tracking feature on and off depending on whether or not we needed it they whether or not we were training so in this case we don't actually need gradient tracking because we're not training but this says that gradient tracking is on for this tensor so what exactly does that mean you might think that it means we're gonna have a gradient for this particular tensor so if we check the grad attribute on the training predictions tensor you might think that we're gonna have a value there let's just see okay there's no value for this particular attribute so what's up with this why does requires grad say true well remember during the training process we don't get a gradient value for any of our tensors until we call backward on the tensor so no gradients have been calculated because we haven't done any back propagation if we look inside the grad underscore F in attribute we're gonna see a value this value is the function that led to this tensor creation and let's just see if we have a value for this and we do we have this particular tensor whose graph is being tracked the problem with this is that whenever we're doing predictions also known as inference we don't want the extra overhead associated with keeping track of the graph so what we typically do and what we should do which I didn't do before because I wanted to show this example is we need to get our predictions without tracking the gradients or without creating a graph so that's what we're going to do in this next line in order to get the predictions without gradient tracking we have a couple of options we have seen in the past that we turn off gradient tracking globally but we also have a local way of doing that and that's with what we call a context manager so we use the width keyword and we use torch no grad and then all we do is inside of this block of code any computations that we do will be done without tracking gradients so this is a way to locally turn off gradient tracking so if we just went through a training process we could have our gradient tracking on and then afterwards we want to do some predictions some inference then all we would have to do is include our calls to the network inside of a locally controlled gradient tracking so we say with torch no grad let's do these computations alright so we're gonna do the same computation and recreate or reinsert and haven't run the numbers on this but it should be potentially a little faster and it's also should use less memory because it's not keeping track at the graph now let's just verify that it's not keeping track of the graph let's check the requires grad attribute on the tensor now okay it says false for this tensor which means this tensor did not require gradient tracking why well it's because we locally turned it off when we created this tensor now let's check the gradient obviously we're not gonna have anything there because we didn't do any back propagation and then let's here's the thing that we're gonna see that's different we're gonna check the gradient function which is essentially the graph for this particular tensor and we're gonna see that indeed there's nothing there the value is not set so for that reason we're using less memory we're not tracking the graph but anyways that was just a quick little example of how we can turn gradient tracking off locally with PI torch now there is another option which I can just show you real quick on the website here so this is another option for turning off gradient tracking is just to annotate or decorate your function with this particular decoration at torch no grad and that means anytime that this function gets called its gradient tracking is gonna be locally called off within the context of the execution of this function so definitely keep that in mind if you're working with PI torch alright now that we have the training predictions tensor for all of our training samples we can take that tensor along with the training set targets pass this into our get Nam correct function that we defined in the last video and we can see how many we got correct so let's do that and we see that we have 50 2813 correct which is an accuracy of about 88% now we're ready to build our confusion matrix so to build a confusion matrix we need to have the targets which are the labels we also need to have a corresponding predictions tensor so we have the prediction sensor but what we need we have 10 per day for each sample we need to just know the won the argument or the label that has the maximum value because this is the one that the network was predicting most strongly so to do that we just call the R max method on the train prediction sensor and we call this method with respect to dimension 1 so let's run both of these cells so the first sample in the training set has a label of a nine and then zero zero and so on and if we look in the predictions tensor and the Arg max values we can see that the first prediction the highest value was a nine second prediction was a zero zero and so on so we can see that these first three predictions the predicted label matches the target label so these are correct predictions these first three and the last three are also the three zero five also represent correct predictions so what we want to do is we want to pair off each one of these prediction labels and target labels and in order to do that we're going to use the torch dot stack function and we're gonna do it with respect to dimension 1 and let me just show you what that gives us if we look at the stacked shape we see that we have 60,000 rows and each row has two values so we've essentially taken taken these two tensors turn them into columns and then kind of smush them together and if we just look at that we can see we have pairs and the first value in each pair is the true label and the second value needs Peter is the predicted label so now we can iterate over all of these pairs and count up how many times each combination occurs so when we have matching label and prediction it's going to be across the diagonal and our confusion matrix and then if we don't have corresponding pairs that's going to be off the diagonal so it's going to be an incorrect prediction so let's just see how we're going to access one of these whenever we start iterating across of them so we're gonna look at the first one and we're just going to change it into a Python list and we're gonna see that we get the pair in a list form we'll be able to unpack this pair get the first value second value and then just add up or increment the current value that's sitting in that spot on the confusion matrix so to actually do this we need to create a confusion matrix so we're gonna call torch zeros we want a 10 by 10 confusion matrix and that's because we have 10 categories in our training set and then we'll go with n64 we'll make that 32 we'll go the end 32 as our data type and let's just see what this looks like so now this is going to be or currently is our confusion matrix tensor CMT so right now all the values are 0 and as we iterate over the pair's we're going to be incrementing the occurrences of each square and to do that and just show you or remind you what this is going to look like we have all the pairs that look like this at the top if we access one of them in call to list on it then we get back a list with the two values and then just so if we unpack the two values like this we can say J comma K is equal to the stacked 0.2 list then what that's gonna do is that's gonna give us J is gonna be a 9 and then K is also gonna be a 9 so that's that's the way this is working here in this for loop we say for each pair in the stack tensor I want to unpack the pair in a into J and K variable then in the confusion matrix I'm going to find the Jade row and the caithe column and I'm going to get the value from that and set that equal to the value plus 1 it may be even better to rename these actually I think so so we'll call this one the true and we'll call this one the call this one the true label and we'll call this one the predicted label so we go true label predicted label and then tree label predicted label so what this is doing is it's saying find the square where the true label and the predicted label are occurring get that value set it equal to whatever it is plus one and so that is essentially going to count the occurrences of corresponding categories so how many times do the network predict this particular category so let's just run this and see the result so it'll take a second because it's going over 60,000 pairs and then now if we just take a look at the what it looks like we can see that we've counted up all the occurrences now like we said before going down the diagonal we see the larger numbers which is where the true label is equal to the predicted label and our network was trained up to 88% accuracy so right now we're actually expecting to see a lot more values going down our diagonal than elsewhere in the confusion matrix so that's how you build the confusion matrix and that's going to be the same process whether you're using PI torch or you're using just some other library or just Python if you have a list of labels and a list of predictions you can pair those up and then generate this confusion matrix so let's see how we can plot this now so plotting a confusion matrix alright so if that was kind of hard for you to conceptualize to understand then I have good news what we can do is just import confusion matrix from scikit-learn metrics and there is a function in there that will generate this confusion matrix for us so we'll use that in this example to plot it just to show you that it's the same we're also going to import PI plot from a matplotlib and then we also need about we also need another function to actually plot the confusion matrix and that's coming from resources dot plot CM which is a local Python file that I have on my system I'm gonna show you the contents of that one in just a second but let's just run this code okay we have those in and I want to generate the confusion matrix now we're going to be using the scikit-learn metrics library for this and what I wanted to show you is that it's going to be exactly the same as what we just generated we're gonna initialize a confucian matrix using this confusion matrix function and we're passing in the train set targets and the train predictions tensor that we created by calling our max on it with respect to the first dimension and we see that they returned a numpy in the array it looks exactly the same as what we did with pie George if you look down at the bottom we have 240 11 5 7 26 we look up here 240 11 5/7 26 this is a Pike torque sensor this is a numpy india ray but either way they both work ok so now all we need to do is plot so the names of the categories is just a tuple of the names and they are listed here now where did this come from well this came from understanding our data set and these values let me just show you the values on the website so if we scroll here we have a table with all of the values and those values basically come to the with the data set all right so we're ready to plot this thing let's call this code see that it plots and then I'll show you what the plot confusion matrix code looks like so it shows the confusion matrix data and it also plots the confusion matrix you can notice at the bottom the data corresponds with what we saw earlier 240 11 and 57:26 so just to show you that this also works with a PI torque sensor we do CMT which is the pi torque sensor confusion matrix tensor that we created we run this code this time it says that oh it's a tensor and same result ok so that's how you do it now all you need to know is this code that lives here inside resources dot plot cm and dysfunction so let's take a look at that ok so for this to work you need to have in your current directory where the notebook or the code is executing you need to have a resources folder and inside that resources folder you need to have a Python file called plot cmp Y and then inside of that file need to have a function called plot confusion matrix and then that will enable you to import this particular function into your program in this way so let's take a quick look at that file so I'm here in this lizard code PI torch directory and just to show you this is where my notebook is and then just up here I have a folder called resources so we will go into that folder and we'll have a look and we have plot CMB y and then if we take a look at the contents by getting the content of this file we can see it imports some stuff and then defines the plot confusion matrix function and this is the code that actually does that in addition to importing the code in this way you can just copy this function into your notebook and it'll it'll work all the same you not to worry about importing it this code is going to be available on the website so you can go copy it from there and that's how you create plot and interpret a confusion matrix if you haven't already be sure to check out deep lizard com where there's blog posts for each episode there's even quizzes now that you can use to test your understanding of the content and don't forget about the deep lizard hivemind where you can get exclusive perks and rewards thanks for contributing collective intelligence I'll see you in the next one greetings again fellow organism in the recent past we trained our network to tell the difference between ten different articles of clothing items we can tell which categories our network is confusing with one another I say to you what the heck an interesting question is lurking just beneath the surface of this whole ordeal we say that our network has been trained we say that our network is confused does this mean our network has knowledge what I say next may shock your brain to the core in an existential way but I pose this question for a human and I pose this question for a neural network what does it mean to say that we know something and fundamentally how do we know that we know if this sounds confusing maybe we should just stick to building confusion matrices or perhaps we should study epistemology the branch of philosophy concerned with the theory of knowledge

Tags: , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,


Leave a Reply

Your email address will not be published. Required fields are marked *