Cached more values in tree exploration calcs
authorJustin Worthe <justin@worthe-it.co.za>
Sat, 8 Sep 2018 11:01:33 +0000 (13:01 +0200)
committerJustin Worthe <justin@worthe-it.co.za>
Sat, 8 Sep 2018 11:01:33 +0000 (13:01 +0200)
src/strategy/monte_carlo_tree.rs

index fe59e34..2d27b62 100644 (file)
@@ -15,8 +15,10 @@ struct NodeStats {
     wins: f32,
     losses: f32,
     attempts: f32,
+    average: f32,
+    confidence: f32,
     explored: Vec<(Command, NodeStats)>,
-    unexplored: Vec<Command>
+    unexplored: Vec<Command>,
 }
 
 impl NodeStats {
@@ -58,6 +60,8 @@ impl NodeStats {
             wins: 0.,
             losses: 0.,
             attempts: 0.,
+            average: 0.,
+            confidence: 0.,
             explored: Vec::with_capacity(commands.len()),
             unexplored: commands
         }
@@ -67,11 +71,12 @@ impl NodeStats {
         debug_assert!(self.unexplored.is_empty());
         debug_assert!(self.explored.len() > 0);
         let total_attempts = self.explored.iter().map(|(_, n)| n.attempts).sum::<f32>();
+        let sqrt_n = total_attempts.sqrt();
 
         let mut max_position = 0;
-        let mut max_value = self.explored[0].1.ucb(total_attempts);
+        let mut max_value = self.explored[0].1.ucb(sqrt_n);
         for i in 1..self.explored.len() {
-            let value = self.explored[i].1.ucb(total_attempts);
+            let value = self.explored[i].1.ucb(sqrt_n);
             if value > max_value {
                 max_position = i;
                 max_value = value;
@@ -80,8 +85,8 @@ impl NodeStats {
         &mut self.explored[max_position]
     }
 
-    fn ucb(&self, n: f32) -> f32 {
-        self.wins / self.attempts + (2.0 * n / self.attempts).sqrt()
+    fn ucb(&self, sqrt_n: f32) -> f32 {
+        self.average + sqrt_n * self.confidence
     }
 
     fn add_node<'a>(&'a mut self, player: &Player, command: Command) -> &'a mut (Command, NodeStats) {
@@ -94,13 +99,20 @@ impl NodeStats {
     fn add_victory(&mut self) {
         self.attempts += 1.;
         self.wins += 1.;
+        self.update_confidence();
     }
     fn add_defeat(&mut self) {
         self.attempts += 1.;
         self.losses += 1.;
+        self.update_confidence();
     }
     fn add_draw(&mut self) {
         self.attempts += 1.;
+        self.update_confidence();
+    }
+    fn update_confidence(&mut self) {
+        self.average = self.wins / self.attempts;
+        self.confidence = (2.0 / self.attempts).sqrt();
     }
 
     #[cfg(feature = "benchmarking")]