#ifndef SAMPLE_PREFETCHER_H
#define SAMPLE_PREFETCHER_H

#include <stdio.h>

#include <map>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <set>
#include <deque>
#include <vector>
#include <map>
#include <string>
#include <fstream>
#include <sstream>
#include <algorithm>
#include "interface.h"

#define BLOCKINT signed long long

///////////////////////////////////////////////////////////////////////////////
// Basic Classes and Templates
///////////////////////////////////////////////////////////////////////////////

class NodeBase {
};

/// Nodes in the LRU-like queues
template<class ItemType>
class _List_Node :  public NodeBase 
{
 public:
  ItemType _item;
  _List_Node* _next;
  _List_Node* _prev;
    
 _List_Node(ItemType item = ItemType())
   :_item(item), 
    _next(0), 
    _prev(0)
      {}

 _List_Node(_List_Node& node)
   :_item(node._item), 
    _next(0), 
    _prev(0)
      {}

  ~_List_Node(){}
};

/// template class for any LRU-like queues.
/// The queue is implemented in a form of double-linked list.
/// The template provides constructor/destructor, and a "bring to head" method.
/// Inherite from LRUList<ItemType> to create your new queue class,
/// where ItemType is the class of queue items. 
/// Requirement of ItemType:
/// 	o. It should provide default/copy constructors.
///		o. It should provide a Dump() routine for dumpping out its content to stdout. This is useful in debugging.
template<class ItemType>
class LRUList{
 public:
  typedef _List_Node<ItemType> Node;
  //protected:
 public:
  Node* mListHead;
  Node* mListTail;
  UINT32 mListLength;

  /// Brings the node to the head of the queue.
  /// This makes it Most Recently Used
  void BringsToHead(Node* ptr){

    if( ptr->_prev != NULL ) 
      {
	// Pull the entry out chain previous and next elements
	// to each other
	ptr->_prev->_next = ptr->_next;
	if( ptr->_next ) 
	  {
	    ptr->_next->_prev = ptr->_prev;
	  }
	else 
	  {
	    // If tail is being removed, set the tail to the
	    // previous guy in the link-list
	    mListTail = ptr->_prev;
	  }

	// Since we are bringing this to the head, next
	// element is going to be the current head.  
	ptr->_next        = mListHead;
	ptr->_prev        = NULL;
                
	// set current head's previous to this
	mListHead->_prev   = ptr;
                
	// This is now the current head
	mListHead          = ptr;
      }
  }

  /// Brings the node to the tail of the queue.
  /// This makes it Least Recently Used
  void BringsToTail(Node* ptr){
    if( ptr->_next != NULL ) 
      {
	// Pull the entry out chain previous and next elements
	// to each other
	ptr->_next->_prev = ptr->_prev;
	if( ptr->_prev ) 
	  {
	    ptr->_prev->_next = ptr->_next;
	  }
	else 
	  {
	    // If head is being removed, set the head to the
	    // next guy in the link-list
	    mListHead        = ptr->_next;
	  }

	// Since we are bringing this to the tail, next
	// element is going to be the current tail.  
	ptr->_prev        = mListTail;
	ptr->_next        = NULL;

	mListTail->_next   = ptr;

	// This is now the current tail
	mListTail          = ptr;
      }
  }

 public:
  /// Create the list with specified length.
 LRUList(UINT32 listLength): 
  mListHead(0), 
    mListTail(0), 
    mListLength(listLength) 
    {
      Node* prevNode = 0;
      for (UINT32 i=0; i<listLength; i++){
	Node* node = new Node();
	if (mListHead == 0) mListHead = node;
	node->_prev = prevNode;
	if (prevNode) prevNode->_next = node;
	prevNode = node;
      }
      mListTail = prevNode;
    }

  ~LRUList()
    {
      while (mListHead)
        {
	  Node* node = mListHead->_next;
	  delete mListHead;
	  mListHead = node;
        }
    }

