Fix branch predictor task being wrong in several orthogonal ways.

This commit is contained in:
peteraa 2019-10-28 09:40:16 +01:00
parent 2944ee9d4e
commit 7394e7a464
5 changed files with 81 additions and 20 deletions

View file

@ -37,11 +37,11 @@ object Data {
case class MemRead(addr: Addr, word: Int) extends ExecutionEvent
// addr is the target address
case class PcUpdateJALR(addr: Addr) extends ExecutionEvent
case class PcUpdateJAL(addr: Addr) extends ExecutionEvent
case class PcUpdateBranch(addr: Addr) extends ExecutionEvent
case class PcUpdateNoBranch(addr: Addr) extends ExecutionEvent
case class PcUpdate(addr: Addr) extends ExecutionEvent
case class PcUpdateJALR(addr: Addr) extends ExecutionEvent
case class PcUpdateJAL(addr: Addr) extends ExecutionEvent
case class PcUpdateBranch(addr: Addr, target: Addr) extends ExecutionEvent
case class PcUpdateNoBranch(addr: Addr) extends ExecutionEvent
case class PcUpdate(addr: Addr) extends ExecutionEvent
case class ExecutionTraceEvent(pc: Addr, event: ExecutionEvent*){ override def toString(): String = s"$pc: " + event.toList.mkString(", ") }
type ExecutionTrace[A] = Writer[List[ExecutionTraceEvent], A]
@ -169,6 +169,17 @@ object Data {
}
def log2: Int = math.ceil(math.log(i.toDouble)/math.log(2.0)).toInt
// Discards two lowest bits
def getTag(slots: Int): Int = {
val bitsLeft = 32 - (slots.log2 + 2)
val bitsRight = 32 - slots.log2
val leftShifted = i << bitsLeft
val rightShifted = leftShifted >>> bitsRight
// say(i)
// say(rightShifted)
rightShifted
}
}
implicit class StringOps(s: String) {

View file

@ -43,7 +43,7 @@ case class VM(
val takeBranch = regs.compare(op.rs1, op.rs2, op.comp.run)
if(takeBranch){
val nextVM = copy(pc = addr)
jump(nextVM, PcUpdateBranch(nextVM.pc))
jump(nextVM, PcUpdateBranch(pc, nextVM.pc))
}
else {
step(this, PcUpdateNoBranch(this.pc + Addr(4)))

View file

@ -40,10 +40,10 @@ object PrintUtils {
case MemRead(addr, word) => fansi.Color.Red(f"M[${addr.show}] -> 0x${word.hs}")
// addr is the target address
case PcUpdateJALR(addr) => fansi.Color.Green(s"PC updated to ${addr.show} via JALR")
case PcUpdateJAL(addr) => fansi.Color.Magenta(s"PC updated to ${addr.show} via JAL")
case PcUpdateBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show} via Branch")
case PcUpdateNoBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show}, skipping a Branch")
case PcUpdateJALR(addr) => fansi.Color.Green(s"PC updated to ${addr.show} via JALR")
case PcUpdateJAL(addr) => fansi.Color.Magenta(s"PC updated to ${addr.show} via JAL")
case PcUpdateBranch(from, to) => fansi.Color.Yellow(s"PC updated to ${to.show} via Branch")
case PcUpdateNoBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show}, skipping a Branch")
}
}

View file

@ -111,12 +111,12 @@ object TestRunner {
} yield {
sealed trait BranchEvent
case class Taken(addr: Int) extends BranchEvent
case class NotTaken(addr: Int) extends BranchEvent
case class Taken(from: Int, to: Int) extends BranchEvent { override def toString = s"Taken ${from.hs}\t${to.hs}" }
case class NotTaken(addr: Int) extends BranchEvent { override def toString = s"Not Taken ${addr.hs}" }
val events: List[BranchEvent] = trace.flatMap(_.event).collect{
case PcUpdateBranch(x) => Taken(x.value)
case PcUpdateNoBranch(x) => NotTaken(x.value)
case PcUpdateBranch(from, to) => Taken(from.value, to.value)
case PcUpdateNoBranch(at) => NotTaken(at.value)
}
@ -126,6 +126,9 @@ object TestRunner {
*/
def OneBitInfiniteSlots(events: List[BranchEvent]): Int = {
// Uncomment to take a look at the event log
// say(events.mkString("\n","\n","\n"))
// Helper inspects the next element of the event list. If the event is a mispredict the prediction table is updated
// to reflect this.
// As long as there are remaining events the helper calls itself recursively on the remainder
@ -145,24 +148,69 @@ object TestRunner {
// `case Constructor(arg1, arg2) :: t => if(p(arg1, arg2))`
// means we want to match a list whose first element is of type Constructor while satisfying some predicate p,
// called an if guard.
case Taken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable)
case Taken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, true))
case NotTaken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, false))
case NotTaken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable)
case Taken(from, to) :: t if( predictionTable(from)) => helper(t, predictionTable)
case Taken(from, to) :: t if(!predictionTable(from)) => 1 + helper(t, predictionTable.updated(from, true))
case NotTaken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, false))
case NotTaken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable)
case _ => 0
}
}
// Initially every possible branch is set to false since the initial state of the predictor is to assume branch not taken
def initState = events.map{
case Taken(addr) => (addr, false)
case NotTaken(addr) => (addr, false)
case Taken(from, addr) => (from, false)
case NotTaken(addr) => (addr, false)
}.toMap
helper(events, initState)
}
def nBitPredictor(events: List[BranchEvent]): Int = {
case class nBitPredictor(
values : List[Int],
predictionRules : List[Boolean],
transitionRules : Int => Boolean => Int,
){
val slots = values.size
def predict(pc: Int): Boolean = predictionRules(values(pc.getTag(slots)))
def update(pc: Int, taken: Boolean): nBitPredictor = {
val current = values(pc.getTag(slots))
copy(values = values.updated(pc.getTag(slots), transitionRules(current)(taken)))
}
}
val initPredictor = nBitPredictor(
List.fill(4)(0),
List(
false,
false,
true,
true,
),
r => r match {
case 0 => taken => if(taken) 1 else 0
case 1 => taken => if(taken) 3 else 0
case 2 => taken => if(taken) 3 else 0
case 3 => taken => if(taken) 3 else 2
}
)
events.foldLeft((0, initPredictor)){ case(((acc, bp), event)) => event match {
case Taken(pc, _) if bp.predict(pc) => (acc, bp.update(pc, true))
case Taken(pc, _) => (acc + 1, bp.update(pc, false))
case NotTaken(pc) if !bp.predict(pc) => (acc, bp.update(pc, false))
case NotTaken(pc) => (acc + 1, bp.update(pc, true))
}}._1
}
say(OneBitInfiniteSlots(events))
say(nBitPredictor(events))
}

View file

@ -230,3 +230,5 @@
+ Block size is 4 words (128 bits)
+ Is write-through write no-allocate
+ Eviction policy is LRU (least recently used)
Your answer should be the number of cache miss latency cycles when using this cache.