From bb89461843e8d5475e33650ff5a6e8b3aec5cf65 Mon Sep 17 00:00:00 2001 From: Sebastian Bugge Date: Thu, 14 Nov 2024 15:38:03 +0100 Subject: [PATCH] Simple one-bit branch-predictor. --- src/main/scala/CPU.scala | 21 ++++++++++++++------- src/main/scala/EX.scala | 13 +++++++++++-- src/main/scala/EXBarrier.scala | 9 --------- src/main/scala/ID.scala | 9 +++++++++ src/main/scala/IDBarrier.scala | 2 ++ src/test/scala/Manifest.scala | 2 +- 6 files changed, 37 insertions(+), 19 deletions(-) diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index 31530e8..eef93b7 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -91,6 +91,8 @@ class CPU extends MultiIOModule { IDBarrier.in.ALUop := ID.io.ALUOp IDBarrier.in.returnAddr := ID.io.returnAddr IDBarrier.in.branchAddr := ID.io.branchAddr + IDBarrier.in.nextOpAddr := ID.io.nextOpAddr + IDBarrier.in.branchPrediction := ID.io.branchPrediction IDBarrier.in.jump := ID.io.jump IDBarrier.in.branchType := ID.io.branchType IDBarrier.in.writeEnable := ID.io.writeEnableOut @@ -104,13 +106,14 @@ class CPU extends MultiIOModule { EX.io.branchType := IDBarrier.out.branchType EX.io.rs1ValueIn := forward(IDBarrier.out.r1Value, IDBarrier.out.r1Address, true.B, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() EX.io.rs2ValueIn := forward(IDBarrier.out.r2Value, IDBarrier.out.r2Address, true.B, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() + EX.io.branchAddr := IDBarrier.out.branchAddr + EX.io.nextOpAddr := IDBarrier.out.nextOpAddr + EX.io.branchPrediction := IDBarrier.out.branchPrediction EXBarrier.in.r2Value := EX.io.rs2ValueOut.asUInt() EXBarrier.in.ALUResult := EX.io.ALUResult.asUInt() - EXBarrier.branchIn := EX.io.branch EXBarrier.in.jump := IDBarrier.out.jump EXBarrier.in.returnAddr := IDBarrier.out.returnAddr - EXBarrier.branchAddrIn := IDBarrier.out.branchAddr EXBarrier.in.writeEnable := IDBarrier.out.writeEnable EXBarrier.in.writeAddr := IDBarrier.out.writeAddr EXBarrier.in.memWrite := IDBarrier.out.memWrite @@ -140,13 +143,17 @@ class CPU extends MultiIOModule { ID.io.writeAddrIn := MEMBarrier.out.writeAddr // Branching - IF.io.branch := EXBarrier.branchOut - IF.io.branchAddress := EXBarrier.branchAddrOut + IF.io.branch := ID.io.branchPrediction || EX.io.misprediction + IF.io.branchAddress := Mux(EX.io.misprediction, EX.io.actualBranchAddr, ID.io.branchAddr) + + // Flush + IFBarrier.flush := EX.io.misprediction + + // Update branch prediction + ID.io.hasBranchResult := EX.io.branchOperation + ID.io.branchResult := EX.io.branchTaken // Stall IF.io.stall := IDBarrier.stall IFBarrier.stall := IDBarrier.stall - - // Flush - IFBarrier.flush := EXBarrier.flush } diff --git a/src/main/scala/EX.scala b/src/main/scala/EX.scala index 539ae2d..b988558 100644 --- a/src/main/scala/EX.scala +++ b/src/main/scala/EX.scala @@ -14,7 +14,13 @@ class Execute extends MultiIOModule { val rs2ValueIn = Input(SInt(32.W)) val rs2ValueOut = Output(SInt(32.W)) val branchType = Input(UInt(3.W)) - val branch = Output(Bool()) + val branchPrediction = Input(Bool()) + val branchAddr = Input(UInt(32.W)) + val nextOpAddr = Input(UInt(32.W)) + val actualBranchAddr = Output(UInt(32.W)) + val branchTaken = Output(Bool()) + val branchOperation = Output(Bool()) + val misprediction = Output(Bool()) val ALUOp = Input(UInt(4.W)) val ALUResult = Output(SInt(32.W)) } @@ -46,6 +52,9 @@ class Execute extends MultiIOModule { ) io.rs2ValueOut := io.rs2ValueIn - io.branch := MuxLookup(io.branchType, false.B, BranchALUOpsMap) + io.branchTaken := MuxLookup(io.branchType, false.B, BranchALUOpsMap) + io.misprediction := io.branchOperation && (io.branchTaken =/= io.branchPrediction) + io.actualBranchAddr := Mux(io.branchTaken, io.branchAddr, io.nextOpAddr) + io.branchOperation := io.branchType =/= branchType.DC io.ALUResult := MuxLookup(io.ALUOp, 0.S(32.W), ALUOpsMap) } \ No newline at end of file diff --git a/src/main/scala/EXBarrier.scala b/src/main/scala/EXBarrier.scala index 45f933e..c0ea41b 100644 --- a/src/main/scala/EXBarrier.scala +++ b/src/main/scala/EXBarrier.scala @@ -19,11 +19,6 @@ class EXBarrier extends MultiIOModule { new Bundle { val in = Input(new EXBarrierIO) val out = Output(new EXBarrierIO) - val flush = Output(Bool()) - val branchIn = Input(Bool()) - val branchOut = Output(Bool()) - val branchAddrIn = Input(UInt(32.W)) - val branchAddrOut = Output(UInt(32.W)) val forwardEx = Output(new Forwarding) }) @@ -31,10 +26,6 @@ class EXBarrier extends MultiIOModule { delay := io.in io.out := delay - io.flush := io.branchIn - io.branchOut := io.branchIn - io.branchAddrOut := io.branchAddrIn - io.forwardEx.write := io.in.writeEnable io.forwardEx.writeAddr := io.in.writeAddr io.forwardEx.writeData := Mux(io.in.jump, io.in.returnAddr, io.in.ALUResult) diff --git a/src/main/scala/ID.scala b/src/main/scala/ID.scala index 8d39e95..c689f33 100644 --- a/src/main/scala/ID.scala +++ b/src/main/scala/ID.scala @@ -19,6 +19,8 @@ class InstructionDecode extends MultiIOModule { val io = IO( new Bundle { val instruction = Input(new Instruction) + val hasBranchResult = Input(Bool()) + val branchResult = Input(Bool()) val pc = Input(UInt(32.W)) val op1 = Output(SInt(32.W)) val isOp1RValue = Output(Bool()) @@ -37,9 +39,11 @@ class InstructionDecode extends MultiIOModule { val memWrite = Output(Bool()) val memRead = Output(Bool()) val branchType = Output(UInt(3.W)) + val branchPrediction = Output(Bool()) val jump = Output(Bool()) val returnAddr = Output(UInt(32.W)) val branchAddr = Output(UInt(32.W)) + val nextOpAddr = Output(UInt(32.W)) val forwardEx = Input(new Forwarding) val forwardMem = Input(new Forwarding) @@ -127,4 +131,9 @@ class InstructionDecode extends MultiIOModule { val op1 = forward(io.op1.asUInt(), io.r1Address, io.isOp1RValue, ex = io.forwardEx, mem = io.forwardMem, wb = io.forwardWb, id = io.forwardId).asSInt() io.branchAddr := ((op1 + io.op2) & -2.S).asUInt() + io.nextOpAddr := io.pc + 4.U + + val lastBranchWasTaken = RegInit(Bool(), true.B) + lastBranchWasTaken := Mux(io.hasBranchResult, io.branchResult, lastBranchWasTaken) + io.branchPrediction := Mux(io.branchType =/= branchType.DC, lastBranchWasTaken, false.B) } diff --git a/src/main/scala/IDBarrier.scala b/src/main/scala/IDBarrier.scala index ec5ecfc..23cdf90 100644 --- a/src/main/scala/IDBarrier.scala +++ b/src/main/scala/IDBarrier.scala @@ -14,6 +14,8 @@ class IDBarrierIO extends Bundle { val r2Address = UInt(5.W) val returnAddr = UInt(32.W) val branchAddr = UInt(32.W) + val nextOpAddr = UInt(32.W) + val branchPrediction = Bool() val jump = Bool() val ALUop = UInt(4.W) val branchType = UInt(3.W) diff --git a/src/test/scala/Manifest.scala b/src/test/scala/Manifest.scala index d10ed8a..6a3c731 100644 --- a/src/test/scala/Manifest.scala +++ b/src/test/scala/Manifest.scala @@ -19,7 +19,7 @@ import LogParser._ object Manifest { - val singleTest = "branch.s" + val singleTest = "square.s" val nopPadded = false