#export
from exp.nb_01 import *
def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))0021_fastai_pt2_2019_fully_connected
The forward and backward passes
get_data
1:23:03 - how to download and prepare the mnist dataset and wrap the process into a function called get_data;
x_train,y_train,x_valid,y_valid = get_data()normalize(x, m, s)
test_near_zero and assert
1:24:52 - how to check the mean and std values are close to 0 and 1 using test_near_zero using assert
train_mean,train_std = x_train.mean(),x_train.std()
train_mean,train_std#export
def test_near_zero(a,tol=1e-3): assert a.abs()<tol, f"Near zero: {a}"test_near_zero(x_train.mean())
test_near_zero(1-x_train.std())getting dimensions of weights of different layers
1:25:16 - how to get the number of activations of each layer n (rows of input), m (columns of input), c (number of targets/classes) from the shape of x_train and y_train
n,m = x_train.shape
c = y_train.max()+1
n,m,cFoundations version
Basic architecture
initialize weights/biases using Xavier init to ensure the first layer’s activation with mean 0 and std 1
using standared xavier init to initialize weights and biases
# standard xavier init
w1 = torch.randn(m,nh)/math.sqrt(m)
b1 = torch.zeros(nh)
w2 = torch.randn(nh,1)/math.sqrt(nh)
b2 = torch.zeros(1)test_near_zero(w1.mean())
test_near_zero(w1.std()-1/math.sqrt(m))x_valid has alread by normalized to have mean 0 and std 1
# This should be ~ (0,1) (mean,std)...
x_valid.mean(),x_valid.std()write linear layer from scratch
def lin(x, w, b): return x@w + bcheck mean and std of activations of first layer
t = lin(x_valid, w1, b1)#...so should this, because we used xavier init, which is designed to do this
t.mean(),t.std()writing a linear layer with relu from scratch
1:30:09 - what is the first layer look like; how should we write relu function to maximize the speed x.clamp_min(0.); how should we write functions in pytorch to maximize the speed in general;
def relu(x): return x.clamp_min(0.)t = relu(lin(x_valid, w1, b1))1:30:50 - but relu does not return the activation with mean 0 and std 1, (actually halved the std, and the gradients for updating weights will be gone when more layers or more ReLUs applied) and Jeremy explained why it is so
#...actually it really should be this!
t.mean(),t.std()1:31:47 - Jeremy introduced and lead us reading Delving Deep into Rectifiers by He; why we should read papers from competition winners than other papers; 1:32:43 - Jeremy explained Rectifiers in the paper and why random weights/biases won’t get trained well using He’s paper and Xavier’s paper (Xavier’s initialization didn’t account for the impact of ReLU, this is where He’s paper come in); the homework is to read this section (2.2) of the He’s paper.
From pytorch docs: a: the negative slope of the rectifier used after this layer (0 for ReLU by default)
\[\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}\]
This was introduced in the paper that described the Imagenet-winning approach from He et al: Delving Deep into Rectifiers, which was also the first paper that claimed “super-human performance” on Imagenet (and, most importantly, it introduced resnets!)
# kaiming init / he init for relu
w1 = torch.randn(m,nh)*math.sqrt(2/m)w1.mean(),w1.std()t = relu(lin(x_valid, w1, b1))
t.mean(),t.std()1:36:26 - Jeremy provided a guidance to us on how to read the Resnet paper
1:39:44 - how to use pytorch function torch.nn.init.kaiming_normal_ to do He init and how to dig into pytorch source code to figure out how to use those functions correctly like why fan_out in init.kaiming_normal_(w1, mode='fan_out')
#export
from torch.nn import initw1 = torch.zeros(m,nh)
init.kaiming_normal_(w1, mode='fan_out')
t = relu(lin(x_valid, w1, b1))init.kaiming_normal_??w1.mean(),w1.std()t.mean(),t.std()w1.shapeimport torch.nntorch.nn.Linear(m,nh).weight.shapetorch.nn.Linear.forward??torch.nn.functional.linear??1:42:56 - how to find and read source code of convolutional layer in pytorch and why it is a good idea to put the url of the paper you are implementing in the source code
torch.nn.Conv2d??torch.nn.modules.conv._ConvNd.reset_parameters??1:38:55 - Jeremy noticed a problem of the He initialization on the value of mean and explained why so; then he tried to a simple but natural method to bring the mean to 0 and the result seems very good and ended here at 1:39:44 1:44:35 - how much better does Jeremy’s tweaked ReLU work for getting activation mean to 0 and std to 0.8 rather than previously 0.7
# what if...?
def relu(x): return x.clamp_min(0.) - 0.5# kaiming init / he init for relu
w1 = torch.randn(m,nh)*math.sqrt(2./m )
t1 = relu(lin(x_valid, w1, b1))
t1.mean(),t1.std()1:45:28 - how to build our model model using the functions lin, relu we built above; how to test how fast it is to run; and how to verify the shape of the model output to be correct
def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3assert model(x_valid).shape==torch.Size([x_valid.shape[0],1])Loss function: MSE
1:46:15 - how to write mean squared error as our loss function (we never use it but only use mse as a starting point for our loss); how to squeeze a tensor with shape [n, 1] to just shape [n] using output.squeeze(-1)
model(x_valid).shapeWe need squeeze() to get rid of that trailing (,1), in order to use mse. (Of course, mse is not a suitable loss function for multi-class classification; we’ll use a better loss function soon. We’ll use mse for now to keep things simple.)
#export
def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()y_train,y_valid = y_train.float(),y_valid.float()preds = model(x_train)preds.shapemse(preds, y_train)Gradients and backward pass
1:48:00 - Jeremy has given us all the matrix calculus we need for deep learning for total beginners; 1:49:12 - how to understand and do the chain rule in terms of getting gradients with respect to params of our model and how to understand derivative in plain language; 1:52:56 - how to calculuate the derivate or gradient with respect to the output of previous layer or funcs ( mse, lin, relu )
def mse_grad(inp, targ): 
    # grad of loss with respect to output of previous layer
    inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]def relu_grad(inp, out):
    # grad of relu with respect to input activations
    inp.g = (inp>0).float() * out.gdef lin_grad(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)1:55:22 - how to put forward pass and backward pass into one function foward_and_backward; and backward pass is the chain rule (people who say No are liars) and saving the gradients as well;
