summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/brain/neural_network.h2
-rw-r--r--src/brain/neural_network.cpp26
-rw-r--r--test/neural_network.cpp20
3 files changed, 47 insertions, 1 deletions
diff --git a/include/brain/neural_network.h b/include/brain/neural_network.h
index bb53079..308cb56 100644
--- a/include/brain/neural_network.h
+++ b/include/brain/neural_network.h
@@ -13,7 +13,7 @@ public:
NeuralNetwork(std::istream &&networkConfigFile, int numberOfSensors, int numberOfOutputs);
void setInput(int inputIndex, double activation);
- int findMaxOutputIndex();
+ int findMaxOutputIndex() const;
int numberOfSensors() const { return _sensors.size(); }
int numberOfOutputs() const { return _outputs.size(); }
diff --git a/src/brain/neural_network.cpp b/src/brain/neural_network.cpp
index c3e9f33..2d3b902 100644
--- a/src/brain/neural_network.cpp
+++ b/src/brain/neural_network.cpp
@@ -13,3 +13,29 @@ NeuralNetwork::NeuralNetwork(std::istream &&networkConfigFile, int numberOfSenso
}
}
+void NeuralNetwork::setInput(int inputIndex, double activation)
+{
+ for (auto sensor : _sensors)
+ {
+ if (sensor->id() == inputIndex)
+ {
+ sensor->setActivation(activation);
+ }
+ }
+}
+
+int NeuralNetwork::findMaxOutputIndex() const
+{
+ double currentMaxActivation = 0;
+ int currentMaxIndex = 0;
+ for (auto output : _outputs)
+ {
+ double activation = output->activation();
+ if (activation >= currentMaxActivation)
+ {
+ currentMaxActivation = activation;
+ currentMaxIndex = output->id();
+ }
+ }
+ return currentMaxIndex;
+}
diff --git a/test/neural_network.cpp b/test/neural_network.cpp
index ba0df4d..6040c83 100644
--- a/test/neural_network.cpp
+++ b/test/neural_network.cpp
@@ -21,4 +21,24 @@ SCENARIO("network is read from istream")
}
}
}
+
+ GIVEN("a valid config file")
+ {
+ std::stringstream file;
+ file << "s0 n2 0.5" << std::endl;
+ file << "n2 n1 1" << std::endl;
+ file << "b0 n0 0.4" << std::endl;
+ file << "b0 n2 0.5" << std:: endl;
+
+ WHEN("the network is initialized")
+ {
+ NeuralNetwork network(std::move(file), 1, 2);
+
+ THEN("the network is constructed correctly")
+ {
+ network.setInput(0, 1);
+ REQUIRE(network.findMaxOutputIndex() == 1);
+ }
+ }
+ }
}