  /// Utillity routine for debugging in gdb. Dump out the list.
  void Dump()
  {
    Node* node=mListHead;
    for (UINT32 i=0; i<mListLength; i++, node=node->_next)
      {
	std::cout<<i<<": ";
	if (node) node->_item.Dump();
	std::cout<<std::endl;
      }
  }
};

class Common {
 public:
  static const UINT32 LogLevel0 = 0;
  static const UINT32 LogLevel1 = 1;
  static const UINT32 LogLevel2 = 2;
  static const UINT32 LogLevel3 = 3;

  static const UINT32 STREAM_MAX_NEIGHBOR_DISTANCE = 16;
};

class AddrOperation {
 public:
  inline static ADDRINT GetCacheBlockAddr(ADDRINT addr, ADDRINT blocksize) {
    return ((addr) & ~(blocksize-1));
  }
  static INT32 GetLog2(ADDRINT blocksize);
    
};

// Base class for prefetcher
class PrefetcherBase {
 private:
  COUNTER prefnum;
  COUNTER prefhit;
  COUNTER _interval;

 public:
  std::string prefname;
  ADDRINT _blocksize;
  INT32 _blockbitnum;
  ADDRINT _cachesize;
  ADDRINT _setassoc;

 public:
  std::map<std::string, int> params;
  UINT32 _loglevel;
  std::string _mystat;

  PrefetcherBase() {
    prefname = "PrefetcherBase";
    prefnum = prefhit = 0;

    _blocksize = 64;
    _blockbitnum = AddrOperation::GetLog2(_blocksize);
    _loglevel = 0;
    _mystat = "mystat.csv";

    params["cacheblocksize"] = (int)_blocksize;
    params["cacheblockbits"] = _blockbitnum;
    params["loglevel"] = _loglevel;

    switch(_loglevel) {
    case 0: 
      _interval=100000; 
      break;
    case 1:
      _interval=50000;
      break;
    case 2:
      _interval=20000;
      break;
    case 3:
      _interval=10000;
      break;
    default:
      _interval=10000;
      break;
    }
  }
  virtual ~PrefetcherBase() {}

  std::string GetPrefName() { return prefname; }
  ADDRINT GetBlockSize() { return _blocksize; }
  void IncPrefnum() { prefnum++; }
  void IncPrefhit() { prefhit++; }
  COUNTER GetPrefnum() { return prefnum; }
  COUNTER GetPrefhit() { return prefhit; }

  inline ADDRINT GetBlockFromAddr(ADDRINT addr) { 
    return (addr >> _blockbitnum);
  }
  inline ADDRINT GetAddrFromBlock(ADDRINT baddr) { 
    return (baddr << _blockbitnum);
  }

  inline ADDRINT GetBlockAddr(ADDRINT addr) {
    return AddrOperation::GetCacheBlockAddr(addr, _blocksize);
  }

  //each prefetcher will implement this function
  virtual void OnMiss(std::stringstream& ss,
		      PrefetchData_t& pdata, 
		      std::vector<CacheAddr_t>& requests,
		      COUNTER cycle = 0,
		      bool isfiltered = false) {}
  virtual void OnHit(std::stringstream& ss,
		     PrefetchData_t& pdata, 
		     std::vector<CacheAddr_t>& requests,
		     COUNTER cycle = 0) {}
  virtual void OnPrefetchHit(std::stringstream& ss,
			     PrefetchData_t& pdata, 
			     std::vector<CacheAddr_t>& requests,
			     COUNTER cycle = 0) {}
  virtual void GetSpecificStatistics() {}

  void ShowPrefNum();
  void ShowStatistics(COUNTER cycle);
  void Log(const UINT32 loglevel, std::string& str) {
    if(_loglevel >= loglevel)
      std::cout << str << std::endl;
  }
};

