diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index 551f62c..230c3af 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -136,6 +136,9 @@ class CPU extends MultiIOModule { IF.io.branchAddress := EXBarrier.out.ALUResult // Stall - IF.io.stall := IDBarrier.stall || ID.io.stall + IF.io.stall := IDBarrier.stall IFBarrier.stall := IDBarrier.stall + + // Flush + IFBarrier.flush := EXBarrier.flush } diff --git a/src/main/scala/EXBarrier.scala b/src/main/scala/EXBarrier.scala index 2c79a95..222f20c 100644 --- a/src/main/scala/EXBarrier.scala +++ b/src/main/scala/EXBarrier.scala @@ -20,9 +20,12 @@ class EXBarrier extends MultiIOModule { new Bundle { val in = Input(new EXBarrierIO) val out = Output(new EXBarrierIO) + val flush = Output(Bool()) }) val delay = Reg(new EXBarrierIO) delay := io.in io.out := delay + + io.flush := io.in.branch } diff --git a/src/main/scala/ID.scala b/src/main/scala/ID.scala index 5627fe4..1027e70 100644 --- a/src/main/scala/ID.scala +++ b/src/main/scala/ID.scala @@ -39,7 +39,6 @@ class InstructionDecode extends MultiIOModule { val branchType = Output(UInt(3.W)) val jump = Output(Bool()) val returnAddr = Output(UInt(32.W)) - val stall = Output(Bool()) } ) @@ -88,25 +87,11 @@ class InstructionDecode extends MultiIOModule { io.ALUOp := decoder.ALUop io.writeAddrOut := decoder.instruction.registerRd - val stallsRemaining = RegInit(UInt(4.W), 0.U) - val stallDelay = stallsRemaining > 0.U - stallsRemaining := Mux( - stallDelay, - stallsRemaining - 1.U, - Mux( - decoder.controlSignals.branch, - 3.U, - 0.U - )) - - val stall = stallsRemaining > 1.U || decoder.controlSignals.branch && !stallDelay - io.stall := stall - - io.jump := Mux(stallDelay, false.B, decoder.controlSignals.jump) - io.branchType := Mux(stallDelay, branchType.DC, decoder.branchType) + io.jump := decoder.controlSignals.jump + io.branchType := decoder.branchType io.returnAddr := io.pc + 4.U - io.writeEnableOut := Mux(stallDelay, false.B, decoder.controlSignals.regWrite) - io.memRead := Mux(stallDelay, false.B, decoder.controlSignals.memRead) - io.memWrite := Mux(stallDelay, false.B, decoder.controlSignals.memWrite) + io.writeEnableOut := decoder.controlSignals.regWrite + io.memRead := decoder.controlSignals.memRead + io.memWrite := decoder.controlSignals.memWrite } diff --git a/src/main/scala/IFBarrier.scala b/src/main/scala/IFBarrier.scala index 755410b..49555ea 100644 --- a/src/main/scala/IFBarrier.scala +++ b/src/main/scala/IFBarrier.scala @@ -11,6 +11,7 @@ class IFBarrier extends MultiIOModule { val instructionIn = Input(new Instruction) val instructionOut = Output(new Instruction) val stall = Input(Bool()) + val flush = Input(Bool()) }) val PC = RegInit(UInt(32.W), 0.U) @@ -19,11 +20,22 @@ class IFBarrier extends MultiIOModule { val instruction = Reg(new Instruction) val replay = RegInit(Bool(), false.B) + val flushRemaining = RegInit(UInt(2.W), 0.U) + flushRemaining := Mux( + io.flush, + 2.U, + Mux( + flushRemaining === 0.U, + 0.U, + flushRemaining - 1.U + ) + ) + replay := io.stall instruction := io.instructionIn io.instructionOut := Mux( - io.stall, + io.stall || io.flush || flushRemaining > 0.U, Instruction.NOP, Mux( replay, diff --git a/src/test/resources/tests/branch.s b/src/test/resources/tests/branch.s index 0761396..15f6937 100644 --- a/src/test/resources/tests/branch.s +++ b/src/test/resources/tests/branch.s @@ -4,4 +4,7 @@ main: loop: addi x2, x2, 1 blt x2, x1, loop + nop + nop + nop done \ No newline at end of file diff --git a/src/test/resources/tests/jump.s b/src/test/resources/tests/jump.s index cb18ee4..a72a9df 100644 --- a/src/test/resources/tests/jump.s +++ b/src/test/resources/tests/jump.s @@ -1,6 +1,8 @@ main: jal x1, end addi x1, x1, 0 + nop + nop done end: diff --git a/src/test/scala/Manifest.scala b/src/test/scala/Manifest.scala index 79d6989..d10ed8a 100644 --- a/src/test/scala/Manifest.scala +++ b/src/test/scala/Manifest.scala @@ -19,7 +19,7 @@ import LogParser._ object Manifest { - val singleTest = "simpleload.s" + val singleTest = "branch.s" val nopPadded = false