summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Worthe <justin@worthe-it.co.za>2018-09-06 21:51:50 +0200
committerJustin Worthe <justin@worthe-it.co.za>2018-09-06 21:51:50 +0200
commit90a7c7d34def7e5f92f2cd521fdc014e0cbd9906 (patch)
tree65b45bbf4bc7204b6189d6fb39180a39527b14eb
parent4ad0035f4f11b41e400a1f567fdcd3541fa3f21e (diff)
Added benchmarking for number of explored nodes
-rw-r--r--src/strategy/monte_carlo_tree.rs10
-rw-r--r--tests/monte_carlo_test.rs13
2 files changed, 23 insertions, 0 deletions
diff --git a/src/strategy/monte_carlo_tree.rs b/src/strategy/monte_carlo_tree.rs
index 4efded8..7d688f2 100644
--- a/src/strategy/monte_carlo_tree.rs
+++ b/src/strategy/monte_carlo_tree.rs
@@ -66,6 +66,7 @@ impl NodeStats {
fn node_with_highest_ucb<'a>(&'a mut self) -> &'a mut (Command, NodeStats) {
debug_assert!(self.unexplored.is_empty());
+ debug_assert!(self.explored.len() > 0);
let total_attempts = self.explored.iter().map(|(_, n)| n.attempts).sum::<f32>();
let mut max_position = 0;
@@ -102,6 +103,10 @@ impl NodeStats {
fn add_draw(&mut self) {
self.attempts += 1.;
}
+
+ fn count_explored(&self) -> usize {
+ 1 + self.explored.iter().map(|(_, n)| n.count_explored()).sum::<usize>()
+ }
}
pub fn choose_move(state: &BitwiseGameState, start_time: PreciseTime, max_time: Duration) -> Command {
@@ -113,6 +118,11 @@ pub fn choose_move(state: &BitwiseGameState, start_time: PreciseTime, max_time:
tree_search(&state, &mut root, &mut rng);
}
+ #[cfg(feature = "benchmarking")]
+ {
+ println!("Explored nodes: {}", root.count_explored());
+ }
+
let (command, _) = root.node_with_highest_ucb();
command.clone()
}
diff --git a/tests/monte_carlo_test.rs b/tests/monte_carlo_test.rs
index 1fb4238..470c92d 100644
--- a/tests/monte_carlo_test.rs
+++ b/tests/monte_carlo_test.rs
@@ -19,3 +19,16 @@ fn it_does_a_normal_turn_successfully() {
assert!(start_time.to(PreciseTime::now()) < max_time + Duration::milliseconds(50))
}
+
+#[test]
+fn it_does_a_normal_tree_serach_turn_successfully() {
+ let start_time = PreciseTime::now();
+ let state = match input::json::read_bitwise_state_from_file(STATE_PATH) {
+ Ok(ok) => ok,
+ Err(error) => panic!("Error while parsing JSON file: {}", error)
+ };
+ let max_time = Duration::milliseconds(20000);
+ strategy::monte_carlo_tree::choose_move(&state, start_time, max_time);
+
+ assert!(start_time.to(PreciseTime::now()) < max_time + Duration::milliseconds(50))
+}