aboutsummaryrefslogblamecommitdiff
path: root/src/models/lt_bus.h
blob: 9ae66b0eb3046909ae5c6ab9ae372d503e64495e (plain) (tree)









































































































































































































































































































                                                                                
#ifndef SYSC_PLAYGROUND_LT_BUS
#define SYSC_PLAYGROUND_LT_BUS

#include <tlm_core/tlm_2/tlm_generic_payload/tlm_generic_payload.h>
#include <tlm_core/tlm_2/tlm_sockets/tlm_sockets.h>

#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>

struct scoped_push_hierarchy {
  [[nodiscard]] explicit scoped_push_hierarchy(sc_core::sc_module& mod)
      : m_mod(mod), m_simctx(sc_core::sc_get_curr_simcontext()) {
    assert(m_simctx);
    m_simctx->hierarchy_push(&m_mod);
  }

  ~scoped_push_hierarchy() {
    const auto* top = m_simctx->hierarchy_pop();
    assert(top == &m_mod);
  }

 private:
  sc_core::sc_simcontext* m_simctx{nullptr};
  sc_core::sc_module& m_mod;
};

struct range {
  constexpr explicit range(std::uint64_t start, std::uint64_t end)
      : start{start}, end{end} {
    assert(start < end);
  }

  constexpr bool overlaps(range rhs) const {
    return start <= rhs.end && rhs.start <= end;
  }

  constexpr bool contains(range rhs) const {
    return start <= rhs.start && rhs.end <= end;
  }

  std::uint64_t start;
  std::uint64_t end;
};

template <typename Module>
class tlm_target_socket_tagged : public tlm::tlm_target_socket<>,
                                 public tlm::tlm_fw_transport_if<> {
  using cb_b_transport = void (Module::*)(std::size_t,
                                          tlm::tlm_generic_payload&,
                                          sc_core::sc_time&);

  using cb_get_direct_mem_ptr = bool (Module::*)(std::size_t,
                                                 tlm::tlm_generic_payload&,
                                                 tlm::tlm_dmi&);
  using cb_transport_dbg = unsigned int (Module::*)(std::size_t,
                                                    tlm::tlm_generic_payload&);

 public:
  explicit tlm_target_socket_tagged(const char* name,
                                    std::size_t id,
                                    Module* mod,
                                    cb_b_transport b,
                                    cb_get_direct_mem_ptr m,
                                    cb_transport_dbg d)
      : tlm_target_socket<>{name},
        m_mod{mod},
        m_id{id},
        m_b_transport{b},
        m_get_direct_mem_ptr{m},
        m_transport_dbg{d} {
    bind(*static_cast<tlm::tlm_fw_transport_if<>*>(this));
  }

 private:
  // -- TLM_FW_TRANSPORT_IF ----------------------------------------------------

  virtual void b_transport(tlm::tlm_generic_payload& tx,
                           sc_core::sc_time& t) override {
    return (m_mod->*m_b_transport)(m_id, tx, t);
  }

  virtual bool get_direct_mem_ptr(tlm::tlm_generic_payload& tx,
                                  tlm::tlm_dmi& dmi_data) override {
    return (m_mod->*m_get_direct_mem_ptr)(m_id, tx, dmi_data);
  }

  virtual unsigned int transport_dbg(tlm::tlm_generic_payload& tx) override {
    return (m_mod->*m_transport_dbg)(m_id, tx);
  }

  virtual tlm::tlm_sync_enum nb_transport_fw(tlm::tlm_generic_payload& tx,
                                             tlm::tlm_phase& phase,
                                             sc_core::sc_time& t) override {
    std::fprintf(stderr,
                 "tlm_target_socket_tagged: nb_transport_fw not supported\n");
    std::abort();
  }

  // -- MEMBER -----------------------------------------------------------------

  std::size_t m_id{0};
  Module* m_mod{nullptr};
  cb_b_transport m_b_transport{nullptr};
  cb_get_direct_mem_ptr m_get_direct_mem_ptr{nullptr};
  cb_transport_dbg m_transport_dbg{nullptr};
};

class lt_bus : public sc_core::sc_module, public tlm::tlm_bw_transport_if<> {
  using target_socket = tlm_target_socket_tagged<lt_bus>;
  using target_socket_ptr = std::unique_ptr<target_socket>;

  using initiator_socket = tlm::tlm_initiator_socket<>;
  using initiator_socket_ptr = std::unique_ptr<initiator_socket>;

 public:
  explicit lt_bus(sc_core::sc_module_name nm)
      : sc_core::sc_module(std::move(nm)) {}

  // -- ATTACH BUS INITIATOR ---------------------------------------------------

  void attach_initiator(tlm::tlm_base_initiator_socket_b<>& init) {
    const std::size_t id = m_initiators.size();
    const std::string name = "init" + std::to_string(id);
    {
      // Add current module on top of module stack for tlm sockets.
      scoped_push_hierarchy g(*this);

      // Add new target socket to connect BUS INITIATOR.
      m_initiators.push_back(std::make_unique<target_socket>(
          name.c_str(), id, this, &lt_bus::b_transport,
          &lt_bus::get_direct_mem_ptr, &lt_bus::transport_dbg));
    }

    // Bind sockets.
    auto& target = m_initiators.back();
    target->bind(init);
  }

