"""
multi(k)-step Sarsa with tile coding on the mountain car task
"""

import mountaincar
from tiles import tiles, CollisionTable
from pylab import *

numTilings = 16
alpha = 0.5/numTilings
gamma = 1
epsilon = 0
k = 5

def mcarTiles(iht, s, a, readonly=False):
    x,xdot = s
    return tiles(iht, numTilings,(8*x/(0.5+1.2), 8*xdot/(0.07+0.07)), (,a), readonly)

def egreedy(F):
    if random() < epsilon:
        return randint(3)
    else:
        return argmax(Qs(F))

def Q(F, a):
    sum = 0
    for f in F:
        sum += w[a*numTiles + f]
    return sum

def Qs(F):
    return [Q(F,a) for a in range(3)]

numRuns = 10
numEpisodes = 500

runSum = 0.0
for run in range(numRuns):
    w = -0.01*rand(n)
    iht = IHT(numTiles, 'safe')
    returnSum = 0.0
    for episodeNum in range(numEpisodes):
        e = zerovec
        G = 0
        S = mountaincar.init()
        F = mcarTiles(S)
        A = egreedy(F)
        for step in xrange(100000):         # for each step of the episode
            for f in F: e[A*numTiles + f] = 1
            R,Sprime = mountaincar.sample(S,A)
            G = G + R
            delta = R - Q(F,A)
            if Sprime==None:
                w += alpha*delta*e
                break
            F = mcarTiles(Sprime)
            Aprime = egreedy(F)              
            delta += gamma * Q(F,Aprime)
            w += alpha*delta*e
            e = gamma*lmbda*e
            S,A = Sprime,Aprime
#        print "Episode: ", episodeNum, "Return: ", G
        returnSum = returnSum + G
    print "Average return:", returnSum/numEpisodes
    runSum += returnSum
    print ctable
print "Overall average return:", runSum/numRuns/numEpisodes

def writeF(filename):
    fout = open(filename, 'w')
    steps = 50
    for i in range(steps):
        for j in range(steps):
            S = (-1.2+i*1.7/steps, -0.07+j*0.14/steps)
            Fs = mcarTiles(S)
            height = -max(Qs(F))
            fout.write(repr(height) + ' ')
        fout.write('\n')
    fout.close()

writeF('value')
