//BackpropagationNode.cpp // // //This file defines functions for the BackpropagationNode class. // #ifndef BNODE_C #define BNODE_C #include "BackpropagationNode.h" //------default constructor-------- BackpropagationNode::BackpropagationNode () {} //general error calculation function - useful for all hidden nodes double BackpropagationNode::calculateError () { double scratch = 0; for (int counter = 0; counter < numberOutputConnections; ++counter) { scratch += (outputWeight[counter].readNodePointer())->readError() * outputWeight[counter].readWeight(); } error = scratch * output * (1 - output); return error; } //error calculation for output nodes inline double BackpropagationNode::calculateOutputNodeError (double actualValue) { error = (actualValue - output) * output * (1 - output); return error; } //adjust the weights on this node to minimize the error in output //usesful for all except the first hidden layer nodes void BackpropagationNode::adjustWeight (double eta = 0.5, double alpha = 0.5) { for (int counter = 0; counter < numberInputConnections; ++counter) { double scratch = inputWeight[counter].readWeightChange(); inputWeight[counter].writeWeightChange (eta * error * (inputWeight[counter].readNodePointer())->readActivation() + alpha * scratch); inputWeight[counter].writeWeight (inputWeight[counter].readWeight() + inputWeight[counter].readWeightChange()); //now adjust the weight on the corresponding output connections of the node //feeding into this one (inputWeight[counter].readNodePointer())->writeOutputConnectionWeight(*this, inputWeight[counter].readWeight()); } } //adust the weights on this first hidden layer node to minimize the error in output // must be different than for other layers because we are using input data values // instead of the outputs from lower layer nodes void BackpropagationNode::adjustFirstHiddenLayerNodeWeight (double *inputVector, double eta = 0.7, double alpha = 0.5) { for (int counter = 0; counter < numberInputConnections; ++counter) { double scratch = inputWeight[counter].readWeightChange(); inputWeight[counter].writeWeightChange (eta * error * *(inputVector+counter) + alpha*scratch); inputWeight[counter].writeWeight (inputWeight[counter].readWeight() + inputWeight[counter].readWeightChange()); } } #endif