summaryrefslogtreecommitdiff
path: root/2019/src/bin/day_22.rs
diff options
context:
space:
mode:
Diffstat (limited to '2019/src/bin/day_22.rs')
-rw-r--r--2019/src/bin/day_22.rs325
1 files changed, 325 insertions, 0 deletions
diff --git a/2019/src/bin/day_22.rs b/2019/src/bin/day_22.rs
new file mode 100644
index 0000000..5b999a6
--- /dev/null
+++ b/2019/src/bin/day_22.rs
@@ -0,0 +1,325 @@
+use derive_more::Display;
+use num::bigint::BigInt;
+use num::traits::identities::Zero;
+use num::traits::sign::abs;
+use num::traits::Signed;
+use std::fmt;
+use std::io;
+use std::io::prelude::*;
+use std::process;
+use std::str::FromStr;
+use structopt::StructOpt;
+
+#[derive(Debug, StructOpt)]
+#[structopt(name = "Day 22: Slam Shuffle")]
+/// Shuffles some cards.
+///
+/// See https://adventofcode.com/2019/day/22 for details.
+struct Opt {
+ /// The size of the deck
+ deck_size: BigInt,
+ /// At the end, query the position of card
+ card: BigInt,
+ /// Number of repetitions
+ repetitions: BigInt,
+
+ /// Prints the card in position n, rather than the position of card n
+ #[structopt(short = "p")]
+ position_mode: bool,
+}
+
+fn main() {
+ let stdin = io::stdin();
+ let opt = Opt::from_args();
+
+ let instructions = stdin
+ .lock()
+ .lines()
+ .map(|x| exit_on_failed_assertion(x, "Error reading input"))
+ .map(|x| exit_on_failed_assertion(x.parse::<Instruction>(), "Parse error"))
+ .collect::<Vec<Instruction>>();
+
+ //eprintln!("{:?}", instructions);
+
+ if opt.position_mode {
+ println!(
+ "{}",
+ instructions
+ .iter()
+ .rev()
+ .fold(
+ StandardisedInstruction::identity(opt.deck_size.clone()),
+ |acc, next| acc.then(&(next.clone(), opt.deck_size.clone(), false).into())
+ )
+ .repeat(opt.repetitions)
+ .apply(opt.card.clone())
+ );
+ } else {
+ println!(
+ "{}",
+ instructions
+ .iter()
+ .fold(
+ StandardisedInstruction::identity(opt.deck_size.clone()),
+ |acc, next| {
+ eprintln!("{}", acc);
+ acc.then(&(next.clone(), opt.deck_size.clone(), true).into())
+ }
+ )
+ .repeat(opt.repetitions)
+ .apply(opt.card.clone())
+ );
+ }
+}
+
+fn exit_on_failed_assertion<A, E: std::error::Error>(data: Result<A, E>, message: &str) -> A {
+ match data {
+ Ok(data) => data,
+ Err(e) => {
+ eprintln!("{}: {}", message, e);
+ process::exit(1);
+ }
+ }
+}
+
+fn mod_plus(a: BigInt, b: BigInt, modulus: BigInt) -> BigInt {
+ mod_normalize(a + b, modulus)
+}
+
+fn mod_sub(a: BigInt, b: BigInt, modulus: BigInt) -> BigInt {
+ mod_normalize(a - b, modulus)
+}
+
+fn mod_times(a: BigInt, b: BigInt, modulus: BigInt) -> BigInt {
+ mod_normalize(a * b, modulus)
+}
+
+fn mod_divide(a: BigInt, b: BigInt, modulus: BigInt) -> BigInt {
+ mod_times(a, mod_inverse(b, modulus.clone()), modulus)
+}
+
+fn mod_pow(a: BigInt, b: BigInt, modulus: BigInt) -> BigInt {
+ a.modpow(&b, &modulus)
+}
+
+fn mod_normalize(a: BigInt, modulus: BigInt) -> BigInt {
+ if a.is_negative() {
+ a.clone() + modulus.clone() * (1 + abs(a) / modulus)
+ } else {
+ a % modulus
+ }
+}
+
+// NB: This may give nonsense if modulus isn't coprime with a
+fn mod_inverse(a: BigInt, modulus: BigInt) -> BigInt {
+ mod_normalize(euclid_gcd_coefficients(a, modulus.clone()).0, modulus)
+}
+
+fn euclid_gcd_coefficients(a: BigInt, b: BigInt) -> (BigInt, BigInt) {
+ fn euclid_gcd_coefficients_inner(
+ r: BigInt,
+ old_r: BigInt,
+ s: BigInt,
+ old_s: BigInt,
+ t: BigInt,
+ old_t: BigInt,
+ ) -> (BigInt, BigInt) {
+ if r.is_zero() {
+ (old_s, old_t)
+ } else {
+ euclid_gcd_coefficients_inner(
+ old_r.clone() - (old_r.clone() / r.clone()) * r.clone(),
+ r.clone(),
+ old_s - (old_r.clone() / r.clone()) * s.clone(),
+ s,
+ old_t - (old_r.clone() / r) * t.clone(),
+ t,
+ )
+ }
+ }
+
+ assert!(a < b);
+
+ euclid_gcd_coefficients_inner(b, a, 0.into(), 1.into(), 1.into(), 0.into())
+}
+
+#[derive(Debug, Clone)]
+enum Instruction {
+ DealIntoNewStack,
+ Cut(BigInt),
+ ReverseCut(BigInt),
+ DealWithIncrement(BigInt),
+}
+
+impl FromStr for Instruction {
+ type Err = ParseErr;
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ if s.starts_with("deal into new stack") {
+ Ok(Instruction::DealIntoNewStack)
+ } else if s.starts_with("cut -") {
+ s.split(' ')
+ .nth(1)
+ .map(|val| {
+ val.parse::<BigInt>()
+ .map_err(|_| ParseErr)
+ .map(|parsed| Instruction::ReverseCut(abs(parsed)))
+ })
+ .unwrap_or(Err(ParseErr))
+ } else if s.starts_with("cut") {
+ s.split(' ')
+ .nth(1)
+ .map(|val| {
+ val.parse::<BigInt>()
+ .map_err(|_| ParseErr)
+ .map(|parsed| Instruction::Cut(parsed))
+ })
+ .unwrap_or(Err(ParseErr))
+ } else if s.starts_with("deal with increment") {
+ s.split(' ')
+ .nth(3)
+ .map(|val| {
+ val.parse::<BigInt>()
+ .map_err(|_| ParseErr)
+ .map(|parsed| Instruction::DealWithIncrement(parsed))
+ })
+ .unwrap_or(Err(ParseErr))
+ } else {
+ Err(ParseErr)
+ }
+ }
+}
+
+// f(x) = ax + b mod c
+#[derive(Display, Clone)]
+#[display(fmt = "f(x) = {} x + {} % {}", a, b, modulus)]
+struct StandardisedInstruction {
+ a: BigInt,
+ b: BigInt,
+ modulus: BigInt,
+}
+
+impl From<(Instruction, BigInt, bool)> for StandardisedInstruction {
+ fn from((instruction, modulus, forward): (Instruction, BigInt, bool)) -> Self {
+ match (instruction, forward) {
+ (Instruction::DealIntoNewStack, _) => StandardisedInstruction {
+ a: BigInt::from(-1),
+ b: BigInt::from(-1),
+ modulus: modulus,
+ },
+ (Instruction::Cut(n), true) => StandardisedInstruction {
+ a: BigInt::from(1),
+ b: BigInt::from(-n),
+ modulus: modulus,
+ },
+ (Instruction::Cut(n), false) => StandardisedInstruction {
+ a: BigInt::from(1),
+ b: BigInt::from(n),
+ modulus: modulus,
+ },
+ (Instruction::ReverseCut(n), true) => StandardisedInstruction {
+ a: BigInt::from(1),
+ b: BigInt::from(n),
+ modulus: modulus,
+ },
+ (Instruction::ReverseCut(n), false) => StandardisedInstruction {
+ a: BigInt::from(1),
+ b: BigInt::from(-n),
+ modulus: modulus,
+ },
+ (Instruction::DealWithIncrement(n), true) => StandardisedInstruction {
+ a: BigInt::from(n),
+ b: BigInt::from(0),
+ modulus: modulus,
+ },
+ (Instruction::DealWithIncrement(n), false) => StandardisedInstruction {
+ a: BigInt::from(mod_inverse(n, modulus.clone())),
+ b: BigInt::from(0),
+ modulus: modulus,
+ },
+ }
+ .normalise()
+ }
+}
+
+impl StandardisedInstruction {
+ fn identity(modulus: BigInt) -> StandardisedInstruction {
+ StandardisedInstruction {
+ a: BigInt::from(1),
+ b: BigInt::from(0),
+ modulus,
+ }
+ }
+ fn normalise(&self) -> StandardisedInstruction {
+ StandardisedInstruction {
+ a: mod_normalize(self.a.clone(), self.modulus.clone()),
+ b: mod_normalize(self.b.clone(), self.modulus.clone()),
+ modulus: self.modulus.clone(),
+ }
+ }
+ fn then(&self, other: &StandardisedInstruction) -> StandardisedInstruction {
+ // g(f(x)) = ga (fa x + fb) + gb =
+ StandardisedInstruction {
+ a: mod_times(self.a.clone(), other.a.clone(), self.modulus.clone()),
+ b: mod_plus(
+ mod_times(self.b.clone(), other.a.clone(), self.modulus.clone()),
+ other.b.clone(),
+ self.modulus.clone(),
+ ),
+ modulus: self.modulus.clone(),
+ }
+ }
+ fn repeat(&self, repetitions: BigInt) -> StandardisedInstruction {
+ StandardisedInstruction {
+ a: mod_pow(self.a.clone(), repetitions.clone(), self.modulus.clone()),
+ b: mod_divide(
+ mod_times(
+ self.b.clone(),
+ mod_sub(
+ BigInt::from(1),
+ mod_pow(self.a.clone(), repetitions.clone(), self.modulus.clone()),
+ self.modulus.clone(),
+ ),
+ self.modulus.clone(),
+ ),
+ mod_sub(BigInt::from(1), self.a.clone(), self.modulus.clone()),
+ self.modulus.clone(),
+ ),
+ modulus: self.modulus.clone(),
+ }
+ }
+
+ fn apply(&self, x: BigInt) -> BigInt {
+ mod_plus(
+ mod_times(self.a.clone(), x, self.modulus.clone()),
+ self.b.clone(),
+ self.modulus.clone(),
+ )
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+struct ParseErr;
+
+impl fmt::Display for ParseErr {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Error parsing input")
+ }
+}
+
+impl std::error::Error for ParseErr {}
+
+#[test]
+fn mod_inverse_of_13() {
+ assert_eq!(mod_inverse(1.into(), 13.into()), 1.into());
+ assert_eq!(mod_inverse(2.into(), 13.into()), 7.into());
+ assert_eq!(mod_inverse(3.into(), 13.into()), 9.into());
+ assert_eq!(mod_inverse(4.into(), 13.into()), 10.into());
+ assert_eq!(mod_inverse(5.into(), 13.into()), 8.into());
+ assert_eq!(mod_inverse(6.into(), 13.into()), 11.into());
+ assert_eq!(mod_inverse(7.into(), 13.into()), 2.into());
+ assert_eq!(mod_inverse(8.into(), 13.into()), 5.into());
+ assert_eq!(mod_inverse(9.into(), 13.into()), 3.into());
+ assert_eq!(mod_inverse(10.into(), 13.into()), 4.into());
+ assert_eq!(mod_inverse(11.into(), 13.into()), 6.into());
+ assert_eq!(mod_inverse(12.into(), 13.into()), 12.into());
+}