///////////////////////////////////////////////////////////////////////////////
// This is a simplified MSHR.
// For the purpose of filtering consecutive same-block misses
// Total 16 MSHR entries. The entries are invalidated either by LRU kickout, 
// or when finding out the block has arrived into cache
///////////////////////////////////////////////////////////////////////////////
class MSHREntry {
 public:
  ADDRINT addr;//cache block addr
};

//MSHR is a FIFO-like structure
class MSHR
    : public LRUList<MSHREntry>
{
 private:
  ADDRINT blocksize;
  std::map<ADDRINT, Node*> mshrlist;
  std::map<ADDRINT, Node*>::iterator iter;

 public:
  MSHR(ADDRINT blocksize0, UINT32 tableSize)
    : LRUList<MSHREntry>(tableSize),
    blocksize(blocksize0)
    {}
 
  ~MSHR(){}

  Node *AddEntry(ADDRINT addr);
  void DelEntry(ADDRINT addr);
  Node* Find( ADDRINT addr );
  void AddToMSHRlist( ADDRINT addr, Node *tmp );
  void DelFrmMSHRlist(ADDRINT addr);
  void ShowMSHR(std::stringstream& ss);
};

///////////////////////////////////////////////////////////////////////////////
// Basic Class for prefetchers with MSHR support
///////////////////////////////////////////////////////////////////////////////
class MSHRPrefetcher: virtual public PrefetcherBase {
 private:
  bool MSHREnabled;
  MSHR *mshr;

 public:
  MSHRPrefetcher()
    {
      prefname = "MSHRPrefetcher";
    
      ADDRINT blocksize;
      UINT32 mshrsize;
    
      MSHREnabled = true;
      blocksize = 64;
      mshrsize = 16;

      params["mshr_enable"] = (int)MSHREnabled;
      params["mshr_size"] = (int)mshrsize;

      if(MSHREnabled)
	mshr = new MSHR(blocksize, mshrsize);
      else
	mshr = NULL;
    }

  virtual ~MSHRPrefetcher() { 
    if(mshr)
      delete mshr;
  }

  MSHR *GetMSHR() { return mshr; }
  bool IsMSHREnabled() { return MSHREnabled; }
  MSHR::Node *AddMSHREntry(ADDRINT addr) {
    if(mshr)
      return mshr->AddEntry(addr);
    else
      return NULL;
  }

  void DelMSHREntry(ADDRINT addr) {
    if(mshr)
      mshr->DelEntry(addr);
  }

  bool LookupMSHR(ADDRINT addr) {
    if(mshr == NULL)
      return false;

    if(mshr->Find(addr) != NULL) {
      return true;
    }
    else {
      AddMSHREntry(addr);
      return false;
    }
  }
 
  void ShowMSHR(std::stringstream& ss) {
    if(mshr)
      mshr->ShowMSHR(ss);
  }
};

///////////////////////////////////////////////////////////////////////////////
// My Stream Prefetching
///////////////////////////////////////////////////////////////////////////////



///////////////////////////////////////////////////////////////////////////////
// Classes for training table: MyTrainingEntry, MyTrainingTable
///////////////////////////////////////////////////////////////////////////////

class MyTrainingEntry{
  public:
  /** Each entry occupies 39 bits **/

    // the first miss that train the stream, block address (26 bits)
    ADDRINT start_baddr;      

    // the second miss which is within 16 blocks of start_addr, 
    // here stores the block address distance (can be negative) to start_addr (5 bits)
    INT32 second_baddr_dist;  

    // the third miss which is within 16 blocks of start_addr, 
    //stores the block address distance (can be negative) to start_addr (5 bits)
    INT32 end_baddr_dist;    
 
    // direction of the stream, true=ascending (1 bit)
    bool ascending;          
    // indicates it's trained, and should be moved to trained table (1 bit)
    bool trained;           
    // special flag indicating possible noise existing (1 bit) 
    bool noise_flag;        

