summaryrefslogtreecommitdiff
path: root/2015-spacebot/test/neural_network.cpp
blob: 418f5c410e3af7779ce37461f18eccc8fa47657b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include "catch.hpp"
#include <sstream>

#include "brain/neural_network.h"

SCENARIO("network is read from istream")
{
    GIVEN("an empty config file")
    {
        std::stringstream file;
        file << "" << std::endl;

        WHEN ("the network is initialized")
        {
            NeuralNetwork network(std::move(file), 1, 2);
	    
            THEN("the specified number of inputs and outputs are created")
            {
                REQUIRE(network.numberOfSensors() == 1);
                REQUIRE(network.numberOfOutputs() == 2);
            }
        }
    }

    GIVEN("a valid config file")
    {
        std::stringstream file;
        file << "s0 n3 0.5" << std::endl;
        file << "n3 n1 1" << std::endl;
        file << "b0 n0 0.4" << std::endl;
        file << "b0 n3 0.5" << std::endl;

        WHEN("the network is initialized")
        {
            NeuralNetwork network(std::move(file), 1, 3);
	    
            THEN("the network is constructed correctly")
            {
                REQUIRE(network.linkExists("s0", "n3", 0.5));
                REQUIRE(network.linkExists("n3", "n1", 1));
                REQUIRE(network.linkExists("b0", "n0", 0.4));
                REQUIRE(network.linkExists("b0", "n3", 0.5));
            }
            THEN("The network evaluates correctly")
            {
                network.setInput(0, 1);
                REQUIRE(network.findMaxOutputIndex() == 1);
            }
        }
    }

    GIVEN("a valid recurrant config file")
    {
        std::stringstream file;
        file << "s0 n3 0.5" << std::endl;
        file << "n3 n1 1" << std::endl;
        file << "b0 n0 0.4" << std::endl;
        file << "b0 n3 0.5" << std::endl;
        file << "n1 n3 0.5" << std::endl;

        WHEN("the network converges")
        {
            NeuralNetwork network(std::move(file), 1, 3);
	    
            THEN("the network is constructed correctly")
            {
                network.setInput(0, 1);
                REQUIRE(network.findMaxOutputIndex() == 1);
            }
        }
    }

    GIVEN("my handcoded config file")
    {
        std::stringstream file;
        file << "b0 n0 20" << std::endl;
        file << "s55 n3 10" << std::endl;
        file << "b0 n4 -10" << std::endl;
        file << "s59 n4 -50" << std::endl;
        file << "s60 n4 20" << std::endl;
        file << "b0 n6 10" << std::endl;
        file << "s51 n6 -10" << std::endl;
        file << "s53 n6 -10" << std::endl;
        file << "n3 n0 -20" << std::endl;
        file << "n4 n0 -20" << std::endl;
        file << "n6 n0 -20" << std::endl;
        file << "n3 n4 -20" << std::endl;
        file << "n6 n3 -20" << std::endl;
        file << "n6 n4 -20" << std::endl;

        WHEN("the netwok is constructed")
        {
            std::vector<bool> sensors(61);
            
            NeuralNetwork network(std::move(file), sensors, 7);
            THEN("it is constructred correctly")
            {
                REQUIRE(network.linkExists("b0", "n0", 20));
                REQUIRE(network.linkExists("s55", "n3", 10));
                REQUIRE(network.linkExists("b0", "n4", -10));
                REQUIRE(network.linkExists("s59", "n4", -50));
                REQUIRE(network.linkExists("s60", "n4", 20));
                REQUIRE(network.linkExists("b0", "n6", 10));
                REQUIRE(network.linkExists("s51", "n6", -10));
                REQUIRE(network.linkExists("s53", "n6", -10));
                REQUIRE(network.linkExists("n3", "n0", -20));
                REQUIRE(network.linkExists("n4", "n0", -20));
                REQUIRE(network.linkExists("n6", "n0", -20));
                REQUIRE(network.linkExists("n3", "n4", -20));
                REQUIRE(network.linkExists("n6", "n3", -20));
                REQUIRE(network.linkExists("n6", "n4", -20));
            }

            THEN("it has the right number of nodes and sensors")
            {
                REQUIRE(network.numberOfSensors() == 61);
                REQUIRE(network.numberOfOutputs() == 7);
                REQUIRE(network.numberOfNeurons() == 7);
            }
        }
    }
}