```"""
Experimental Implementation of Natural Actor Critic
"""
import numpy as np
from .Agent import Agent
from rlpy.Tools import solveLinear, regularize

__credits__ = ["Alborz Geramifard", "Robert H. Klein", "Christoph Dann",
"William Dabney", "Jonathan P. How"]
__author__ = "Christoph Dann"

class NaturalActorCritic(Agent):

"""
the step-based Natural Actor Critic algorithm
as described in algorithm 1 of
Peters, J. & Schaal, S. Natural Actor-Critic.
Neurocomputing 71, 1180-1190 (2008).
"""

# minimum for the cosine of the current and last gradient
min_cos = np.cos(np.pi / 180.)

def __init__(self, domain, policy, representation, forgetting_rate,
learn_rate):
"""
@param representation: function approximation used to approximate the
value function
@param policy:  parametrized stochastic policy with parameters
policy.theta
@param domain: domain to investigate
@param forgetting_rate: specifies the decay of previous statistics
after a policy update; 1 = forget all
0 = forget none
@param min_steps_between_updates: minimum number of steps between
@param lambda_:    e-trace parameter lambda
@param learn_rate:  learning rate

"""

self.samples_count = 0
self.forgetting_rate = forgetting_rate
self.n = representation.features_num + len(policy.theta)
self.representation = representation
self.lambda_ = lambda_
self.learn_rate = learn_rate

self.b = np.zeros((self.n))
self.A = np.zeros((self.n, self.n))
self.buf_ = np.zeros((self.n, self.n))
self.z = np.zeros((self.n))

super(NaturalActorCritic, self).__init__(domain, policy,
representation)

def learn(self, s, p_actions, a, r, ns, np_actions, na, terminal):

# compute basis functions
phi_s = np.zeros((self.n))
phi_ns = np.zeros((self.n))
k = self.representation.features_num
phi_s[:k] = self.representation.phi(s, False)
phi_s[k:] = self.policy.dlogpi(s, a)
phi_ns[:k] = self.representation.phi(ns, terminal)

# update statistics
self.z *= self.lambda_
self.z += phi_s

self.A += np.einsum("i,j", self.z, phi_s - self.domain.discount_factor * phi_ns,
out=self.buf_)
self.b += self.z * r
if terminal:
self.z[:] = 0.
self.logger.debug("Statistics updated")

A = regularize(self.A)
param, time = solveLinear(A, self.b)
#  v = param[:k]  # parameters of the value function representation
w = param[k:]  # natural gradient estimate

# update policy
self.policy.theta = self.policy.theta + self.learn_rate * w
self.last_w = w
self.logger.debug(
"Policy updated, norm of gradient {}".format(np.linalg.norm(w)))
# forget statistics
self.z *= 1. - self.forgetting_rate
self.A *= 1. - self.forgetting_rate
self.b *= 1. - self.forgetting_rate

if terminal:
self.episodeTerminated()