diff --git a/src/main.rs b/src/main.rs index 8f67991..53f7c14 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,17 @@ use std::collections::HashMap; use std::fs::File; use std::io::{self, BufRead, BufReader}; +fn retrieve( + memory: &mut HashMap, + ops: Vec +) -> i64 { + if ops.len() >= 3 { + return *memory.get(&ops[2]).unwrap_or(&0); + } else { + return *memory.get(&ops[1]).unwrap_or(&0); + } +} + fn process( line: String, accumulator: &mut i64, @@ -12,7 +23,7 @@ fn process( let ops: Vec = line.trim().split_whitespace().map(String::from).collect(); let options: Vec<&str>; - if ops.len() > 1 { + if ops.len() >= 3 { options = vec![ops[0].as_str(), ops[1].as_str()]; } else { options = vec![ops[0].as_str()]; @@ -27,31 +38,35 @@ fn process( } "OUT" => println!("{}", accumulator), "STA" => { - memory.insert(ops[1].clone(), *accumulator); + if ops.len() >= 3 { + memory.insert(ops[2].clone(), *accumulator); + } else { + memory.insert(ops[1].clone(), *accumulator); + } } "LDA" => { - if ops.len() > 1 { - *accumulator = *memory.get(&ops[2]).unwrap_or(&0); - } else { - *accumulator = *memory.get(&ops[1]).unwrap_or(&0); - } + *accumulator = retrieve(memory, ops.clone()); } "ADD" => { if let Ok(value) = ops[1].parse::() { *accumulator += value; } else { - *accumulator += *memory.get(&ops[1]).unwrap_or(&0); + *accumulator += retrieve(memory, ops.clone()); } } "SUB" => { if let Ok(value) = ops[1].parse::() { *accumulator -= value; } else { - *accumulator -= *memory.get(&ops[1]).unwrap_or(&0); + *accumulator -= retrieve(memory, ops.clone()); } } "BRA" => { - *pc = *labels.get(&ops[1]).unwrap(); + if ops.len() >= 3 { + *pc = *labels.get(&ops[2]).unwrap(); + } else { + *pc = *labels.get(&ops[1]).unwrap(); + } } "HLT" => std::process::exit(0), _ => {