From 6d6474530ca0b92c69720fba840bb33c8a38f01a Mon Sep 17 00:00:00 2001 From: Sebastian Bugge Date: Thu, 7 Nov 2024 23:51:17 +0100 Subject: [PATCH] Simplify IDBarrier. --- src/main/scala/CPU.scala | 31 ++++++++++++---- src/main/scala/IDBarrier.scala | 65 +++++++--------------------------- 2 files changed, 38 insertions(+), 58 deletions(-) diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index a34dfc7..52738b2 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -54,6 +54,25 @@ class CPU extends MultiIOModule { /** TODO: Your code here */ + def forward(data: UInt, addr: UInt, useForward: Bool, mem: Forwarding, wb: Forwarding, id: Forwarding): UInt = { + Mux( + !useForward, + data, + Mux( + mem.valid && mem.writeAddr === addr, + mem.writeData, + Mux( + wb.valid && wb.writeAddr === addr, + wb.writeData, + Mux( + id.valid && id.writeAddr === addr, + id.writeData, + data, + ) + ) + ) + ) + } IFBarrier.PCin := IF.io.PC IFBarrier.instructionIn := IF.io.instruction @@ -63,8 +82,8 @@ class CPU extends MultiIOModule { IDBarrier.op1in := ID.io.op1 IDBarrier.op2in := ID.io.op2 - IDBarrier.isOp1RValue := ID.io.isOp1RValue - IDBarrier.isOp2RValue := ID.io.isOp2RValue + IDBarrier.isOp1RValueIn := ID.io.isOp1RValue + IDBarrier.isOp2RValueIn := ID.io.isOp2RValue IDBarrier.r1ValueIn := ID.io.r1Value IDBarrier.r2ValueIn := ID.io.r2Value IDBarrier.r1AddressIn := ID.io.r1Address @@ -78,12 +97,12 @@ class CPU extends MultiIOModule { IDBarrier.memWriteIn := ID.io.memWrite IDBarrier.memReadIn := ID.io.memRead - EX.io.op1 := IDBarrier.op1out - EX.io.op2 := IDBarrier.op2out + EX.io.op1 := forward(IDBarrier.op1out.asUInt(), IDBarrier.r1AddressOut, IDBarrier.isOp1RValueOut, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() + EX.io.op2 := forward(IDBarrier.op2out.asUInt(), IDBarrier.r2AddressOut, IDBarrier.isOp2RValueOut, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() EX.io.ALUOp := IDBarrier.ALUopOut EX.io.branchType := IDBarrier.branchTypeOut - EX.io.rs1ValueIn := IDBarrier.r1ValueOut.asSInt() - EX.io.rs2ValueIn := IDBarrier.r2ValueOut.asSInt() + EX.io.rs1ValueIn := forward(IDBarrier.r1ValueOut, IDBarrier.r1AddressOut, true.B, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() + EX.io.rs2ValueIn := forward(IDBarrier.r2ValueOut, IDBarrier.r2AddressOut, true.B, mem = MEMBarrier.forwardMem, wb = MEMBarrier.forwardWb, id = MEMBarrier.forwardId).asSInt() EXBarrier.r2ValueIn := EX.io.rs2ValueOut.asUInt() EXBarrier.ALUResultIn := EX.io.ALUResult.asUInt() diff --git a/src/main/scala/IDBarrier.scala b/src/main/scala/IDBarrier.scala index fe0ee47..26869d4 100644 --- a/src/main/scala/IDBarrier.scala +++ b/src/main/scala/IDBarrier.scala @@ -8,10 +8,12 @@ class IDBarrier extends MultiIOModule { new Bundle { val op1in = Input(SInt(32.W)) val op1out = Output(SInt(32.W)) - val isOp1RValue = Input(Bool()) + val isOp1RValueIn = Input(Bool()) + val isOp1RValueOut = Output(Bool()) val op2in = Input(SInt(32.W)) val op2out = Output(SInt(32.W)) - val isOp2RValue = Input(Bool()) + val isOp2RValueIn = Input(Bool()) + val isOp2RValueOut = Output(Bool()) val r1ValueIn = Input(UInt(32.W)) val r1ValueOut = Output(UInt(32.W)) val r1AddressIn = Input(UInt(5.W)) @@ -36,16 +38,15 @@ class IDBarrier extends MultiIOModule { val memReadOut = Output(Bool()) val memWriteIn = Input(Bool()) val memWriteOut = Output(Bool()) - - val forwardMem = Input(new Forwarding) - val forwardWb = Input(new Forwarding) - val forwardId = Input(new Forwarding) }) val isOp1RValue = RegInit(Bool(), false.B) - isOp1RValue := io.isOp1RValue + isOp1RValue := io.isOp1RValueIn + io.isOp1RValueOut := isOp1RValue + val isOp2RValue = RegInit(Bool(), false.B) - isOp2RValue := io.isOp2RValue + isOp2RValue := io.isOp2RValueIn + io.isOp2RValueOut := isOp2RValue val r2Address = RegInit(UInt(5.W), 0.U) r2Address := io.r2AddressIn @@ -57,59 +58,19 @@ class IDBarrier extends MultiIOModule { val op1 = RegInit(SInt(32.W), 0.S) op1 := io.op1in - io.op1out := Mux( - isOp1RValue && io.forwardMem.valid && r1Address === io.forwardMem.writeAddr, - io.forwardMem.writeData.asSInt(), - Mux( - isOp1RValue && io.forwardWb.valid && r1Address === io.forwardWb.writeAddr, - io.forwardWb.writeData.asSInt(), - Mux( - isOp1RValue && io.forwardId.valid && r1Address === io.forwardId.writeAddr, - io.forwardId.writeData.asSInt(), - op1.asSInt(), - ))) + io.op1out := op1 val op2 = RegInit(SInt(32.W), 0.S) op2 := io.op2in - io.op2out := Mux( - isOp2RValue && io.forwardMem.valid && r2Address === io.forwardMem.writeAddr, - io.forwardMem.writeData.asSInt(), - Mux( - isOp2RValue && io.forwardWb.valid && r2Address === io.forwardWb.writeAddr, - io.forwardWb.writeData.asSInt(), - Mux( - isOp2RValue && io.forwardId.valid && r2Address === io.forwardId.writeAddr, - io.forwardId.writeData.asSInt(), - op2.asSInt(), - ))) + io.op2out := op2 val r1Value = RegInit(UInt(32.W), 0.U) r1Value := io.r1ValueIn - io.r1ValueOut := Mux( - io.forwardMem.valid && r1Address === io.forwardMem.writeAddr, - io.forwardMem.writeData, - Mux( - io.forwardWb.valid && r1Address === io.forwardWb.writeAddr, - io.forwardWb.writeData, - Mux( - io.forwardId.valid && r1Address === io.forwardId.writeAddr, - io.forwardId.writeData, - r1Value, - ))) + io.r1ValueOut := r1Value val r2Value = RegInit(UInt(32.W), 0.U) r2Value := io.r2ValueIn - io.r2ValueOut := Mux( - io.forwardMem.valid && r2Address === io.forwardMem.writeAddr, - io.forwardMem.writeData, - Mux( - io.forwardWb.valid && r2Address === io.forwardWb.writeAddr, - io.forwardWb.writeData, - Mux( - io.forwardId.valid && r2Address === io.forwardId.writeAddr, - io.forwardId.writeData, - r2Value, - ))) + io.r2ValueOut := r2Value val returnAddr = RegInit(UInt(32.W), 0.U) returnAddr := io.returnAddrIn