Can somebody help me to correctly derive the loss function?

Toma Dragos

I'm trying to adapt the example from http://cs231n.github.io/neural-networks-case-study/#together to make a neural network for a numeric target variable so it will be a neural network with regression. I surely do something wrong in the derivation part because my loss function in insanely growing. Here is the code:

h = neurons # size of hidden layer
D = X[0].size
K = 1
W = 0.01 * np.random.randn(D,h)
b = np.zeros((1,h))
W2 = 0.01 * np.random.randn(h,K)
b2 = np.zeros((1,K))

# some hyperparameters
step_size = 1 #learning rate
reg = 0.001 # regularization strength

loss_vec = []
# gradient descent loop
num_examples = X.shape[0]
for i in xrange(1000):

  # evaluate class scores, [N x K]
  hidden_layer = np.maximum(0, np.dot(X, W) + b) # note, ReLU activation
  scores = np.dot(hidden_layer, W2) + b2

  loss = np.power(y - scores,2)
  #if i % 50 == 0:
  loss_vec.append(np.mean(np.abs(loss)))
  print "iteration %d: loss %f" % (i, np.mean(np.abs(loss)))

  # compute the gradient on scores
  dscores = 2*(y-scores) # here I am not sure is correct
    
  # backpropate the gradient to the parameters
  # first backprop into parameters W2 and b2
  dW2 = np.dot(hidden_layer.T, dscores)
  db2 = np.sum(dscores, axis=0, keepdims=True)
  # next backprop into hidden layer
  dhidden = np.dot(dscores, W2.T)
  # backprop the ReLU non-linearity
  dhidden[hidden_layer <= 0] = 0
  # finally into W,b
  dW = np.dot(X.T, dhidden)
  db = np.sum(dhidden, axis=0, keepdims=True)

  # add regularization gradient contribution
  dW2 += reg * W2
  dW += reg * W

  # perform a parameter update
  W += -step_size * dW
  b += -step_size * db
  W2 += -step_size * dW2
  b2 += -step_size * db2

Code output:

iteration 0: loss 5786.021888

iteration 1: loss 24248543152533318464172949461134213120.000000

iteration 2: loss 388137710832824223006297769344993376570435619092

Maxim

I've noticed several important mistakes:

  • the learning rate is too big, no chance to learn anything. I used 0.0005, but it depends on the data, size of hidden layer, etc
  • the loss derivative dscores should be flipped: scores - y
  • the loss also ignores regularization (probably dropped for debugging purposes)

Complete code below:

import numpy as np

# Generate data: learn the sum x[0] + x[1]
np.random.seed(0)
N = 100
D = 2
X_test = np.zeros([N, D])
y = np.zeros([N, 1])
for i in range(N):
  X_test[i, :] = np.random.random_integers(0, 4, size=2)
  y[i] = X_test[i, 0] + X_test[i, 1]

# Network params
H = 10
W = 0.01 * np.random.randn(D, H)
b = np.zeros([1, H])
W2 = 0.01 * np.random.randn(H, 1)
b2 = np.zeros([1, 1])

# Hyper params
step_size = 0.0005
reg = 0.001

for i in xrange(100):
  hidden_layer = np.maximum(0, np.dot(X_test, W) + b)
  scores = np.dot(hidden_layer, W2) + b2

  reg_loss = 0.5 * reg * np.sum(W * W) + 0.5 * reg * np.sum(W2 * W2)
  loss = np.mean(np.power(y - scores, 2)) + reg_loss

  print "iteration %d: loss %f" % (i, loss)

  dscores = (scores - y)

  dW2 = np.dot(hidden_layer.T, dscores)
  db2 = np.sum(dscores, axis=0, keepdims=True)

  dhidden = np.dot(dscores, W2.T)
  dhidden[hidden_layer <= 0] = 0

  dW = np.dot(X_test.T, dhidden)
  db = np.sum(dhidden, axis=0, keepdims=True)

  dW2 += reg * W2
  dW += reg * W

  W += -step_size * dW
  b += -step_size * db
  W2 += -step_size * dW2
  b2 += -step_size * db2

# Test
X_test = np.array([[1, 0], [0, 1], [2, 3], [2, 2]]).reshape([-1, 2])
y_test = np.array([1, 1, 5, 4]).reshape([-1, 1])
hidden_layer = np.maximum(0, np.dot(X_test, W) + b)
scores = np.dot(hidden_layer, W2) + b2
print 'Average test error = %f' % np.mean((scores - y_test).T)

