Fix branch predictor task being wrong in several orthogonal ways.

This commit is contained in:
peteraa 2019-10-28 09:40:16 +01:00
parent 2944ee9d4e
commit 7394e7a464
5 changed files with 81 additions and 20 deletions

View file

@ -37,11 +37,11 @@ object Data {
case class MemRead(addr: Addr, word: Int) extends ExecutionEvent case class MemRead(addr: Addr, word: Int) extends ExecutionEvent
// addr is the target address // addr is the target address
case class PcUpdateJALR(addr: Addr) extends ExecutionEvent case class PcUpdateJALR(addr: Addr) extends ExecutionEvent
case class PcUpdateJAL(addr: Addr) extends ExecutionEvent case class PcUpdateJAL(addr: Addr) extends ExecutionEvent
case class PcUpdateBranch(addr: Addr) extends ExecutionEvent case class PcUpdateBranch(addr: Addr, target: Addr) extends ExecutionEvent
case class PcUpdateNoBranch(addr: Addr) extends ExecutionEvent case class PcUpdateNoBranch(addr: Addr) extends ExecutionEvent
case class PcUpdate(addr: Addr) extends ExecutionEvent case class PcUpdate(addr: Addr) extends ExecutionEvent
case class ExecutionTraceEvent(pc: Addr, event: ExecutionEvent*){ override def toString(): String = s"$pc: " + event.toList.mkString(", ") } case class ExecutionTraceEvent(pc: Addr, event: ExecutionEvent*){ override def toString(): String = s"$pc: " + event.toList.mkString(", ") }
type ExecutionTrace[A] = Writer[List[ExecutionTraceEvent], A] type ExecutionTrace[A] = Writer[List[ExecutionTraceEvent], A]
@ -169,6 +169,17 @@ object Data {
} }
def log2: Int = math.ceil(math.log(i.toDouble)/math.log(2.0)).toInt def log2: Int = math.ceil(math.log(i.toDouble)/math.log(2.0)).toInt
// Discards two lowest bits
def getTag(slots: Int): Int = {
val bitsLeft = 32 - (slots.log2 + 2)
val bitsRight = 32 - slots.log2
val leftShifted = i << bitsLeft
val rightShifted = leftShifted >>> bitsRight
// say(i)
// say(rightShifted)
rightShifted
}
} }
implicit class StringOps(s: String) { implicit class StringOps(s: String) {

View file

@ -43,7 +43,7 @@ case class VM(
val takeBranch = regs.compare(op.rs1, op.rs2, op.comp.run) val takeBranch = regs.compare(op.rs1, op.rs2, op.comp.run)
if(takeBranch){ if(takeBranch){
val nextVM = copy(pc = addr) val nextVM = copy(pc = addr)
jump(nextVM, PcUpdateBranch(nextVM.pc)) jump(nextVM, PcUpdateBranch(pc, nextVM.pc))
} }
else { else {
step(this, PcUpdateNoBranch(this.pc + Addr(4))) step(this, PcUpdateNoBranch(this.pc + Addr(4)))

View file

@ -40,10 +40,10 @@ object PrintUtils {
case MemRead(addr, word) => fansi.Color.Red(f"M[${addr.show}] -> 0x${word.hs}") case MemRead(addr, word) => fansi.Color.Red(f"M[${addr.show}] -> 0x${word.hs}")
// addr is the target address // addr is the target address
case PcUpdateJALR(addr) => fansi.Color.Green(s"PC updated to ${addr.show} via JALR") case PcUpdateJALR(addr) => fansi.Color.Green(s"PC updated to ${addr.show} via JALR")
case PcUpdateJAL(addr) => fansi.Color.Magenta(s"PC updated to ${addr.show} via JAL") case PcUpdateJAL(addr) => fansi.Color.Magenta(s"PC updated to ${addr.show} via JAL")
case PcUpdateBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show} via Branch") case PcUpdateBranch(from, to) => fansi.Color.Yellow(s"PC updated to ${to.show} via Branch")
case PcUpdateNoBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show}, skipping a Branch") case PcUpdateNoBranch(addr) => fansi.Color.Yellow(s"PC updated to ${addr.show}, skipping a Branch")
} }
} }

View file