def forward_and_backward(inp, targ):
    # forward pass:
    l1 = inp @ w1 + b1
    l2 = relu(l1)
    out = l2 @ w2 + b2
    # we don't actually need the loss in backward!
    loss = mse(out, targ)
    
    # backward pass:
    mse_grad(out, targ)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)forward_and_backward(x_train, y_train)1:56:41 - how to use pytorch’s gradient calculation functions to test whether our own gradients are calculated correctly;
# Save for testing against later
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()We cheat a little bit and use PyTorch autograd to check our results.
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)def forward(inp, targ):
    # forward pass:
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, targ)loss = forward(xt2, y_train)loss.backward()test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )Refactor model
Layers as classes
1:58:16 - how to refactor the previous funcs into classes; After Jeremy has done the refactory work, it becomes almost identical to pytorch api
class Relu():
    def __call__(self, inp):
        self.inp = inp
        self.out = inp.clamp_min(0.)-0.5
        return self.out
    
    def backward(self): self.inp.g = (self.inp>0).float() * self.out.gclass Lin():
    def __init__(self, w, b): self.w,self.b = w,b
        
    def __call__(self, inp):
        self.inp = inp
        self.out = inp@self.w + self.b
        return self.out
    
    def backward(self):
        self.inp.g = self.out.g @ self.w.t()
        # Creating a giant outer product, just to sum it, is inefficient!
        self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.b.g = self.out.g.sum(0)class Mse():
    def __call__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        self.out = (inp.squeeze() - targ).pow(2).mean()
        return self.out
    
    def backward(self):
        self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) / self.targ.shape[0]class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers): l.backward()w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model(w1, b1, w2, b2)test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)Module.forward()
2:02:36 - how to remove duplicated codes by adding another class Module and using einsum, and as a result, our refactor codes become identical to pytorch api; this step truly help make sense pytorch api
class Module():
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out
    
    def forward(self): raise Exception('not implemented')
    def backward(self): self.bwd(self.out, *self.args)class Relu(Module):
    def forward(self, inp): return inp.clamp_min(0.)-0.5
    def bwd(self, out, inp): inp.g = (inp>0).float() * out.gclass Lin(Module):
    def __init__(self, w, b): self.w,self.b = w,b
        
    def forward(self, inp): return inp@self.w + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = torch.einsum("bi,bj->ij", inp, out.g)
        self.b.g = out.g.sum(0)class Mse(Module):
    def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()
    def bwd(self, out, inp, targ): inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]class Model():
    def __init__(self):
        self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers): l.backward()w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model()test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)Without einsum
2:04:44 - how to replace einsum with pure matrix multiplication with @; and as a result, our own code from scratch is as fast as pytorch 2:05:44 plan for the next lesson
class Lin(Module):
    def __init__(self, w, b): self.w,self.b = w,b
        
    def forward(self, inp): return inp@self.w + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = inp.t() @ out.g
        self.b.g = out.g.sum(0)w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model()test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)nn.Linear and nn.Module
#export
from torch import nnclass Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]
        self.loss = mse
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x.squeeze(), targ)model = Model(m, nh, 1)Export
!./notebook2script.py 02_fully_connected.ipynb