diff --git a/src/test/scala/RISCV/DataTypes.scala b/src/test/scala/RISCV/DataTypes.scala index 5627015..14d6c61 100644 --- a/src/test/scala/RISCV/DataTypes.scala +++ b/src/test/scala/RISCV/DataTypes.scala @@ -37,11 +37,11 @@ object Data { case class MemRead(addr: Addr, word: Int) extends ExecutionEvent // addr is the target address - case class PcUpdateJALR(addr: Addr) extends ExecutionEvent - case class PcUpdateJAL(addr: Addr) extends ExecutionEvent - case class PcUpdateBranch(addr: Addr) extends ExecutionEvent - case class PcUpdateNoBranch(addr: Addr) extends ExecutionEvent - case class PcUpdate(addr: Addr) extends ExecutionEvent + case class PcUpdateJALR(addr: Addr) extends ExecutionEvent + case class PcUpdateJAL(addr: Addr) extends ExecutionEvent + case class PcUpdateBranch(addr: Addr, target: Addr) extends ExecutionEvent + case class PcUpdateNoBranch(addr: Addr) extends ExecutionEvent + case class PcUpdate(addr: Addr) extends ExecutionEvent case class ExecutionTraceEvent(pc: Addr, event: ExecutionEvent*){ override def toString(): String = s"$pc: " + event.toList.mkString(", ") } type ExecutionTrace[A] = Writer[List[ExecutionTraceEvent], A] @@ -169,6 +169,17 @@ object Data { } def log2: Int = math.ceil(math.log(i.toDouble)/math.log(2.0)).toInt + + // Discards two lowest bits + def getTag(slots: Int): Int = { + val bitsLeft = 32 - (slots.log2 + 2) + val bitsRight = 32 - slots.log2 + val leftShifted = i << bitsLeft + val rightShifted = leftShifted >>> bitsRight + // say(i) + // say(rightShifted) + rightShifted + } } implicit class StringOps(s: String) { diff --git a/src/test/scala/RISCV/VM.scala b/src/test/scala/RISCV/VM.scala index effaf6d..cf597ba 100644 --- a/src/test/scala/RISCV/VM.scala +++ b/src/test/scala/RISCV/VM.scala @@ -43,7 +43,7 @@ case class VM( val takeBranch = regs.compare(op.rs1, op.rs2, op.comp.run) if(takeBranch){ val nextVM = copy(pc = addr) - jump(nextVM, PcUpdateBranch(nextVM.pc)) + jump(nextVM, PcUpdateBranch(pc, nextVM.pc)) } else { step(this, PcUpdateNoBranch(this.pc + Addr(4))) diff --git a/src/test/scala/RISCV/printUtils.scala b/src/test/scala/RISCV/printUtils.scala index 980e52e..02b76fc 100644 --- a/src/test/scala/RISCV/printUtils.scala +++ b/src/test/scala/RISCV/printUtils.scala @@ -40,10 +40,10 @@ object PrintUtils { case MemRead(addr, word) => fansi.Color.Red(f"M[${addr.show}] -> 0x${word.hs}") // addr is the target address - case PcUpdateJALR(addr) => fansi.Color.Green(s"PC updated to ${addr.show} via JALR") - case PcUpdateJAL(addr) => fansi.Color.Magenta(s"PC updated to ${addr.show} via JAL") - case PcUpdateBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show} via Branch") - case PcUpdateNoBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show}, skipping a Branch") + case PcUpdateJALR(addr) => fansi.Color.Green(s"PC updated to ${addr.show} via JALR") + case PcUpdateJAL(addr) => fansi.Color.Magenta(s"PC updated to ${addr.show} via JAL") + case PcUpdateBranch(from, to) => fansi.Color.Yellow(s"PC updated to ${to.show} via Branch") + case PcUpdateNoBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show}, skipping a Branch") } } diff --git a/src/test/scala/RISCV/testRunner.scala b/src/test/scala/RISCV/testRunner.scala index aa53791..80a582e 100644 --- a/src/test/scala/RISCV/testRunner.scala +++ b/src/test/scala/RISCV/testRunner.scala @@ -111,12 +111,12 @@ object TestRunner { } yield { sealed trait BranchEvent - case class Taken(addr: Int) extends BranchEvent - case class NotTaken(addr: Int) extends BranchEvent + case class Taken(from: Int, to: Int) extends BranchEvent { override def toString = s"Taken ${from.hs}\t${to.hs}" } + case class NotTaken(addr: Int) extends BranchEvent { override def toString = s"Not Taken ${addr.hs}" } val events: List[BranchEvent] = trace.flatMap(_.event).collect{ - case PcUpdateBranch(x) => Taken(x.value) - case PcUpdateNoBranch(x) => NotTaken(x.value) + case PcUpdateBranch(from, to) => Taken(from.value, to.value) + case PcUpdateNoBranch(at) => NotTaken(at.value) } @@ -126,6 +126,9 @@ object TestRunner { */ def OneBitInfiniteSlots(events: List[BranchEvent]): Int = { + // Uncomment to take a look at the event log + // say(events.mkString("\n","\n","\n")) + // Helper inspects the next element of the event list. If the event is a mispredict the prediction table is updated // to reflect this. // As long as there are remaining events the helper calls itself recursively on the remainder @@ -145,24 +148,69 @@ object TestRunner { // `case Constructor(arg1, arg2) :: t => if(p(arg1, arg2))` // means we want to match a list whose first element is of type Constructor while satisfying some predicate p, // called an if guard. - case Taken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable) - case Taken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, true)) - case NotTaken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, false)) - case NotTaken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable) + case Taken(from, to) :: t if( predictionTable(from)) => helper(t, predictionTable) + case Taken(from, to) :: t if(!predictionTable(from)) => 1 + helper(t, predictionTable.updated(from, true)) + case NotTaken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, false)) + case NotTaken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable) case _ => 0 } } // Initially every possible branch is set to false since the initial state of the predictor is to assume branch not taken def initState = events.map{ - case Taken(addr) => (addr, false) - case NotTaken(addr) => (addr, false) + case Taken(from, addr) => (from, false) + case NotTaken(addr) => (addr, false) }.toMap helper(events, initState) } + + def nBitPredictor(events: List[BranchEvent]): Int = { + + case class nBitPredictor( + values : List[Int], + predictionRules : List[Boolean], + transitionRules : Int => Boolean => Int, + ){ + val slots = values.size + + def predict(pc: Int): Boolean = predictionRules(values(pc.getTag(slots))) + + def update(pc: Int, taken: Boolean): nBitPredictor = { + val current = values(pc.getTag(slots)) + copy(values = values.updated(pc.getTag(slots), transitionRules(current)(taken))) + } + } + + val initPredictor = nBitPredictor( + List.fill(4)(0), + List( + false, + false, + true, + true, + ), + r => r match { + case 0 => taken => if(taken) 1 else 0 + case 1 => taken => if(taken) 3 else 0 + case 2 => taken => if(taken) 3 else 0 + case 3 => taken => if(taken) 3 else 2 + } + ) + + events.foldLeft((0, initPredictor)){ case(((acc, bp), event)) => event match { + case Taken(pc, _) if bp.predict(pc) => (acc, bp.update(pc, true)) + case Taken(pc, _) => (acc + 1, bp.update(pc, false)) + case NotTaken(pc) if !bp.predict(pc) => (acc, bp.update(pc, false)) + case NotTaken(pc) => (acc + 1, bp.update(pc, true)) + }}._1 + } + + + say(OneBitInfiniteSlots(events)) + say(nBitPredictor(events)) } diff --git a/theory2.org b/theory2.org index 5b0d0c9..90bf611 100644 --- a/theory2.org +++ b/theory2.org @@ -230,3 +230,5 @@ + Block size is 4 words (128 bits) + Is write-through write no-allocate + Eviction policy is LRU (least recently used) + + Your answer should be the number of cache miss latency cycles when using this cache.