diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index de251b5..c13b6c7 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -63,9 +63,12 @@ class CPU extends MultiIOModule { IDBarrier.op1in := ID.io.op1 IDBarrier.op2in := ID.io.op2 + IDBarrier.r2ValueIn := ID.io.r2Value IDBarrier.ALUopIn := ID.io.ALUOp IDBarrier.writeEnableIn := ID.io.writeEnableOut IDBarrier.writeAddrIn := ID.io.writeAddrOut + IDBarrier.memWriteIn := ID.io.memWrite + IDBarrier.memReadIn := ID.io.memRead EX.io.op1 := IDBarrier.op1out EX.io.op2 := IDBarrier.op2out @@ -73,9 +76,18 @@ class CPU extends MultiIOModule { EXBarrier.writeEnableIn := IDBarrier.writeEnableOut EXBarrier.writeAddrIn := IDBarrier.writeAddrOut - EXBarrier.writeDataIn := EX.io.ALUResult.asUInt() + EXBarrier.memWriteIn := IDBarrier.memWriteOut + EXBarrier.memReadIn := IDBarrier.memReadOut + EXBarrier.r2ValueIn := IDBarrier.r2ValueOut + EXBarrier.ALUResultIn := EX.io.ALUResult.asUInt() - ID.io.writeData := EXBarrier.writeDataOut + MEM.io.ALUResult := EXBarrier.ALUResultOut + MEM.io.writeMem := EXBarrier.memWriteOut + MEM.io.readMem := EXBarrier.memReadOut + MEM.io.writeData := EXBarrier.r2ValueOut + + // Write back + ID.io.writeData := MEM.io.dataOut ID.io.writeEnableIn := EXBarrier.writeEnableOut ID.io.writeAddrIn := EXBarrier.writeAddrOut } diff --git a/src/main/scala/Decoder.scala b/src/main/scala/Decoder.scala index 6f4e376..9cc836b 100644 --- a/src/main/scala/Decoder.scala +++ b/src/main/scala/Decoder.scala @@ -46,27 +46,36 @@ class Decoder() extends Module { */ val opcodeMap: Array[(BitPat, List[UInt])] = Array( - // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp - ADD -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.ADD ), - SUB -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SUB ), - AND -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.AND ), - OR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.OR ), - XOR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.XOR ), - SLT -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLT ), - SLTU -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLTU), - SRA -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRA ), - SRL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRL ), - SLL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLL ), + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + ADD -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.ADD ), + SUB -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SUB ), + AND -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.AND ), + OR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.OR ), + XOR -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.XOR ), + SLT -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLT ), + SLTU -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLTU), + SRA -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRA ), + SRL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRL ), + SLL -> List(Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLL ), - ADDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), - ANDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.AND ), - ORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.OR ), - XORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.XOR ), - SLTI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLT ), - SLTIU -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLTU), - SRAI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRA ), - SRLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRL ), - SLLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLL ), + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + ADDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), + ANDI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.AND ), + ORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.OR ), + XORI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.XOR ), + SLTI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLT ), + SLTIU -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLTU), + SRAI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRA ), + SRLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SRL ), + SLLI -> List(Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.SLL ), + + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + LUI -> List(Y, N, N, N, N, branchType.DC, Op1Select.Zero, imm, ImmFormat.UTYPE, ALUOps.ADD ), + AUIPC -> List(Y, N, N, N, N, branchType.DC, Op1Select.PC, imm, ImmFormat.UTYPE, ALUOps.ADD ), + + // signal regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp + LW -> List(Y, Y, N, N, N, branchType.DC, rs1, imm, ImmFormat.ITYPE, ALUOps.ADD ), + SW -> List(N, N, Y, N, N, branchType.DC, rs1, imm, ImmFormat.STYPE, ALUOps.ADD ), ) diff --git a/src/main/scala/EXBarrier.scala b/src/main/scala/EXBarrier.scala index 6ccffd1..3c8fbff 100644 --- a/src/main/scala/EXBarrier.scala +++ b/src/main/scala/EXBarrier.scala @@ -6,17 +6,25 @@ import chisel3.experimental.MultiIOModule class EXBarrier extends MultiIOModule { val io = IO( new Bundle { - val writeDataIn = Input(UInt(32.W)) - val writeDataOut = Output(UInt(32.W)) + val ALUResultIn = Input(UInt(32.W)) + val ALUResultOut = Output(UInt(32.W)) + val r2ValueIn = Input(UInt(32.W)) + val r2ValueOut = Output(UInt(32.W)) val writeAddrIn = Input(UInt(5.W)) val writeAddrOut = Output(UInt(5.W)) val writeEnableIn = Input(Bool()) val writeEnableOut = Output(Bool()) + val memReadIn = Input(Bool()) + val memReadOut = Output(Bool()) + val memWriteIn = Input(Bool()) + val memWriteOut = Output(Bool()) }) - val writeData = RegInit(UInt(32.W), 0.U) - writeData := io.writeDataIn - io.writeDataOut := writeData + io.ALUResultOut := io.ALUResultIn + + val r2Value = RegInit(UInt(32.W), 0.U) + r2Value := io.r2ValueIn + io.r2ValueOut := r2Value val writeAddr = RegInit(UInt(5.W), 0.U) writeAddr := io.writeAddrIn @@ -25,5 +33,13 @@ class EXBarrier extends MultiIOModule { val writeEnable = RegInit(Bool(), false.B) writeEnable := io.writeEnableIn io.writeEnableOut := writeEnable + + val memRead = RegInit(Bool(), false.B) + memRead := io.memReadIn + io.memReadOut := memRead + + val memWrite = RegInit(Bool(), false.B) + memWrite := io.memWriteIn + io.memWriteOut := memWrite } diff --git a/src/main/scala/ID.scala b/src/main/scala/ID.scala index 768b89f..66ca58f 100644 --- a/src/main/scala/ID.scala +++ b/src/main/scala/ID.scala @@ -22,12 +22,15 @@ class InstructionDecode extends MultiIOModule { val pc = Input(UInt(32.W)) val op1 = Output(SInt(32.W)) val op2 = Output(SInt(32.W)) + val r2Value = Output(UInt(32.W)) val ALUOp = Output(UInt(4.W)) val writeAddrIn = Input(UInt(5.W)) val writeAddrOut = Output(UInt(5.W)) val writeEnableIn = Input(Bool()) val writeEnableOut = Output(Bool()) val writeData = Input(UInt(32.W)) + val memWrite = Output(Bool()) + val memRead = Output(Bool()) } ) @@ -47,8 +50,9 @@ class InstructionDecode extends MultiIOModule { decoder.instruction := io.instruction val select1Map = Array( - Op1Select.rs1 -> registers.io.readData1.asSInt(), - Op1Select.PC -> io.pc.asSInt(), + Op1Select.rs1 -> registers.io.readData1.asSInt(), + Op1Select.PC -> io.pc.asSInt(), + Op1Select.Zero -> 0.S ) io.op1 := MuxLookup(decoder.op1Select, 0.S(32.W), select1Map) @@ -65,8 +69,11 @@ class InstructionDecode extends MultiIOModule { Op2Select.rs2 -> registers.io.readData2.asSInt(), ) io.op2 := MuxLookup(decoder.op2Select, 0.S(32.W), select2Map) + io.r2Value := registers.io.readData2 io.ALUOp := decoder.ALUop io.writeAddrOut := decoder.instruction.registerRd io.writeEnableOut := decoder.controlSignals.regWrite + io.memRead := decoder.controlSignals.memRead + io.memWrite := decoder.controlSignals.memWrite } diff --git a/src/main/scala/IDBarrier.scala b/src/main/scala/IDBarrier.scala index 1d2d0f7..ff71f89 100644 --- a/src/main/scala/IDBarrier.scala +++ b/src/main/scala/IDBarrier.scala @@ -10,12 +10,18 @@ class IDBarrier extends MultiIOModule { val op1out = Output(SInt(32.W)) val op2in = Input(SInt(32.W)) val op2out = Output(SInt(32.W)) + val r2ValueIn = Input(UInt(32.W)) + val r2ValueOut = Output(UInt(32.W)) val ALUopIn = Input(UInt(4.W)) val ALUopOut = Output(UInt(4.W)) val writeAddrIn = Input(UInt(5.W)) val writeAddrOut = Output(UInt(5.W)) val writeEnableIn = Input(Bool()) val writeEnableOut = Output(Bool()) + val memReadIn = Input(Bool()) + val memReadOut = Output(Bool()) + val memWriteIn = Input(Bool()) + val memWriteOut = Output(Bool()) }) val op1 = RegInit(SInt(32.W), 0.S) @@ -26,6 +32,10 @@ class IDBarrier extends MultiIOModule { op2 := io.op2in io.op2out := op2 + val r2Value = RegInit(UInt(32.W), 0.U) + r2Value := io.r2ValueIn + io.r2ValueOut := r2Value + val ALUop = RegInit(UInt(4.W), 0.U) ALUop := io.ALUopIn io.ALUopOut := ALUop @@ -37,4 +47,12 @@ class IDBarrier extends MultiIOModule { val writeEnable = RegInit(Bool(), false.B) writeEnable := io.writeEnableIn io.writeEnableOut := writeEnable + + val memRead = RegInit(Bool(), false.B) + memRead := io.memReadIn + io.memReadOut := memRead + + val memWrite = RegInit(Bool(), false.B) + memWrite := io.memWriteIn + io.memWriteOut := memWrite } diff --git a/src/main/scala/MEM.scala b/src/main/scala/MEM.scala index 0038a6d..31f94d2 100644 --- a/src/main/scala/MEM.scala +++ b/src/main/scala/MEM.scala @@ -18,6 +18,11 @@ class MemoryFetch() extends MultiIOModule { val io = IO( new Bundle { + val ALUResult = Input(UInt(32.W)) // We get ALUResult one cycle early + val writeData = Input(UInt(32.W)) + val readMem = Input(Bool()) + val writeMem = Input(Bool()) + val dataOut = Output(UInt(32.W)) }) @@ -35,7 +40,15 @@ class MemoryFetch() extends MultiIOModule { /** * Your code here. */ - DMEM.io.dataIn := 0.U - DMEM.io.dataAddress := 0.U - DMEM.io.writeEnable := false.B + DMEM.io.dataIn := io.writeData + DMEM.io.dataAddress := io.ALUResult + DMEM.io.writeEnable := io.writeMem + + // ALUResult is one cycle early! + val ALUResult = RegInit(UInt(32.W), 0.U) + ALUResult := io.ALUResult + + when(io.readMem) { + io.dataOut := DMEM.io.dataOut + }.otherwise(io.dataOut := ALUResult) } diff --git a/src/main/scala/ToplevelSignals.scala b/src/main/scala/ToplevelSignals.scala index c8efa2e..69f6627 100644 --- a/src/main/scala/ToplevelSignals.scala +++ b/src/main/scala/ToplevelSignals.scala @@ -80,9 +80,10 @@ object branchType { using them altogether. */ object Op1Select { - val rs1 = 0.asUInt(1.W) - val PC = 1.asUInt(1.W) - val DC = 0.asUInt(1.W) + val rs1 = 0.asUInt(2.W) + val PC = 1.asUInt(2.W) + val Zero = 1.asUInt(2.W) + val DC = 3.asUInt(2.W) } object Op2Select { diff --git a/src/test/scala/Manifest.scala b/src/test/scala/Manifest.scala index 6246818..ca77290 100644 --- a/src/test/scala/Manifest.scala +++ b/src/test/scala/Manifest.scala @@ -19,7 +19,7 @@ import LogParser._ object Manifest { - val singleTest = "addi.s" + val singleTest = "load2.s" val nopPadded = true