Collected from the Internet

Please contact [email protected] to delete if infringement.

edited at
0

Comments

0 comments
Login to comment

Related

From Dev

Could somebody please help me with this code?

From Dev

Somebody could help me with this little program in prolog?

From Dev

Can somebody please help me to avoid internal server error | htaccess | apache2ctl | backtrack

From Dev

Hello, can somebody help me or any suggstion code how can I start in conditional for getting the range of specific range number

From Dev

Can somebody explain this javascript statement to me?

From Dev

Can somebody explain me the working of this program with stacks?

From Dev

Can somebody tell me what is wrong with this picture?

From Dev

I am trying to search user by name, and load some data to Recycleview. But the code is not working. Can somebody help me please

From Dev

Can somebody help explain what is happening in this fft code snippet

From Dev

Can somebody help to translate this MySql query to JPQL or HQL?

From Dev

Can somebody help explain a line of code in an example of Theano tutorial?

From Dev

Can somebody tell me how can this regular expression match anything?

From Dev

Can someone help me on converting this Kotlin function with callback to a java function?

From Dev

Can somebody please walk through this MIPS code with me?

From Dev

GMS Shoal Glassfish 3.1.2 : Can somebody translate me this into english :-)

From Dev

Can somebody explain me what is the use of "etag" in django?

From Dev

Multithreading in Go. Can somebody explain these answers to me?

From Dev

Can somebody please briefly explain to me what's happening with the IntStream?

From Dev

Can somebody run me through a clean install of 13.10?

From Dev

Can somebody give me an easy example and explanation of a memory leak in JavaScript?

From Dev

Can somebody explain to me what 'void postConcat' in Android does?

From Dev

Can somebody show me a simple way to compare strings

From Dev

Can somebody explain to me why the L (long) has to be there?

From Dev

Can somebody tell me what this sorting algorithm is called?

From Dev

Can somebody run me through a clean install of 13.10?

From Dev

python help me with this function

From Dev

Can you help me to understand salt hashing function?

From Dev

Can someone help me clean up my r function?

From Dev

create function throws error in mysql,can someone help me out?

Related Related

  1. 1

    Could somebody please help me with this code?

  2. 2

    Somebody could help me with this little program in prolog?

  3. 3

    Can somebody please help me to avoid internal server error | htaccess | apache2ctl | backtrack

  4. 4

    Hello, can somebody help me or any suggstion code how can I start in conditional for getting the range of specific range number

  5. 5

    Can somebody explain this javascript statement to me?

  6. 6

    Can somebody explain me the working of this program with stacks?

  7. 7

    Can somebody tell me what is wrong with this picture?

  8. 8

    I am trying to search user by name, and load some data to Recycleview. But the code is not working. Can somebody help me please

  9. 9

    Can somebody help explain what is happening in this fft code snippet

  10. 10

    Can somebody help to translate this MySql query to JPQL or HQL?

  11. 11

    Can somebody help explain a line of code in an example of Theano tutorial?

  12. 12

    Can somebody tell me how can this regular expression match anything?

  13. 13

    Can someone help me on converting this Kotlin function with callback to a java function?

  14. 14

    Can somebody please walk through this MIPS code with me?

  15. 15

    GMS Shoal Glassfish 3.1.2 : Can somebody translate me this into english :-)

  16. 16

    Can somebody explain me what is the use of "etag" in django?

  17. 17

    Multithreading in Go. Can somebody explain these answers to me?

  18. 18

    Can somebody please briefly explain to me what's happening with the IntStream?

  19. 19

    Can somebody run me through a clean install of 13.10?

  20. 20

    Can somebody give me an easy example and explanation of a memory leak in JavaScript?

  21. 21

    Can somebody explain to me what 'void postConcat' in Android does?

  22. 22

    Can somebody show me a simple way to compare strings

  23. 23

    Can somebody explain to me why the L (long) has to be there?

  24. 24

    Can somebody tell me what this sorting algorithm is called?

  25. 25

    Can somebody run me through a clean install of 13.10?

  26. 26

    python help me with this function

  27. 27

    Can you help me to understand salt hashing function?

  28. 28

    Can someone help me clean up my r function?

  29. 29

    create function throws error in mysql,can someone help me out?

HotTag

Archive