@ -111,12 +111,12 @@ object TestRunner {
} yield { } yield {
sealed trait BranchEvent sealed trait BranchEvent
case class Taken(addr: Int) extends BranchEvent case class Taken(from: Int, to: Int) extends BranchEvent { override def toString = s"Taken ${from.hs}\t${to.hs}" }
case class NotTaken(addr: Int) extends BranchEvent case class NotTaken(addr: Int) extends BranchEvent { override def toString = s"Not Taken ${addr.hs}" }
val events: List[BranchEvent] = trace.flatMap(_.event).collect{ val events: List[BranchEvent] = trace.flatMap(_.event).collect{
case PcUpdateBranch(x) => Taken(x.value) case PcUpdateBranch(from, to) => Taken(from.value, to.value)
case PcUpdateNoBranch(x) => NotTaken(x.value) case PcUpdateNoBranch(at) => NotTaken(at.value)
} }
@ -126,6 +126,9 @@ object TestRunner {
*/ */
def OneBitInfiniteSlots(events: List[BranchEvent]): Int = { def OneBitInfiniteSlots(events: List[BranchEvent]): Int = {
// Uncomment to take a look at the event log
// say(events.mkString("\n","\n","\n"))
// Helper inspects the next element of the event list. If the event is a mispredict the prediction table is updated // Helper inspects the next element of the event list. If the event is a mispredict the prediction table is updated
// to reflect this. // to reflect this.
// As long as there are remaining events the helper calls itself recursively on the remainder // As long as there are remaining events the helper calls itself recursively on the remainder
@ -145,24 +148,69 @@ object TestRunner {
// `case Constructor(arg1, arg2) :: t => if(p(arg1, arg2))` // `case Constructor(arg1, arg2) :: t => if(p(arg1, arg2))`
// means we want to match a list whose first element is of type Constructor while satisfying some predicate p, // means we want to match a list whose first element is of type Constructor while satisfying some predicate p,
// called an if guard. // called an if guard.
case Taken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable) case Taken(from, to) :: t if( predictionTable(from)) => helper(t, predictionTable)
case Taken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, true)) case Taken(from, to) :: t if(!predictionTable(from)) => 1 + helper(t, predictionTable.updated(from, true))
case NotTaken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, false)) case NotTaken(addr) :: t if(!predictionTable(addr)) => 1 + helper(t, predictionTable.updated(addr, false))
case NotTaken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable) case NotTaken(addr) :: t if( predictionTable(addr)) => helper(t, predictionTable)
case _ => 0 case _ => 0
} }
} }
// Initially every possible branch is set to false since the initial state of the predictor is to assume branch not taken // Initially every possible branch is set to false since the initial state of the predictor is to assume branch not taken
def initState = events.map{ def initState = events.map{
case Taken(addr) => (addr, false) case Taken(from, addr) => (from, false)
case NotTaken(addr) => (addr, false) case NotTaken(addr) => (addr, false)
}.toMap }.toMap
helper(events, initState) helper(events, initState)
} }
def nBitPredictor(events: List[BranchEvent]): Int = {
case class nBitPredictor(
values : List[Int],
predictionRules : List[Boolean],
transitionRules : Int => Boolean => Int,
){
val slots = values.size
def predict(pc: Int): Boolean = predictionRules(values(pc.getTag(slots)))
def update(pc: Int, taken: Boolean): nBitPredictor = {
val current = values(pc.getTag(slots))
copy(values = values.updated(pc.getTag(slots), transitionRules(current)(taken)))
}
}
val initPredictor = nBitPredictor(
List.fill(4)(0),
List(
false,
false,
true,
true,
),
r => r match {
case 0 => taken => if(taken) 1 else 0
case 1 => taken => if(taken) 3 else 0
case 2 => taken => if(taken) 3 else 0
case 3 => taken => if(taken) 3 else 2
}
)
events.foldLeft((0, initPredictor)){ case(((acc, bp), event)) => event match {
case Taken(pc, _) if bp.predict(pc) => (acc, bp.update(pc, true))
case Taken(pc, _) => (acc + 1, bp.update(pc, false))
case NotTaken(pc) if !bp.predict(pc) => (acc, bp.update(pc, false))
case NotTaken(pc) => (acc + 1, bp.update(pc, true))
}}._1
}
say(OneBitInfiniteSlots(events)) say(OneBitInfiniteSlots(events))
say(nBitPredictor(events))
} }

View file

@ -230,3 +230,5 @@
+ Block size is 4 words (128 bits) + Block size is 4 words (128 bits)
+ Is write-through write no-allocate + Is write-through write no-allocate
+ Eviction policy is LRU (least recently used) + Eviction policy is LRU (least recently used)
Your answer should be the number of cache miss latency cycles when using this cache.