diff options
author | Justin Worthe <justin@worthe-it.co.za> | 2018-09-06 21:51:50 +0200 |
---|---|---|
committer | Justin Worthe <justin@worthe-it.co.za> | 2018-09-06 21:51:50 +0200 |
commit | 90a7c7d34def7e5f92f2cd521fdc014e0cbd9906 (patch) | |
tree | 65b45bbf4bc7204b6189d6fb39180a39527b14eb | |
parent | 4ad0035f4f11b41e400a1f567fdcd3541fa3f21e (diff) |
Added benchmarking for number of explored nodes
-rw-r--r-- | src/strategy/monte_carlo_tree.rs | 10 | ||||
-rw-r--r-- | tests/monte_carlo_test.rs | 13 |
2 files changed, 23 insertions, 0 deletions
diff --git a/src/strategy/monte_carlo_tree.rs b/src/strategy/monte_carlo_tree.rs index 4efded8..7d688f2 100644 --- a/src/strategy/monte_carlo_tree.rs +++ b/src/strategy/monte_carlo_tree.rs @@ -66,6 +66,7 @@ impl NodeStats { fn node_with_highest_ucb<'a>(&'a mut self) -> &'a mut (Command, NodeStats) { debug_assert!(self.unexplored.is_empty()); + debug_assert!(self.explored.len() > 0); let total_attempts = self.explored.iter().map(|(_, n)| n.attempts).sum::<f32>(); let mut max_position = 0; @@ -102,6 +103,10 @@ impl NodeStats { fn add_draw(&mut self) { self.attempts += 1.; } + + fn count_explored(&self) -> usize { + 1 + self.explored.iter().map(|(_, n)| n.count_explored()).sum::<usize>() + } } pub fn choose_move(state: &BitwiseGameState, start_time: PreciseTime, max_time: Duration) -> Command { @@ -113,6 +118,11 @@ pub fn choose_move(state: &BitwiseGameState, start_time: PreciseTime, max_time: tree_search(&state, &mut root, &mut rng); } + #[cfg(feature = "benchmarking")] + { + println!("Explored nodes: {}", root.count_explored()); + } + let (command, _) = root.node_with_highest_ucb(); command.clone() } diff --git a/tests/monte_carlo_test.rs b/tests/monte_carlo_test.rs index 1fb4238..470c92d 100644 --- a/tests/monte_carlo_test.rs +++ b/tests/monte_carlo_test.rs @@ -19,3 +19,16 @@ fn it_does_a_normal_turn_successfully() { assert!(start_time.to(PreciseTime::now()) < max_time + Duration::milliseconds(50)) } + +#[test] +fn it_does_a_normal_tree_serach_turn_successfully() { + let start_time = PreciseTime::now(); + let state = match input::json::read_bitwise_state_from_file(STATE_PATH) { + Ok(ok) => ok, + Err(error) => panic!("Error while parsing JSON file: {}", error) + }; + let max_time = Duration::milliseconds(20000); + strategy::monte_carlo_tree::choose_move(&state, start_time, max_time); + + assert!(start_time.to(PreciseTime::now()) < max_time + Duration::milliseconds(50)) +} |