From 92d0dfd9eb8c875a79165a4e1ad81c9b301db00c Mon Sep 17 00:00:00 2001 From: Sebastian Bugge Date: Fri, 4 Oct 2024 02:16:17 +0200 Subject: [PATCH] Working branching. --- src/main/scala/CPU.scala | 14 ++++++- src/main/scala/Decoder.scala | 64 +++++++++++++++++-------------- src/main/scala/EX.scala | 17 ++++++++ src/main/scala/EXBarrier.scala | 10 +++++ src/main/scala/ID.scala | 4 ++ src/main/scala/IDBarrier.scala | 12 ++++++ src/main/scala/IF.scala | 4 +- src/test/resources/tests/branch.s | 7 ++++ src/test/scala/Manifest.scala | 2 +- 9 files changed, 102 insertions(+), 32 deletions(-) create mode 100644 src/test/resources/tests/branch.s diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index c13b6c7..8fd4b3a 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -63,8 +63,10 @@ class CPU extends MultiIOModule { IDBarrier.op1in := ID.io.op1 IDBarrier.op2in := ID.io.op2 + IDBarrier.r1ValueIn := ID.io.r1Value IDBarrier.r2ValueIn := ID.io.r2Value IDBarrier.ALUopIn := ID.io.ALUOp + IDBarrier.branchTypeIn := ID.io.branchType IDBarrier.writeEnableIn := ID.io.writeEnableOut IDBarrier.writeAddrIn := ID.io.writeAddrOut IDBarrier.memWriteIn := ID.io.memWrite @@ -73,13 +75,17 @@ class CPU extends MultiIOModule { EX.io.op1 := IDBarrier.op1out EX.io.op2 := IDBarrier.op2out EX.io.ALUOp := IDBarrier.ALUopOut + EX.io.branchType := IDBarrier.branchTypeOut + EX.io.rs1ValueIn := IDBarrier.r1ValueOut.asSInt() + EX.io.rs2ValueIn := IDBarrier.r2ValueOut.asSInt() + EXBarrier.r2ValueIn := EX.io.rs2ValueOut.asUInt() + EXBarrier.ALUResultIn := EX.io.ALUResult.asUInt() + EXBarrier.branchIn := EX.io.branch EXBarrier.writeEnableIn := IDBarrier.writeEnableOut EXBarrier.writeAddrIn := IDBarrier.writeAddrOut EXBarrier.memWriteIn := IDBarrier.memWriteOut EXBarrier.memReadIn := IDBarrier.memReadOut - EXBarrier.r2ValueIn := IDBarrier.r2ValueOut - EXBarrier.ALUResultIn := EX.io.ALUResult.asUInt() MEM.io.ALUResult := EXBarrier.ALUResultOut MEM.io.writeMem := EXBarrier.memWriteOut @@ -90,4 +96,8 @@ class CPU extends MultiIOModule { ID.io.writeData := MEM.io.dataOut ID.io.writeEnableIn := EXBarrier.writeEnableOut ID.io.writeAddrIn := EXBarrier.writeAddrOut + + // Branching + IF.io.branch := EXBarrier.branchOut + IF.io.branchAddress := EXBarrier.branchAddress } diff --git a/src/main/scala/Decoder.scala b/src/main/scala/Decoder.scala index 1f8bfab..5cd321d 100644 --- a/src/main/scala/Decoder.scala +++ b/src/main/scala/Decoder.scala @@ -46,36 +46,44 @@ class Decoder() extends Module { */ val opcodeMap: Array[(BitPat, List[UInt])] = Array( - // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp - ADD -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.ADD ), - SUB -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SUB ), - AND -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.AND ), - OR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.OR ), - XOR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.XOR ), - SLT -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLT ), - SLTU -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLTU ), - SRA -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRA ), - SRL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRL ), - SLL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLL ), + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + ADD -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.ADD ), + SUB -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SUB ), + AND -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.AND ), + OR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.OR ), + XOR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.XOR ), + SLT -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLT ), + SLTU -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLTU ), + SRA -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRA ), + SRL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRL ), + SLL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLL ), + + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + ADDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), + ANDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.AND ), + ORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.OR ), + XORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.XOR ), + SLTI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLT ), + SLTIU -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLTU ), + SRAI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRA ), + SRLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRL ), + SLLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLL ), + + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + LUI -> List(Y, N, N, N, N, branchType.DC, Op1Select.DC, imm, ImmFormat.UTYPE, ALUOps.COPY_B), + AUIPC -> List(Y, N, N, N, N, branchType.DC, Op1Select.PC, imm, ImmFormat.UTYPE, ALUOps.ADD ), + + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + LW -> List(Y, Y, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), + SW -> List(N, N, Y, N, N, branchType.DC, rs1, imm, ImmFormat.STYPE, ALUOps.ADD ), // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp - ADDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), - ANDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.AND ), - ORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.OR ), - XORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.XOR ), - SLTI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLT ), - SLTIU -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLTU ), - SRAI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRA ), - SRLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRL ), - SLLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLL ), - - // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp - LUI -> List(Y, N, N, N, N, branchType.DC, Op1Select.DC, imm, ImmFormat.UTYPE, ALUOps.COPY_B), - AUIPC -> List(Y, N, N, N, N, branchType.DC, Op1Select.PC, imm, ImmFormat.UTYPE, ALUOps.ADD ), - - // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp - LW -> List(Y, Y, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), - SW -> List(N, N, Y, N, N, branchType.DC, rs1, imm, ImmFormat.STYPE, ALUOps.ADD ), + BEQ -> List(N, N, N, Y, N, branchType.beq, PC, imm, ImmFormat.BTYPE, ALUOps.ADD ), + BNE -> List(N, N, N, Y, N, branchType.neq, PC, imm, ImmFormat.BTYPE, ALUOps.ADD ), + BLT -> List(N, N, N, Y, N, branchType.lt, PC, imm, ImmFormat.BTYPE, ALUOps.ADD ), + BGE -> List(N, N, N, Y, N, branchType.gte, PC, imm, ImmFormat.BTYPE, ALUOps.ADD ), + BLTU -> List(N, N, N, Y, N, branchType.ltu, PC, imm, ImmFormat.BTYPE, ALUOps.ADD ), + BGEU -> List(N, N, N, Y, N, branchType.gteu,PC, imm, ImmFormat.BTYPE, ALUOps.ADD ), ) diff --git a/src/main/scala/EX.scala b/src/main/scala/EX.scala index 42fb903..539ae2d 100644 --- a/src/main/scala/EX.scala +++ b/src/main/scala/EX.scala @@ -10,6 +10,11 @@ class Execute extends MultiIOModule { new Bundle { val op1 = Input(SInt(32.W)) val op2 = Input(SInt(32.W)) + val rs1ValueIn = Input(SInt(32.W)) + val rs2ValueIn = Input(SInt(32.W)) + val rs2ValueOut = Output(SInt(32.W)) + val branchType = Input(UInt(3.W)) + val branch = Output(Bool()) val ALUOp = Input(UInt(4.W)) val ALUResult = Output(SInt(32.W)) } @@ -30,5 +35,17 @@ class Execute extends MultiIOModule { ALUOps.COPY_B -> io.op2, ) + val BranchALUOpsMap = Array ( + branchType.beq -> (io.rs1ValueIn === io.rs2ValueIn), + branchType.neq -> !(io.rs1ValueIn === io.rs2ValueIn), + branchType.lt -> (io.rs1ValueIn < io.rs2ValueIn), + branchType.gte -> (io.rs1ValueIn >= io.rs2ValueIn), + branchType.ltu -> (io.rs1ValueIn.asUInt() < io.rs2ValueIn.asUInt()), + branchType.gteu -> (io.rs1ValueIn.asUInt() >= io.rs2ValueIn.asUInt()), + branchType.jump -> true.B, + ) + + io.rs2ValueOut := io.rs2ValueIn + io.branch := MuxLookup(io.branchType, false.B, BranchALUOpsMap) io.ALUResult := MuxLookup(io.ALUOp, 0.S(32.W), ALUOpsMap) } \ No newline at end of file diff --git a/src/main/scala/EXBarrier.scala b/src/main/scala/EXBarrier.scala index 3c8fbff..d61babe 100644 --- a/src/main/scala/EXBarrier.scala +++ b/src/main/scala/EXBarrier.scala @@ -8,6 +8,7 @@ class EXBarrier extends MultiIOModule { new Bundle { val ALUResultIn = Input(UInt(32.W)) val ALUResultOut = Output(UInt(32.W)) + val branchAddress = Output(UInt(32.W)) val r2ValueIn = Input(UInt(32.W)) val r2ValueOut = Output(UInt(32.W)) val writeAddrIn = Input(UInt(5.W)) @@ -18,9 +19,14 @@ class EXBarrier extends MultiIOModule { val memReadOut = Output(Bool()) val memWriteIn = Input(Bool()) val memWriteOut = Output(Bool()) + val branchIn = Input(Bool()) + val branchOut = Output(Bool()) }) io.ALUResultOut := io.ALUResultIn + val branchAddress = RegInit(UInt(32.W), 0.U) + branchAddress := io.ALUResultIn + io.branchAddress := branchAddress val r2Value = RegInit(UInt(32.W), 0.U) r2Value := io.r2ValueIn @@ -41,5 +47,9 @@ class EXBarrier extends MultiIOModule { val memWrite = RegInit(Bool(), false.B) memWrite := io.memWriteIn io.memWriteOut := memWrite + + val branch = RegInit(Bool(), false.B) + branch := io.branchIn + io.branchOut := branch } diff --git a/src/main/scala/ID.scala b/src/main/scala/ID.scala index a39367f..cda89dc 100644 --- a/src/main/scala/ID.scala +++ b/src/main/scala/ID.scala @@ -22,6 +22,7 @@ class InstructionDecode extends MultiIOModule { val pc = Input(UInt(32.W)) val op1 = Output(SInt(32.W)) val op2 = Output(SInt(32.W)) + val r1Value = Output(UInt(32.W)) val r2Value = Output(UInt(32.W)) val ALUOp = Output(UInt(4.W)) val writeAddrIn = Input(UInt(5.W)) @@ -31,6 +32,7 @@ class InstructionDecode extends MultiIOModule { val writeData = Input(UInt(32.W)) val memWrite = Output(Bool()) val memRead = Output(Bool()) + val branchType = Output(UInt(3.W)) } ) @@ -68,9 +70,11 @@ class InstructionDecode extends MultiIOModule { Op2Select.rs2 -> registers.io.readData2.asSInt(), ) io.op2 := MuxLookup(decoder.op2Select, 0.S(32.W), select2Map) + io.r1Value := registers.io.readData1 io.r2Value := registers.io.readData2 io.ALUOp := decoder.ALUop + io.branchType := decoder.branchType io.writeAddrOut := decoder.instruction.registerRd io.writeEnableOut := decoder.controlSignals.regWrite io.memRead := decoder.controlSignals.memRead diff --git a/src/main/scala/IDBarrier.scala b/src/main/scala/IDBarrier.scala index ff71f89..749ae3c 100644 --- a/src/main/scala/IDBarrier.scala +++ b/src/main/scala/IDBarrier.scala @@ -10,10 +10,14 @@ class IDBarrier extends MultiIOModule { val op1out = Output(SInt(32.W)) val op2in = Input(SInt(32.W)) val op2out = Output(SInt(32.W)) + val r1ValueIn = Input(UInt(32.W)) + val r1ValueOut = Output(UInt(32.W)) val r2ValueIn = Input(UInt(32.W)) val r2ValueOut = Output(UInt(32.W)) val ALUopIn = Input(UInt(4.W)) val ALUopOut = Output(UInt(4.W)) + val branchTypeIn = Input(UInt(3.W)) + val branchTypeOut = Output(UInt(3.W)) val writeAddrIn = Input(UInt(5.W)) val writeAddrOut = Output(UInt(5.W)) val writeEnableIn = Input(Bool()) @@ -32,6 +36,10 @@ class IDBarrier extends MultiIOModule { op2 := io.op2in io.op2out := op2 + val r1Value = RegInit(UInt(32.W), 0.U) + r1Value := io.r1ValueIn + io.r1ValueOut := r1Value + val r2Value = RegInit(UInt(32.W), 0.U) r2Value := io.r2ValueIn io.r2ValueOut := r2Value @@ -40,6 +48,10 @@ class IDBarrier extends MultiIOModule { ALUop := io.ALUopIn io.ALUopOut := ALUop + val branchType = RegInit(UInt(5.W), 0.U) + branchType := io.branchTypeIn + io.branchTypeOut := branchType + val writeAddr = RegInit(UInt(5.W), 0.U) writeAddr := io.writeAddrIn io.writeAddrOut := writeAddr diff --git a/src/main/scala/IF.scala b/src/main/scala/IF.scala index 1659efd..f35311e 100644 --- a/src/main/scala/IF.scala +++ b/src/main/scala/IF.scala @@ -25,6 +25,8 @@ class InstructionFetch extends MultiIOModule { new Bundle { val PC = Output(UInt(32.W)) val instruction = Output(new Instruction) + val branch = Input(Bool()) + val branchAddress = Input(UInt(32.W)) }) val IMEM = Module(new IMEM) @@ -45,7 +47,7 @@ class InstructionFetch extends MultiIOModule { */ io.PC := PC IMEM.io.instructionAddress := PC - PC := PC + 4.U + PC := Mux(io.branch, io.branchAddress, PC + 4.U) val instruction = Wire(new Instruction) instruction := IMEM.io.instruction.asTypeOf(new Instruction) diff --git a/src/test/resources/tests/branch.s b/src/test/resources/tests/branch.s new file mode 100644 index 0000000..0761396 --- /dev/null +++ b/src/test/resources/tests/branch.s @@ -0,0 +1,7 @@ +main: + addi x1, x1, 12 + lui x2, 0 +loop: + addi x2, x2, 1 + blt x2, x1, loop + done \ No newline at end of file diff --git a/src/test/scala/Manifest.scala b/src/test/scala/Manifest.scala index 426b464..d38187b 100644 --- a/src/test/scala/Manifest.scala +++ b/src/test/scala/Manifest.scala @@ -19,7 +19,7 @@ import LogParser._ object Manifest { - val singleTest = "constants.s" + val singleTest = "branch.s" val nopPadded = true