    /// default constructor
    MyTrainingEntry()
      :start_baddr(0),
      second_baddr_dist(0),
      end_baddr_dist(0),
      ascending(false),
      trained(false),
      noise_flag(0)
    { 
    }

    /// copy constructor
    MyTrainingEntry(const MyTrainingEntry& ent)
      :start_baddr(ent.start_baddr), 
      second_baddr_dist(ent.second_baddr_dist), 
      end_baddr_dist(ent.end_baddr_dist), 
      ascending(ent.ascending),
      trained(ent.trained),
      noise_flag(ent.noise_flag)
    {
        // Nothing
    }

    ~MyTrainingEntry()
    {
        // Nothing
    }

    void Clear()
    {
      start_baddr = 0;
      second_baddr_dist = 0;
      end_baddr_dist = 0;
      ascending = false;
      trained = false;
      noise_flag = 0;
    }

    // function to get block address of the second miss
    ADDRINT GetSecondBAddr() {
      return (start_baddr + second_baddr_dist);
    }
    // function to get block address of the third miss
    ADDRINT GetEndBAddr() {
      return (start_baddr + end_baddr_dist);
    }
    // function to get distance to the start_baddr in block address
    INT32 GetBAddrDist(ADDRINT baddr) {
      return (INT32)((BLOCKINT) baddr - (BLOCKINT)start_baddr);
    }
};

class MyTrainingTable : public LRUList<MyTrainingEntry>
{
  private:
    ADDRINT _blocksize;
    INT32 _blockbitnum;
    bool _enable_noise_removal;
    UINT32 _stream_prefetch_distance;

 public:
 MyTrainingTable(UINT32 tableSize, ADDRINT blocksize, INT32 blockbitnum, bool enable_noise_removal, 
		 UINT32 stream_prefetch_distance)
        : LRUList<MyTrainingEntry>(tableSize){
      _blocksize = blocksize;
      _blockbitnum = blockbitnum;
      _enable_noise_removal = enable_noise_removal;
      _stream_prefetch_distance = stream_prefetch_distance;
    }

    ~MyTrainingTable(){}

    // Look-up in the training table. Try to train a stream. Return the node
    Node* AccessEntry(UINT32 threadId, ADDRINT addr, COUNTER cycle);

    // Reuse the entry by shiftting out the start address, use the second and end addr as the new start and second respectively
    void RenewEntry(MyTrainingEntry& entry, COUNTER cycle);

    // Dump the training table
    void ShowEntry(MyTrainingEntry& entry);
    void ShowTable();
};





///////////////////////////////////////////////////////////////////////////////
// Classes for stream table: MyStreamEntry, MyStreamPrefetchTable
///////////////////////////////////////////////////////////////////////////////
class MyStreamEntry{
  public:
  /** Each entry occupies 97 bits **/

  // the first ever miss addr which triggers the stream, block address (26 bits)
    ADDRINT original_baddr;      

    // start addr of monitored region, block address distance to region_region_end_addr (9 bits)
    // Since prefetch distance is set to be 64, with constant stride extension, 
    // it will be at most 8 blocks (half of training window size) for the stride, 
    //thus the region’s size can be 6+3 bits
    ADDRINT region_start_baddr_dist;  

    // end addr of region, the last prefetched address (32 bits)
    ADDRINT region_end_addr;    

    // last access addr(hit/miss), byte address distance to region_end_addr (16 bits)
    // Since it cannot be far more than (prefetch distance (64) + prefetch degree (4)) away from the region’s end address, 
    // it needs 7 bits to cover the prefetch_distance, 6 bits for block, and extra 3 bits if it has a multi-block stride
    ADDRINT last_access_addr_dist;   

    // stride of the stream, in bytes. (9 bits) 
    //Since the stride can be at most 8 blocks, this field can be at most 8*64bytes, thus 9 bits;
    ADDRINT stride;       
      
    bool ascending;             // direction of the stream, true=ascending (1 bit)
    bool stride_confirm;        // on to confirm the stride; otherwise stride=1 block (1 bit)

    //during constant stride detection, indicates that a detection begins (1 bit)
    bool csd1;                

    //during constant stride detection, indicates that the second access arrives, 
    //and generates one distance to the last access, stores in stride¡. (1 bit)
    bool csd2;

    //indicates the stream is a repeated stream (1 bit)
    bool isrepeat;           

    /// default constructor
    MyStreamEntry()
      :original_baddr(0),
      region_start_baddr_dist(0),
      region_end_addr(0),
      last_access_addr_dist(0),
      stride(0),
      ascending(false),
      stride_confirm(0),
      csd1(0),
      csd2(0),
      isrepeat(0)
    { 
    }

    /// copy constructor
    MyStreamEntry(const MyStreamEntry& ent)
      :original_baddr(ent.original_baddr), 
      region_start_baddr_dist(ent.region_start_baddr_dist), 
      region_end_addr(ent.region_end_addr), 
      last_access_addr_dist(ent.last_access_addr_dist),
      stride(ent.stride),
      ascending(ent.ascending),
      stride_confirm(ent.stride_confirm),
      csd1(ent.csd1),
      csd2(ent.csd2),
      isrepeat(ent.isrepeat)
    {
        // Nothing
    }

    ~MyStreamEntry()
    {
        // Nothing
    }

    void Clear() {
      original_baddr = 0;
      region_start_baddr_dist = 0;
      region_end_addr = 0;
      last_access_addr_dist = 0;
      stride = 0;
      ascending = false;
      stride_confirm = 0;
      csd1 = 0;
      csd2 = 0;
      isrepeat = false;
   }
};

class MyStreamPrefetchTable : public LRUList<MyStreamEntry>
{
  private:
    ADDRINT _blocksize;
    INT32 _blockbitnum;
    UINT32 _pref_degree;
    UINT32 _stream_prefetch_distance;   
    UINT32 _loglevel;

 public:
  MyStreamPrefetchTable(UINT32 tableSize, ADDRINT blocksize, INT32 blockbitnum, 
			UINT32 prefdegree, UINT32 stream_prefetch_distance, UINT32 loglevel)
    : LRUList<MyStreamEntry>(tableSize){
      _blocksize = blocksize;
      _blockbitnum = blockbitnum;
      _pref_degree = prefdegree;
      _stream_prefetch_distance = stream_prefetch_distance;
      _loglevel = loglevel;
    }

    ~MyStreamPrefetchTable(){}

    void AddEntry(MyStreamEntry& newentry, COUNTER cycle);
    void ForwardTrackingRegion(MyStreamEntry* entry, UINT32 prefdistance, UINT32 prefdegree);
    MyStreamEntry* AccessTrainedEntry(ADDRINT baddr, bool& islatepref, COUNTER cycle);
    Node * FindPrefetchStream(ADDRINT baddr);
    MyStreamEntry * LookForRepeatStream(std::vector<CacheAddr_t>& requests, MyStreamEntry& newentry);

    void ShowMyStreamEntry(MyStreamEntry& entry);
    void ShowAll();

    ADDRINT GetRegionStartBAddrFromBDist(MyStreamEntry &entry) {
      if(entry.ascending)
	return ( (entry.region_end_addr >> _blockbitnum) - entry.region_start_baddr_dist);
      else
	return ( (entry.region_end_addr >> _blockbitnum) + entry.region_start_baddr_dist);
    }    
    void SetRegionStartBAddrToBDist(MyStreamEntry &entry, ADDRINT baddr) {
      if(entry.ascending)
	entry.region_start_baddr_dist = (entry.region_end_addr >> _blockbitnum) - baddr;
      else
	entry.region_start_baddr_dist = baddr - (entry.region_end_addr >> _blockbitnum);
    }    
    ADDRINT GetLastAccessAddressFromDist(MyStreamEntry &entry) {
      if(entry.ascending)
	return (entry.region_end_addr - entry.last_access_addr_dist);
      else
	return (entry.region_end_addr + entry.last_access_addr_dist);
    }
    void SetLastAccessAddressToDist(MyStreamEntry &entry, ADDRINT addr) {
      if(entry.ascending)
	entry.last_access_addr_dist = (entry.region_end_addr - addr);
      else
	entry.last_access_addr_dist = (addr - entry.region_end_addr);
    }
    ADDRINT GetBAddrFromAddr(ADDRINT addr) { return addr >> _blockbitnum; }
    ADDRINT GetAddrFromBAddr(ADDRINT baddr) { return baddr << _blockbitnum; }
};

class MyStream : public MSHRPrefetcher {
 private:
  UINT32 _pref_degree;
  UINT32 _stream_prefetch_distance;
  UINT32 _stream_table_size;
  UINT32 _stream_training_table_size;
  MyStreamPrefetchTable *_stream_table;
  MyTrainingTable *_training_table;
  bool _enable_stream_merge;
  bool _enable_constant_stride;
  bool _enable_noise_removal;
  bool _enable_stream_repeat;
  
 public:
  MyStream()
  {
    prefname="MyStream";
    _pref_degree = 4;
    _stream_prefetch_distance = 64;
    _stream_table_size = 128;
    _stream_training_table_size = 256;
    _enable_constant_stride = 1;
    _enable_noise_removal = 1;
    _enable_stream_repeat = 1;

    _stream_table = new MyStreamPrefetchTable(_stream_table_size, _blocksize, _blockbitnum, _pref_degree, _stream_prefetch_distance, _loglevel);
    _training_table = new MyTrainingTable(_stream_training_table_size, _blocksize, _blockbitnum, _enable_noise_removal,_stream_prefetch_distance);

    params["prefetch_degree"] = _pref_degree;
    params["stream_prefetch_distance"] = _stream_prefetch_distance;
    params["stream_table_size"] = _stream_table_size;
    params["stream_training_table_size"] = _stream_training_table_size;
    params["enable_constant_stride"] = _enable_constant_stride;
    params["enable_noise_removal"] = _enable_noise_removal;
    params["enable_stream_repeat"] = _enable_stream_repeat;
  }

  ~MyStream() {
    delete _stream_table;
    delete _training_table;
  }
  
  // local function
  ADDRINT GetBAddrFromAddr(ADDRINT addr) { return addr >> _blockbitnum; }
  ADDRINT GetAddrFromBAddr(ADDRINT baddr) { return baddr << _blockbitnum; }

  void MoveFromTrainingToStream(MyTrainingEntry& tent, MyStreamEntry& sent);
  void GetPrefetchCandidates(std::stringstream& ss, MyStreamEntry *entry, std::vector<CacheAddr_t>& requests, COUNTER cycle);
  void ConstantStrideRevise(MyStreamEntry* entry, ADDRINT addr);
  void ConstantStrideDetect(MyStreamEntry* entry, ADDRINT addr);
  void ClearConstantStride(MyStreamEntry* entry);
  void MyIssuePrefetches(std::stringstream& ss, MyStreamEntry *entry, COUNTER cycle);

  // overriden functions from base class
  void OnMiss(std::stringstream& ss, PrefetchData_t& pdata, std::vector<CacheAddr_t>& requests, COUNTER cycle = 0, bool isfiltered = false);
  void OnHit(std::stringstream& ss, PrefetchData_t& pdata, std::vector<CacheAddr_t>& requests, COUNTER cycle = 0);
  void OnPrefetchHit(std::stringstream& ss, PrefetchData_t& pdata, std::vector<CacheAddr_t>& requests, COUNTER cycle = 0);
  void GetSpecificStatistics();
};

#endif
