diff --git a/src/main/scala/ID.scala b/src/main/scala/ID.scala index a3adb11..7ea6e01 100644 --- a/src/main/scala/ID.scala +++ b/src/main/scala/ID.scala @@ -86,7 +86,6 @@ class InstructionDecode extends MultiIOModule { io.r2Address := registers.io.readAddress2 io.ALUOp := decoder.ALUop - io.branchType := decoder.branchType io.writeAddrOut := decoder.instruction.registerRd val stallsRemaining = RegInit(UInt(4.W), 0.U) @@ -97,13 +96,18 @@ class InstructionDecode extends MultiIOModule { Mux( decoder.controlSignals.memRead, 1.U, - 0.U + Mux( + decoder.controlSignals.branch, + 3.U, + 0.U + ) )) - val stall = stallsRemaining > 1.U || decoder.controlSignals.memRead && !stallDelay + val stall = stallsRemaining > 1.U || decoder.controlSignals.memRead && !stallDelay || decoder.controlSignals.branch && !stallDelay io.stall := stall io.jump := Mux(stallDelay, false.B, decoder.controlSignals.jump) + io.branchType := Mux(stallDelay, branchType.DC, decoder.branchType) io.returnAddr := io.pc + 4.U io.writeEnableOut := Mux(stallDelay, false.B, decoder.controlSignals.regWrite) diff --git a/src/main/scala/IF.scala b/src/main/scala/IF.scala index 35b0842..b4fd54a 100644 --- a/src/main/scala/IF.scala +++ b/src/main/scala/IF.scala @@ -48,7 +48,7 @@ class InstructionFetch extends MultiIOModule { */ io.PC := PC IMEM.io.instructionAddress := PC - PC := Mux(io.stall, PC, Mux(io.branch, io.branchAddress, PC + 4.U)) + PC := Mux(io.branch, io.branchAddress, Mux(io.stall, PC, PC + 4.U)) val instruction = Wire(new Instruction) instruction := IMEM.io.instruction.asTypeOf(new Instruction) diff --git a/src/test/resources/tests/branch.s b/src/test/resources/tests/branch.s index 0761396..64d56ef 100644 --- a/src/test/resources/tests/branch.s +++ b/src/test/resources/tests/branch.s @@ -4,4 +4,5 @@ main: loop: addi x2, x2, 1 blt x2, x1, loop + nop done \ No newline at end of file diff --git a/src/test/scala/Manifest.scala b/src/test/scala/Manifest.scala index bc851f7..d10ed8a 100644 --- a/src/test/scala/Manifest.scala +++ b/src/test/scala/Manifest.scala @@ -19,7 +19,7 @@ import LogParser._ object Manifest { - val singleTest = "forward1.s" + val singleTest = "branch.s" val nopPadded = false