Added benchmarking for number of explored nodes
authorJustin Worthe <justin@worthe-it.co.za>
Thu, 6 Sep 2018 19:51:50 +0000 (21:51 +0200)
committerJustin Worthe <justin@worthe-it.co.za>
Thu, 6 Sep 2018 19:51:50 +0000 (21:51 +0200)
src/strategy/monte_carlo_tree.rs
tests/monte_carlo_test.rs

index 4efded8..7d688f2 100644 (file)
@@ -66,6 +66,7 @@ impl NodeStats {
     
     fn node_with_highest_ucb<'a>(&'a mut self) -> &'a mut (Command, NodeStats) {
         debug_assert!(self.unexplored.is_empty());
+        debug_assert!(self.explored.len() > 0);
         let total_attempts = self.explored.iter().map(|(_, n)| n.attempts).sum::<f32>();
 
         let mut max_position = 0;
@@ -102,6 +103,10 @@ impl NodeStats {
     fn add_draw(&mut self) {
         self.attempts += 1.;
     }
+
+    fn count_explored(&self) -> usize {
+        1 + self.explored.iter().map(|(_, n)| n.count_explored()).sum::<usize>()
+    }
 }
 
 pub fn choose_move(state: &BitwiseGameState, start_time: PreciseTime, max_time: Duration) -> Command {
@@ -113,6 +118,11 @@ pub fn choose_move(state: &BitwiseGameState, start_time: PreciseTime, max_time:
         tree_search(&state, &mut root, &mut rng);
     }
 
+    #[cfg(feature = "benchmarking")]
+    {
+        println!("Explored nodes: {}", root.count_explored());
+    }
+
     let (command, _) = root.node_with_highest_ucb();
     command.clone()
 }
index 1fb4238..470c92d 100644 (file)
@@ -19,3 +19,16 @@ fn it_does_a_normal_turn_successfully() {
 
     assert!(start_time.to(PreciseTime::now()) < max_time + Duration::milliseconds(50))
 }
+
+#[test]
+fn it_does_a_normal_tree_serach_turn_successfully() {
+    let start_time = PreciseTime::now();
+    let state = match input::json::read_bitwise_state_from_file(STATE_PATH) {
+        Ok(ok) => ok,
+        Err(error) => panic!("Error while parsing JSON file: {}", error)
+    };
+    let max_time = Duration::milliseconds(20000);
+    strategy::monte_carlo_tree::choose_move(&state, start_time, max_time);
+
+    assert!(start_time.to(PreciseTime::now()) < max_time + Duration::milliseconds(50))
+}