diff options
-rw-r--r-- | include/brain/neural_network.h | 2 | ||||
-rw-r--r-- | src/brain/neural_network.cpp | 26 | ||||
-rw-r--r-- | test/neural_network.cpp | 20 |
3 files changed, 47 insertions, 1 deletions
diff --git a/include/brain/neural_network.h b/include/brain/neural_network.h index bb53079..308cb56 100644 --- a/include/brain/neural_network.h +++ b/include/brain/neural_network.h @@ -13,7 +13,7 @@ public: NeuralNetwork(std::istream &&networkConfigFile, int numberOfSensors, int numberOfOutputs); void setInput(int inputIndex, double activation); - int findMaxOutputIndex(); + int findMaxOutputIndex() const; int numberOfSensors() const { return _sensors.size(); } int numberOfOutputs() const { return _outputs.size(); } diff --git a/src/brain/neural_network.cpp b/src/brain/neural_network.cpp index c3e9f33..2d3b902 100644 --- a/src/brain/neural_network.cpp +++ b/src/brain/neural_network.cpp @@ -13,3 +13,29 @@ NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, int numberOfSenso } } +void NeuralNetwork::setInput(int inputIndex, double activation) +{ + for (auto sensor : _sensors) + { + if (sensor->id() == inputIndex) + { + sensor->setActivation(activation); + } + } +} + +int NeuralNetwork::findMaxOutputIndex() const +{ + double currentMaxActivation = 0; + int currentMaxIndex = 0; + for (auto output : _outputs) + { + double activation = output->activation(); + if (activation >= currentMaxActivation) + { + currentMaxActivation = activation; + currentMaxIndex = output->id(); + } + } + return currentMaxIndex; +} diff --git a/test/neural_network.cpp b/test/neural_network.cpp index ba0df4d..6040c83 100644 --- a/test/neural_network.cpp +++ b/test/neural_network.cpp @@ -21,4 +21,24 @@ SCENARIO("network is read from istream") } } } + + GIVEN("a valid config file") + { + std::stringstream file; + file << "s0 n2 0.5" << std::endl; + file << "n2 n1 1" << std::endl; + file << "b0 n0 0.4" << std::endl; + file << "b0 n2 0.5" << std:: endl; + + WHEN("the network is initialized") + { + NeuralNetwork network(std::move(file), 1, 2); + + THEN("the network is constructed correctly") + { + network.setInput(0, 1); + REQUIRE(network.findMaxOutputIndex() == 1); + } + } + } } |