  // -- ATTACH BUS TARGET ------------------------------------------------------

  void attach_target(tlm::tlm_base_target_socket_b<>& target,
                     std::uint64_t start,
                     std::uint64_t end) {
    const range addr{start, end};

    // Check if new range overlaps with any registered memory map range.
    for (const auto& map : m_mappings) {
      if (map.addr.overlaps(addr)) {
        std::fprintf(stderr,
                     "lt_bus: memory map conflict detected\n"
                     "old: %08lx - %08lx\n"
                     "new: %08lx - %08lx\n",
                     map.addr.start, map.addr.end, start, end);
        std::abort();
      }
    }

    const std::size_t id = m_targets.size();
    const std::string name = "target" + std::to_string(id);
    {
      // Add current module on top of module stack for tlm sockets.
      scoped_push_hierarchy g(*this);

      // Add new initiator socket to connect BUS TARGET.
      m_targets.push_back(std::make_unique<initiator_socket>(name.c_str()));
    }

    // Bind sockets.
    auto& init = m_targets.back();
    init->bind(*this);
    init->bind(target);

    // Insert new mapping, id is equal to idx into socket vector.
    m_mappings.push_back({addr, id});
  }

 private:
  // -- TLM_BW_TRANSPORT_IF ----------------------------------------------------

  virtual tlm::tlm_sync_enum nb_transport_bw(tlm::tlm_generic_payload&,
                                             tlm::tlm_phase&,
                                             sc_core::sc_time&) override {
    std::fprintf(stderr, "lt_bus: nb_transport_bw not supported\n");
    std::abort();
  }

  virtual void invalidate_direct_mem_ptr(sc_dt::uint64 start,
                                         sc_dt::uint64 end) override {
    assert(false);
  }

  // -- TLM_FW_TRANSPORT_IF ----------------------------------------------------

  void b_transport(std::size_t id,
                   tlm::tlm_generic_payload& tx,
                   sc_core::sc_time& t) {
    std::uint64_t start = tx.get_address();
    std::uint64_t end = start + tx.get_data_length() - 1;

    if (auto res = decode(range{start, end})) {
      assert(res.base <= start);

      tx.set_address(start - res.base);
      (*res.sock)->b_transport(tx, t);
      tx.set_address(start);
    } else {
      tx.set_response_status(tlm::TLM_ADDRESS_ERROR_RESPONSE);
    }
  }

  bool get_direct_mem_ptr(std::size_t,
                          tlm::tlm_generic_payload& tx,
                          tlm::tlm_dmi& dmi) {
    std::uint64_t start = tx.get_address();
    std::uint64_t end = start + tx.get_data_length() - 1;

    bool ret = false;
    if (auto res = decode(range{start, end})) {
      assert(res.base <= start);

      tx.set_address(start - res.base);
      ret = (*res.sock)->get_direct_mem_ptr(tx, dmi);
      tx.set_address(start);
    }
    return ret;
  }

  unsigned int transport_dbg(std::size_t, tlm::tlm_generic_payload& tx) {
    std::uint64_t start = tx.get_address();
    std::uint64_t end = start + tx.get_data_length() - 1;

    unsigned int ret = 0;
    if (auto res = decode(range{start, end})) {
      assert(res.base <= start);

      tx.set_address(start - res.base);
      ret = (*res.sock)->transport_dbg(tx);
      tx.set_address(start);
    }
    return ret;
  }

  // -- DECODE BUS TARGET ------------------------------------------------------

  struct decode_result {
    initiator_socket* sock{nullptr};
    std::uint64_t base;

    constexpr explicit operator bool() const {
      return sock != nullptr;
    }
  };

  decode_result decode(range addr) const {
    for (const auto& map : m_mappings) {
      if (map.addr.contains(addr)) {
        return {m_targets[map.idx].get(), map.addr.start};
      }
    }
    return {nullptr, 0ull};
  }

  // -- SC_MODULE CALLBACKS ----------------------------------------------------

  virtual void start_of_simulation() override {
    // Sort memory mappings by start address.
    std::sort(m_mappings.begin(), m_mappings.end(),
              [](const mapping& lhs, const mapping& rhs) {
                return lhs.addr.start < rhs.addr.start;
              });

    // Dump memory map.
    // for (const auto& map : m_mappings) {
    //   std::printf("%08lx - %08lx :[%2ld] %s\n", map.addr.start, map.addr.end,
    //               map.idx, m_targets[map.idx].get()->name());
    // }
  }

  // -- LOCAL CLASSES ----------------------------------------------------------

  struct mapping {
    range addr;
    std::size_t idx;
  };

  // -- MEMBER -----------------------------------------------------------------

  // TARGET sockets to bind BUS INITIATORS against.
  std::vector<target_socket_ptr> m_initiators;
  // INITIATOR sockets to bind BUS TARGET against.
  std::vector<initiator_socket_ptr> m_targets;
  // Address range mappings to BUS TARGETs (m_tragets).
  std::vector<mapping> m_mappings;
};

#endif