diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index eef93b7..d5e17c7 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -90,9 +90,6 @@ class CPU extends MultiIOModule { IDBarrier.in.r2Address := ID.io.r2Address IDBarrier.in.ALUop := ID.io.ALUOp IDBarrier.in.returnAddr := ID.io.returnAddr - IDBarrier.in.branchAddr := ID.io.branchAddr - IDBarrier.in.nextOpAddr := ID.io.nextOpAddr - IDBarrier.in.branchPrediction := ID.io.branchPrediction IDBarrier.in.jump := ID.io.jump IDBarrier.in.branchType := ID.io.branchType IDBarrier.in.writeEnable := ID.io.writeEnableOut @@ -106,12 +103,10 @@ class CPU extends MultiIOModule { EX.io.branchType := IDBarrier.out.branchType EX.io.rs1ValueIn := forward(IDBarrier.out.r1Value, IDBarrier.out.r1Address, true.B, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() EX.io.rs2ValueIn := forward(IDBarrier.out.r2Value, IDBarrier.out.r2Address, true.B, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() - EX.io.branchAddr := IDBarrier.out.branchAddr - EX.io.nextOpAddr := IDBarrier.out.nextOpAddr - EX.io.branchPrediction := IDBarrier.out.branchPrediction EXBarrier.in.r2Value := EX.io.rs2ValueOut.asUInt() EXBarrier.in.ALUResult := EX.io.ALUResult.asUInt() + EXBarrier.branchIn := EX.io.branch EXBarrier.in.jump := IDBarrier.out.jump EXBarrier.in.returnAddr := IDBarrier.out.returnAddr EXBarrier.in.writeEnable := IDBarrier.out.writeEnable @@ -131,29 +126,19 @@ class CPU extends MultiIOModule { MEMBarrier.in.writeEnable := EXBarrier.out.writeEnable MEMBarrier.in.writeAddr := EXBarrier.out.writeAddr - // Fast branching forwarding - ID.io.forwardEx := EXBarrier.forwardEx - ID.io.forwardMem := MEMBarrier.forwardMem - ID.io.forwardWb := MEMBarrier.forwardWb - ID.io.forwardId := MEMBarrier.forwardId - // Write back ID.io.writeData := MEMBarrier.out.data ID.io.writeEnableIn := MEMBarrier.out.writeEnable ID.io.writeAddrIn := MEMBarrier.out.writeAddr // Branching - IF.io.branch := ID.io.branchPrediction || EX.io.misprediction - IF.io.branchAddress := Mux(EX.io.misprediction, EX.io.actualBranchAddr, ID.io.branchAddr) - - // Flush - IFBarrier.flush := EX.io.misprediction - - // Update branch prediction - ID.io.hasBranchResult := EX.io.branchOperation - ID.io.branchResult := EX.io.branchTaken + IF.io.branch := EXBarrier.branchOut + IF.io.branchAddress := EXBarrier.branchAddr // Stall IF.io.stall := IDBarrier.stall IFBarrier.stall := IDBarrier.stall + + // Flush + IFBarrier.flush := EXBarrier.flush } diff --git a/src/main/scala/Decoder.scala b/src/main/scala/Decoder.scala index 71d324b..3f6dc46 100644 --- a/src/main/scala/Decoder.scala +++ b/src/main/scala/Decoder.scala @@ -87,7 +87,7 @@ class Decoder() extends Module { // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp JAL -> List(Y, N, N, Y, Y, branchType.jump, PC, imm, ImmFormat.JTYPE, ALUOps.ADD ), - JALR -> List(Y, N, N, Y, Y, branchType.jump, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), + JALR -> List(Y, N, N, Y, Y, branchType.jump, rs1, imm, ImmFormat.ITYPE, ALUOps.ADDR ), ) diff --git a/src/main/scala/EX.scala b/src/main/scala/EX.scala index b988558..a65bb61 100644 --- a/src/main/scala/EX.scala +++ b/src/main/scala/EX.scala @@ -14,13 +14,7 @@ class Execute extends MultiIOModule { val rs2ValueIn = Input(SInt(32.W)) val rs2ValueOut = Output(SInt(32.W)) val branchType = Input(UInt(3.W)) - val branchPrediction = Input(Bool()) - val branchAddr = Input(UInt(32.W)) - val nextOpAddr = Input(UInt(32.W)) - val actualBranchAddr = Output(UInt(32.W)) - val branchTaken = Output(Bool()) - val branchOperation = Output(Bool()) - val misprediction = Output(Bool()) + val branch = Output(Bool()) val ALUOp = Input(UInt(4.W)) val ALUResult = Output(SInt(32.W)) } @@ -37,6 +31,7 @@ class Execute extends MultiIOModule { ALUOps.SRA -> (io.op1 >> io.op2(4, 0)), ALUOps.SRL -> (io.op1.asUInt() >> io.op2(4, 0)).asSInt(), ALUOps.SLL -> (io.op1.asUInt() << io.op2(4, 0)).asSInt(), + ALUOps.ADDR -> ((io.op1 + io.op2) & -2.S), ALUOps.COPY_A -> io.op1, ALUOps.COPY_B -> io.op2, ) @@ -52,9 +47,6 @@ class Execute extends MultiIOModule { ) io.rs2ValueOut := io.rs2ValueIn - io.branchTaken := MuxLookup(io.branchType, false.B, BranchALUOpsMap) - io.misprediction := io.branchOperation && (io.branchTaken =/= io.branchPrediction) - io.actualBranchAddr := Mux(io.branchTaken, io.branchAddr, io.nextOpAddr) - io.branchOperation := io.branchType =/= branchType.DC + io.branch := MuxLookup(io.branchType, false.B, BranchALUOpsMap) io.ALUResult := MuxLookup(io.ALUOp, 0.S(32.W), ALUOpsMap) } \ No newline at end of file diff --git a/src/main/scala/EXBarrier.scala b/src/main/scala/EXBarrier.scala index c0ea41b..5e8687e 100644 --- a/src/main/scala/EXBarrier.scala +++ b/src/main/scala/EXBarrier.scala @@ -19,14 +19,17 @@ class EXBarrier extends MultiIOModule { new Bundle { val in = Input(new EXBarrierIO) val out = Output(new EXBarrierIO) - val forwardEx = Output(new Forwarding) + val flush = Output(Bool()) + val branchAddr = Output(UInt(32.W)) + val branchIn = Input(Bool()) + val branchOut = Output(Bool()) }) val delay = Reg(new EXBarrierIO) delay := io.in io.out := delay - io.forwardEx.write := io.in.writeEnable - io.forwardEx.writeAddr := io.in.writeAddr - io.forwardEx.writeData := Mux(io.in.jump, io.in.returnAddr, io.in.ALUResult) + io.flush := io.branchIn + io.branchOut := io.branchIn + io.branchAddr := io.in.ALUResult } diff --git a/src/main/scala/ID.scala b/src/main/scala/ID.scala index c689f33..1027e70 100644 --- a/src/main/scala/ID.scala +++ b/src/main/scala/ID.scala @@ -19,8 +19,6 @@ class InstructionDecode extends MultiIOModule { val io = IO( new Bundle { val instruction = Input(new Instruction) - val hasBranchResult = Input(Bool()) - val branchResult = Input(Bool()) val pc = Input(UInt(32.W)) val op1 = Output(SInt(32.W)) val isOp1RValue = Output(Bool()) @@ -39,43 +37,11 @@ class InstructionDecode extends MultiIOModule { val memWrite = Output(Bool()) val memRead = Output(Bool()) val branchType = Output(UInt(3.W)) - val branchPrediction = Output(Bool()) val jump = Output(Bool()) val returnAddr = Output(UInt(32.W)) - val branchAddr = Output(UInt(32.W)) - val nextOpAddr = Output(UInt(32.W)) - - val forwardEx = Input(new Forwarding) - val forwardMem = Input(new Forwarding) - val forwardWb = Input(new Forwarding) - val forwardId = Input(new Forwarding) } ) - def forward(data: UInt, addr: UInt, useForward: Bool, ex: Forwarding, mem: Forwarding, wb: Forwarding, id: Forwarding): UInt = { - Mux( - !useForward, - data, - Mux( - ex.valid && ex.writeAddr === addr, - ex.writeData, - Mux( - mem.valid && mem.writeAddr === addr, - mem.writeData, - Mux( - wb.valid && wb.writeAddr === addr, - wb.writeData, - Mux( - id.valid && id.writeAddr === addr, - id.writeData, - data, - ) - ) - ) - ) - ) - } - val registers = Module(new Registers) val decoder = Module(new Decoder).io @@ -128,12 +94,4 @@ class InstructionDecode extends MultiIOModule { io.writeEnableOut := decoder.controlSignals.regWrite io.memRead := decoder.controlSignals.memRead io.memWrite := decoder.controlSignals.memWrite - - val op1 = forward(io.op1.asUInt(), io.r1Address, io.isOp1RValue, ex = io.forwardEx, mem = io.forwardMem, wb = io.forwardWb, id = io.forwardId).asSInt() - io.branchAddr := ((op1 + io.op2) & -2.S).asUInt() - io.nextOpAddr := io.pc + 4.U - - val lastBranchWasTaken = RegInit(Bool(), true.B) - lastBranchWasTaken := Mux(io.hasBranchResult, io.branchResult, lastBranchWasTaken) - io.branchPrediction := Mux(io.branchType =/= branchType.DC, lastBranchWasTaken, false.B) } diff --git a/src/main/scala/IDBarrier.scala b/src/main/scala/IDBarrier.scala index 23cdf90..c924233 100644 --- a/src/main/scala/IDBarrier.scala +++ b/src/main/scala/IDBarrier.scala @@ -13,9 +13,6 @@ class IDBarrierIO extends Bundle { val r2Value = UInt(32.W) val r2Address = UInt(5.W) val returnAddr = UInt(32.W) - val branchAddr = UInt(32.W) - val nextOpAddr = UInt(32.W) - val branchPrediction = Bool() val jump = Bool() val ALUop = UInt(4.W) val branchType = UInt(3.W) diff --git a/src/main/scala/IF.scala b/src/main/scala/IF.scala index a08d526..b4fd54a 100644 --- a/src/main/scala/IF.scala +++ b/src/main/scala/IF.scala @@ -46,9 +46,9 @@ class InstructionFetch extends MultiIOModule { * * You should expand on or rewrite the code below. */ - io.PC := Mux(io.branch, io.branchAddress, PC) - IMEM.io.instructionAddress := io.PC - PC := Mux(io.branch, io.branchAddress + 4.U, Mux(io.stall, PC, PC + 4.U)) + io.PC := PC + IMEM.io.instructionAddress := PC + PC := Mux(io.branch, io.branchAddress, Mux(io.stall, PC, PC + 4.U)) val instruction = Wire(new Instruction) instruction := IMEM.io.instruction.asTypeOf(new Instruction) diff --git a/src/main/scala/IFBarrier.scala b/src/main/scala/IFBarrier.scala index 789f4f6..4b5d103 100644 --- a/src/main/scala/IFBarrier.scala +++ b/src/main/scala/IFBarrier.scala @@ -20,12 +20,22 @@ class IFBarrier extends MultiIOModule { val instruction = Reg(new Instruction) val replay = RegInit(Bool(), false.B) + val flushRemaining = RegInit(UInt(2.W), 0.U) + flushRemaining := Mux( + io.flush, + 1.U, + Mux( + flushRemaining === 0.U, + 0.U, + flushRemaining - 1.U + ) + ) replay := io.stall instruction := io.instructionIn io.instructionOut := Mux( - io.stall || io.flush, + io.stall || io.flush || flushRemaining > 0.U, Instruction.NOP, Mux( replay, diff --git a/src/main/scala/ToplevelSignals.scala b/src/main/scala/ToplevelSignals.scala index 7fe70f5..e56b572 100644 --- a/src/main/scala/ToplevelSignals.scala +++ b/src/main/scala/ToplevelSignals.scala @@ -119,6 +119,7 @@ object ALUOps { val SRA = 9.U(4.W) val COPY_A = 10.U(4.W) val COPY_B = 11.U(4.W) + val ADDR = 12.U(4.W) val DC = 15.U(4.W) } diff --git a/src/test/scala/Manifest.scala b/src/test/scala/Manifest.scala index 6a3c731..d10ed8a 100644 --- a/src/test/scala/Manifest.scala +++ b/src/test/scala/Manifest.scala @@ -19,7 +19,7 @@ import LogParser._ object Manifest { - val singleTest = "square.s" + val singleTest = "branch.s" val nopPadded = false