Commit 100aa5a4 authored by Florent Kermarrec's avatar Florent Kermarrec

soc/cores/spi/SPIMaster: rewrite/simplify.

- Make sure MOSI is latched on start, MISO is stable during Xfer (last value).
- Allow clk_divider down to 2.
- improve test errors reporting with hex() on AssertEqual.
parent 63c19ff4
...@@ -41,94 +41,96 @@ class SPIMaster(Module, AutoCSR): ...@@ -41,94 +41,96 @@ class SPIMaster(Module, AutoCSR):
# # # # # #
done = Signal() clk_enable = Signal()
bits = Signal(8) cs_enable = Signal()
xfer = Signal() count = Signal(max=data_width)
shift = Signal() mosi_latch = Signal()
miso_latch = Signal()
# Clock generation ------------------------------------------------------------------------- # Clock generation -------------------------------------------------------------------------
clk_divider = Signal(16) clk_divider = Signal(16)
clk_rise = Signal() clk_rise = Signal()
clk_fall = Signal() clk_fall = Signal()
self.comb += clk_rise.eq(clk_divider == (self.clk_divider[1:] - 1))
self.comb += clk_fall.eq(clk_divider == (self.clk_divider - 1))
self.sync += [ self.sync += [
If(clk_rise, pads.clk.eq(xfer)), clk_divider.eq(clk_divider + 1),
If(clk_fall, pads.clk.eq(0)), If(clk_rise,
If(clk_fall, pads.clk.eq(clk_enable),
clk_divider.eq(0) ).Elif(clk_fall,
).Else( clk_divider.eq(0),
clk_divider.eq(clk_divider + 1) pads.clk.eq(0),
) )
] ]
self.comb += clk_rise.eq(clk_divider == (self.clk_divider[1:] - 1))
self.comb += clk_fall.eq(clk_divider == (self.clk_divider - 1))
# Control FSM ------------------------------------------------------------------------------ # Control FSM ------------------------------------------------------------------------------
self.submodules.fsm = fsm = FSM(reset_state="IDLE") self.submodules.fsm = fsm = FSM(reset_state="IDLE")
fsm.act("IDLE", fsm.act("IDLE",
done.eq(1), self.done.eq(1),
If(self.start, If(self.start,
NextValue(bits, 0), self.done.eq(0),
NextState("WAIT-CLK-FALL") mosi_latch.eq(1),
NextState("START")
) )
) )
fsm.act("WAIT-CLK-FALL", fsm.act("START",
NextValue(count, 0),
If(clk_fall, If(clk_fall,
NextState("XFER") cs_enable.eq(1),
NextState("RUN")
) )
) )
fsm.act("XFER", fsm.act("RUN",
If(bits == self.length, clk_enable.eq(1),
NextState("END") cs_enable.eq(1),
).Elif(clk_fall, If(clk_fall,
NextValue(bits, bits + 1) NextValue(count, count + 1),
), If(count == (self.length - 1),
xfer.eq(1), NextState("STOP")
shift.eq(1) )
)
) )
fsm.act("END", fsm.act("STOP",
cs_enable.eq(1),
If(clk_rise, If(clk_rise,
miso_latch.eq(1),
self.irq.eq(1),
NextState("IDLE") NextState("IDLE")
), )
shift.eq(1),
self.irq.eq(1)
) )
self.sync += self.done.eq(done & ~self.start)
# Chip Select generation ------------------------------------------------------------------- # Chip Select generation -------------------------------------------------------------------
if hasattr(pads, "cs_n"): if hasattr(pads, "cs_n"):
for i in range(len(pads.cs_n)): for i in range(len(pads.cs_n)):
self.comb += pads.cs_n[i].eq(~self.cs[i] | ~xfer) self.sync += pads.cs_n[i].eq(~self.cs[i] | ~cs_enable)
# Master Out Slave In (MOSI) generation (generated on spi_clk falling edge) ---------------- # Master Out Slave In (MOSI) generation (generated on spi_clk falling edge) ----------------
mosi_data = Array(self.mosi[i] for i in range(data_width)) mosi_data = Signal(data_width)
mosi_bit = Signal(max=data_width) mosi_array = Array(mosi_data[i] for i in range(data_width))
mosi_sel = Signal(max=data_width)
self.sync += [ self.sync += [
If(self.start, If(mosi_latch,
mosi_bit.eq(self.length - 1 if mode == "aligned" else data_width - 1), mosi_data.eq(self.mosi),
).Elif(clk_rise & shift, mosi_sel.eq((self.length-1) if mode == "aligned" else (data_width-1)),
mosi_bit.eq(mosi_bit - 1) ).Elif(clk_fall,
If(cs_enable, pads.mosi.eq(mosi_array[mosi_sel])),
mosi_sel.eq(mosi_sel - 1)
), ),
If(clk_fall,
pads.mosi.eq(mosi_data[mosi_bit])
)
] ]
# Master In Slave Out (MISO) capture (captured on spi_clk rising edge) -------------------- # Master In Slave Out (MISO) capture (captured on spi_clk rising edge) --------------------
miso = Signal() miso = Signal()
miso_data = Signal(data_width) miso_data = Signal(data_width)
self.sync += [ self.sync += [
If(clk_rise & shift, If(clk_rise,
If(self.loopback, If(self.loopback,
miso.eq(pads.mosi) miso_data.eq(Cat(pads.mosi, miso_data))
).Else( ).Else(
miso.eq(pads.miso) miso_data.eq(Cat(pads.miso, miso_data))
) )
), )
If(clk_fall & shift,
miso_data.eq(Cat(miso, miso_data))
),
If(done, self.miso.eq(miso_data)),
] ]
self.sync += If(miso_latch, self.miso.eq(miso_data))
def add_csr(self, with_cs=True, with_loopback=True): def add_csr(self, with_cs=True, with_loopback=True):
self._control = CSRStorage(fields=[ self._control = CSRStorage(fields=[
......
...@@ -16,6 +16,7 @@ class TestSPI(unittest.TestCase): ...@@ -16,6 +16,7 @@ class TestSPI(unittest.TestCase):
def test_spi_master_xfer_loopback_32b_32b(self): def test_spi_master_xfer_loopback_32b_32b(self):
def generator(dut): def generator(dut):
yield dut.loopback.eq(1) yield dut.loopback.eq(1)
yield dut.clk_divider.eq(2)
yield dut.mosi.eq(0xdeadbeef) yield dut.mosi.eq(0xdeadbeef)
yield dut.length.eq(32) yield dut.length.eq(32)
yield dut.start.eq(1) yield dut.start.eq(1)
...@@ -24,7 +25,8 @@ class TestSPI(unittest.TestCase): ...@@ -24,7 +25,8 @@ class TestSPI(unittest.TestCase):
yield yield
while (yield dut.done) == 0: while (yield dut.done) == 0:
yield yield
self.assertEqual((yield dut.miso), 0xdeadbeef) yield
self.assertEqual(hex((yield dut.miso)), hex(0xdeadbeef))
dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False) dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False)
run_simulation(dut, generator(dut)) run_simulation(dut, generator(dut))
...@@ -40,7 +42,8 @@ class TestSPI(unittest.TestCase): ...@@ -40,7 +42,8 @@ class TestSPI(unittest.TestCase):
yield yield
while (yield dut.done) == 0: while (yield dut.done) == 0:
yield yield
self.assertEqual((yield dut.miso), 0xbeef) yield
self.assertEqual(hex((yield dut.miso)), hex(0xbeef))
dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False, mode="aligned") dut = SPIMaster(pads=None, data_width=32, sys_clk_freq=100e6, spi_clk_freq=5e6, with_csr=False, mode="aligned")
run_simulation(dut, generator(dut)) run_simulation(dut, generator(dut))
...@@ -59,6 +62,8 @@ class TestSPI(unittest.TestCase): ...@@ -59,6 +62,8 @@ class TestSPI(unittest.TestCase):
self.submodules.slave = SPISlave(pads, data_width=32) self.submodules.slave = SPISlave(pads, data_width=32)
def master_generator(dut): def master_generator(dut):
for i in range(8):
yield
yield dut.master.mosi.eq(0xdeadbeef) yield dut.master.mosi.eq(0xdeadbeef)
yield dut.master.length.eq(32) yield dut.master.length.eq(32)
yield dut.master.start.eq(1) yield dut.master.start.eq(1)
...@@ -67,15 +72,19 @@ class TestSPI(unittest.TestCase): ...@@ -67,15 +72,19 @@ class TestSPI(unittest.TestCase):
yield yield
while (yield dut.master.done) == 0: while (yield dut.master.done) == 0:
yield yield
self.assertEqual((yield dut.master.miso), 0x12345678) yield
self.assertEqual(hex((yield dut.master.miso)), hex(0x12345678))
def slave_generator(dut): def slave_generator(dut):
for i in range(8):
yield
yield dut.slave.miso.eq(0x12345678) yield dut.slave.miso.eq(0x12345678)
while (yield dut.slave.start) == 0: while (yield dut.slave.start) == 0:
yield yield
while (yield dut.slave.done) == 0: while (yield dut.slave.done) == 0:
yield yield
self.assertEqual((yield dut.slave.mosi), 0xdeadbeef) yield
self.assertEqual(hex((yield dut.slave.mosi)), hex(0xdeadbeef))
self.assertEqual((yield dut.slave.length), 32) self.assertEqual((yield dut.slave.length), 32)
dut = DUT() dut = DUT()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment