#include <stdio.h>
#include <iostream>

#include <map>
#include <string>

#include "interface.h"  // Do NOT edit interface .h
#include "sample_prefetcher.h"

void nop_fprintf(FILE* fp,...) { }
#define debug nop_fprintf
#define MSHRfprintf nop_fprintf
#define EVICTdebug nop_fprintf

#define SIGNED_COUNTER long long

const int theBlockSize(64);

const bool NoRotation = false;

struct Stats
{
	typedef std::map<std::string,long long> tCounters;
	tCounters theCounters;
	~Stats() {
		for(tCounters::iterator i=theCounters.begin();i!=theCounters.end();++i) {
			std::cerr << i->first << "," << i->second << std::endl;
		}
	}
} theStats;
#define C(n) do { ++(theStats.theCounters[(#n)]); } while(0);
#define S(n,v) do { (theStats.theCounters[(#n)])+=(v); } while(0);

struct MSHRs {
	unsigned int theSize;
	typedef std::map<CacheAddr_t,SIGNED_COUNTER> tEntries;
	tEntries theEntries;

	MSHRs(unsigned int aSize)
	:	theSize(aSize)
	{ }

	void clean(SIGNED_COUNTER cycle) {
		if (cycle<0) cycle = -cycle;
		tEntries::iterator next;
		for(tEntries::iterator i=theEntries.begin();i!=theEntries.end();i = next) {
			next = i;
			++next;
			if (GetPrefetchBit(0,i->first)>-1) {
				MSHRfprintf(stderr,"removing MSHR entry %llx, %d left (i->second=%lld, cycle=%lld)\n",i->first,theEntries.size()-1,i->second,cycle);
				if (i->second > 0) {
					S(mem_waited,cycle - i->second)
					assert(cycle >= i->second);
				}
				theEntries.erase(i);
			} else {
				MSHRfprintf(stderr,"keeping MSHR entry %llx, %d entries\n",i->first,theEntries.size());
			}
		}
	}

	bool inflight(CacheAddr_t anAddr) {
		anAddr &= ~63ULL;
		bool res = (theEntries.find(anAddr) != theEntries.end());
		MSHRfprintf(stderr,"MSHR in-flight check for %llx: %d\n",anAddr,res);
		return res;
	}

	bool allocate(CacheAddr_t anAddr, SIGNED_COUNTER cycle) {
		clean(cycle);
		anAddr &= ~63ULL;
		if (cycle<0 && GetPrefetchBit(0,anAddr)>-1) C(L1_Prefetch_Not_Needed)
		tEntries::iterator i = theEntries.find(anAddr);
		if (i != theEntries.end()) {
			if (cycle<0) {
				C(L1_Prefetch_Already_InFlight)
			} else if (i->second<0) {
				C(L1_Prefetch_PartialHit);
				S(mem_saved,cycle + i->second)
				i->second = cycle;
			} else {
				C(L1_OoO_Miss_Overlap)
			}
			MSHRfprintf(stderr,"MSHR for %llx exists\n",anAddr);
			return false;
		}
		if (theEntries.size() == theSize) {
			MSHRfprintf(stderr,"MSHRs are full, can't insert %llx (size=%d)\n",anAddr,theSize);
			C(MSHR_full)
			return true;
		}
		theEntries.insert(std::make_pair(anAddr,cycle));
		MSHRfprintf(stderr,"added MSHR entry %llx, %d present\n",anAddr,theEntries.size());
		return true;
	}

};

struct SMS
{
	unsigned theRegionShift,theRegionSize,theRegionMask,theBlocksPerRegion;

	typedef unsigned long long pattern_t;

	struct AGTent {
		CacheAddr_t pc;
		int offset;
		pattern_t pattern;
		AGTent(CacheAddr_t aPC = 0ULL,int aOffset = 0,pattern_t aPattern = 0ULL)
		:	pc(aPC)
		,	offset(aOffset)
		,	pattern(aPattern)
		{ }
		CacheAddr_t EVICT_DETECTOR;
	};
	struct PHTent {
		pattern_t pattern;
		PHTent(pattern_t aPattern = 0ULL):pattern(aPattern) { }
		CacheAddr_t EVICT_DETECTOR;
	};
	typedef Container<CacheAddr_t,AGTent> AGT; AGT theAGT;
	typedef Container<CacheAddr_t,PHTent> PHT; PHT thePHT;

	SMS()
		: theRegionShift(13) // shift: 12 = 4KB region, 10 = 1KB region
		, theRegionSize(1<<theRegionShift)
		, theRegionMask(theRegionSize-1)
		, theBlocksPerRegion(theRegionSize/theBlockSize)
		, theAGT(8,16,theRegionShift,27-theRegionShift)
		, thePHT(8,16,NoRotation?0:2,14)
	{
	}

	pattern_t rotate(int aBitIndex,int anOffset) {
		debug(stderr,"rotate(%d,%d [%d]) : ",aBitIndex,anOffset,theBlocksPerRegion);
		pattern_t res = 1ULL<<((aBitIndex + anOffset) % theBlocksPerRegion);
		debug(stderr,"%llx > %llx\n",1ULL<<aBitIndex,res);
		assert(!(res & (res-1)));
		return res;
	}

	bool replace( SIGNED_COUNTER cycle, CacheAddr_t addr ) {
		CacheAddr_t region_tag = addr & ~theRegionMask;
		int region_offset = (addr & theRegionMask)>>6;
		AGT::Item agt_evicted;
		//DBG_(Dev, ( << std::dec << theId << "-evict: group=" << std::hex << region_tag << "  offset=" << std::dec << region_offset ) );
		bool erased_something(false);

		AGT::Iter agt_ent = theAGT.find(region_tag);
		if (agt_ent != theAGT.end()) {
			pattern_t new_bit = rotate(region_offset,agt_ent->second.offset);
			if(NoRotation) new_bit = 1ULL << region_offset;
			if (agt_ent->second.pattern & new_bit) {
				agt_evicted = *agt_ent;
				//DBG_(Dev, ( << std::dec << theId << "-end: group=" << std::hex << region_tag << "  key=" << agt_evicted.second.pc << "  " << agt_evicted.second.pattern ) );
				theAGT.erase(region_tag);
				erased_something = true;
			}
		}

		if (agt_evicted.second.pattern) {
			if (thePHT.erase(agt_evicted.second.pc)) C(PHT_erased_previous); // if replacing or if it's a singleton
			if ((agt_evicted.second.pattern-1)&agt_evicted.second.pattern) {// not singleton
				debug(stderr,"learned pattern (block eviction) %llx into PHT, pc=%llx\n",agt_evicted.second.pattern,agt_evicted.second.pc);
				thePHT.insert(agt_evicted.second.pc,PHTent(agt_evicted.second.pattern));
			}
		}

		return erased_something;
	}

	void checkEvictions( SIGNED_COUNTER cycle, MSHRs* mshrs ) {
		for(int n=0;n<theAGT.theHeight;++n) {
invalidated_list:
			AGT::ListType& aList(theAGT.theItems[n]);
			int x=0;
			for(AGT::Iter i=aList.begin();i!=aList.end();i++) {
	 //fprintf(stderr,"bla %d n=%d x=%d\n",__LINE__,n,x);
	 //fprintf(stderr,"bla %llx\n",i->second.pattern);
				int offset = theBlocksPerRegion-1;
				for(pattern_t pattern = i->second.pattern;pattern;--offset) {
					pattern_t mask = (1ULL<<offset);
					if (!(pattern & mask)) continue;
					EVICTdebug(stderr,"  attempt offset=%d mask=%llx on region_offset=%d with pattern %llx\n",offset,mask,i->second.offset,pattern);
					pattern &= ~mask;
					CacheAddr_t prediction = ((-i->second.offset+offset)*theBlockSize);
					if(NoRotation) prediction = (offset*theBlockSize);
					EVICTdebug(stderr,"  prediction = %llx\n",prediction);
					prediction &= theRegionMask;
					CacheAddr_t anAddress = i->second.EVICT_DETECTOR+prediction;
					EVICTdebug(stderr,"  prediction&= %llx\n",prediction);
					EVICTdebug(stderr,"  prediction!= %llx\n",i->first + prediction);
					if (GetPrefetchBit(0,anAddress)==-1) {
						if (mshrs->inflight(anAddress)) {
							MSHRfprintf(stderr,"block still in flight (n=%d, x=%d), region=%llx offset=%d addr=%llx\n",n,x,i->first,i->second.offset,anAddress);
						} else {
							MSHRfprintf(stderr,"detected evicted block (n=%d, x=%d), region=%llx offset=%d, evicted addr=%llx\n",n,x,i->first,i->second.offset,anAddress);
							if (replace(cycle, anAddress)) goto invalidated_list;
						}
					}
				}
				++x;
			}
		}
	}

	void IssuePrefetches( SIGNED_COUNTER cycle, PrefetchData_t *Data, MSHRs* mshrs ) {
		CacheAddr_t pc(Data->LastRequestAddr);
		assert(pc);
		CacheAddr_t region_tag = Data->DataAddr & ~theRegionMask;
		int region_offset = (Data->DataAddr & theRegionMask)>>6;
		CacheAddr_t key = pc;
		if(NoRotation) key = (pc << (theRegionShift-6)) | region_offset;
		bool miss = (! Data->hit);

		//DBG_(Dev, ( << std::dec << theId << "-access: group=" << std::hex << region_tag << "  key=" << key << "  offset=" << std::dec << region_offset ) );
		debug(stderr,"region=%llx offset=%d bit=%llx pc=%llx\n",region_tag,region_offset,1ULL<<region_offset,pc);
		AGT::Iter agt_ent = theAGT.find(region_tag);
		AGT::Item agt_evicted;
		bool new_gen = false;
		if (agt_ent == theAGT.end()) {
			C(AGT_evict_replacement)
			pattern_t new_bit = 1ULL<<(theBlocksPerRegion-1);
			if(!NoRotation) {
				agt_evicted = theAGT.insert(region_tag,AGTent(key,theBlocksPerRegion-region_offset-1,new_bit));
			} else {
				new_bit = 1ULL << region_offset;
				agt_evicted = theAGT.insert(region_tag,AGTent(key,0,new_bit));
			}
			new_gen = true;
			debug(stderr,"new pattern (from scratch) new_bit=%llx offset=%d->%d\n",new_bit,region_offset,theBlocksPerRegion-region_offset);
		} else {
			pattern_t new_bit = rotate(region_offset,agt_ent->second.offset);
			if(NoRotation) new_bit = 1ULL << region_offset;
			if ((agt_ent->second.pattern & new_bit) && miss) {
				C(AGT_samebit_replacement)
				// FIXME: is same bit repeating the common case or not?
				debug(stderr,"collided on pattern %llx (new_bit=%llx)\n",agt_ent->second.pattern,new_bit);
				/* same-bit ends gen logic
				agt_evicted = *agt_ent;
				if(!NoRotation) {
					agt_ent->second = AGTent(key,theBlocksPerRegion-region_offset-1);
				} else {
					agt_ent->second = AGTent(key,0);
				}
				new_gen = true;
				*/
			} else {
				C(AGT_addbit)
			}
			agt_ent->second.pattern |= new_bit;
			debug(stderr,"update pattern:%llx new_bit=%llx offset=%d\n",agt_ent->second.pattern,new_bit,agt_ent->second.offset);
		}

		if (agt_evicted.second.pattern) {
			if (thePHT.erase(agt_evicted.second.pc)) C(PHT_erased_previous); // if replacing or if it's a singleton
			if ((agt_evicted.second.pattern-1)&agt_evicted.second.pattern) {
				// not singleton
				debug(stderr,"learned pattern (AGT eviction) %llx into PHT, pc=%llx\n",agt_evicted.second.pattern,agt_evicted.second.pc);
				thePHT.insert(agt_evicted.second.pc,PHTent(agt_evicted.second.pattern));
			}
		}

		if (new_gen) {
			PHT::Iter pht_ent = thePHT.find(key);
			if (pht_ent != thePHT.end()) {
				//DBG_(Dev, ( << std::dec << theId << "-predict: group=" << std::hex << region_tag << "  key=" << key << "  " << pht_ent->second.pattern ) );
				C(L1_Found_Pattern)
				debug(stderr,"prediction pattern for pc=%llx is %llx and region_offset %d\n",key,pht_ent->second.pattern,region_offset);
				int offset = theBlocksPerRegion-2; // extra -1 to avoid prefetch of trigger
				if(NoRotation) offset += 1;
				//assert((pht_ent->second.pattern-1)&pht_ent->second.pattern);
				for(pattern_t pattern = pht_ent->second.pattern;pattern;--offset) {
					//debug(stderr,"  prediction at offset %d, pattern is %llx and region_offset %d\n",offset,pht_ent->second.pattern,region_offset);
					pattern_t mask = (1ULL<<offset);
					debug(stderr,"  attempt offset=%d mask=%llx on region_offset=%d with pattern %llx\n",offset,mask,region_offset,pattern);
					if (!(pattern & mask)) continue;
					pattern &= ~mask;
					if (NoRotation && (offset == region_offset)) continue;
					debug(stderr,"  prediction at offset %d, pattern is %llx and region_offset %d\n",offset,pht_ent->second.pattern,region_offset);
					CacheAddr_t prediction = ((region_offset+offset+1)*theBlockSize);
					if(NoRotation) prediction = (offset*theBlockSize);
					debug(stderr,"  prediction = %llx\n",prediction);
					prediction &= theRegionMask;
					debug(stderr,"  prediction&= %llx\n",prediction);
					debug(stderr,"  prediction!= %llx\n",region_tag + prediction);
					IssueL2Prefetch(cycle,region_tag + prediction);
					//mshrs->allocate(region_tag + prediction, -cycle); only for stats, don't waste MSHRs
					C(L1_Prefetches_Issued)
				}
			}
		}

	}

};

SMS *sms;
MSHRs *mshrs;

void InitPrefetchers() // DO NOT CHANGE THE PROTOTYPE
{
	sms = new SMS();
	mshrs = new MSHRs(16);
}

void IssuePrefetches( COUNTER cycle, PrefetchData_t *L1Data, PrefetchData_t * L2Data )
{
	if (!cycle) return;
	sms->checkEvictions( cycle, mshrs );
	if (cycle%100000 == 0) {
		fprintf(stdout,"cycle %lld (%.1f%% at IPC=1.0)\n",cycle,cycle/1000000.);
		fflush(stdout);
	}
	for(int i = 0; i < 4; i++) {
		if(cycle == L1Data[i].LastRequestCycle) {
			int pbit = GetPrefetchBit(0,L1Data[i].DataAddr);
			if (!L1Data[i].hit || (pbit==1)) {
				bool mshr_allocated = false;
				if (!L1Data[i].hit) {
					mshr_allocated = mshrs->allocate(L1Data[i].DataAddr,cycle);
					if (mshr_allocated)
						C(L1_Total_Misses)
					else {
						C(L1_Total_MSHR_Hits)
						assert(!L1Data[i].hit);
						L1Data[i].hit = 1;
					}
					debug(stderr,"%lld: %llx missed on %llx (pbit=%d)\n",cycle,L1Data[i].LastRequestAddr,L1Data[i].DataAddr,pbit);
				} else {
					C(L1_Total_Prefetch_Hits)
					assert(pbit==1);
					debug(stderr,"%lld: %llx prefetch hit on %llx (pbit=%d)\n",cycle,L1Data[i].LastRequestAddr,L1Data[i].DataAddr,pbit);
				}
			} else {
				C(L1_Total_Hits)
				debug(stderr,"%lld: %llx hit on %llx (pbit=%d)\n",cycle,L1Data[i].LastRequestAddr,L1Data[i].DataAddr,pbit);
			}
			sms->IssuePrefetches(cycle, &L1Data[i], mshrs);
			UnSetPrefetchBit(0,L1Data[i].DataAddr);
		}
	}

	for(int i = 0; i < 4; i++) {
		if(cycle == L2Data[i].LastRequestCycle) {
			if (L2Data[i].hit) {
				C(L2_Total_Hits)
				if (GetPrefetchBit(1,L2Data[i].DataAddr)==1) C(L2_Total_Prefetch_Hits)
			} else {
				C(L2_Total_Misses)
			}
			if (!L2Data[i].hit || (GetPrefetchBit(1,L2Data[i].DataAddr)==1)) {
				UnSetPrefetchBit(1,L2Data[i].DataAddr);
			}
		}
	}

}
