```#!/usr/bin/env python
import pdb
import numpy as np
#from eoldas_model import subset_vector_matrix
from eoldas_State import State
from eoldas_Lib import sortopt,sortlog
from eoldas_ParamStorage import ParamStorage
import resource
from eoldas_Spectral import Spectral
from eoldas_Files import writer
import os
from os.path import dirname

class Operator ( ParamStorage ):
"""
A generic class for the EOLDAS operator
"""
def __init__ ( self,options,general,parameter=None,\
datatype={'x':'x','y':'y'},\
state={'x':None,'y':None},\
sd={'x':None,'y':None},\
names={'x':None,'y':None},\
control={'x':None,'y':None},\
location={'x':None,'y':None},\
name=None,\
env=None,logger=None,\
"""

This sets up the data for an Operator

The main purpose is to load self.x and self.y
with associated metadata in self.x_meta and self.y_meta

The Operator is set up to calculate, among other things, as cost:

J = 0.5 (y - H(x))^T (C_y^-1 + C_H(x)^-1) (y - H(x))

where y might be an observation and H(x) is a modelled version
of that observation.

For example, for a prior Operator, we want:

J = 0.5 (x_prior - x)^T C_prior^-1 (x_prior - x)

so when we are solving for this, we load candidate state estimates
into x, with C_x^-1 = 0 with C_y^-1 = C_prior^-1 and prior terms
into y, with H(x) = I(x), and identity operation.

For some model e.g. x_model = A + B x, we have:

J = 0.5 (x_model - x)^T C_model^-1 (x_model - x)

For a linear model, x_model = A + B x so we have:

J = 0.5 (A + (B -I)x)^T C_model^-1 (A + (B - I)x)

so we load the vector A into y, with C_y^-1 = 0 and H(x) = (B-I)x
with C_x^-1 = C_model^-1

For an observation operator, we want:

J = 0.5 (y_obs - Y(x))^T C_obs^-1 (y_obs - Y(x))

so we load the observations into y,
and set H(x) = Y(x), the observation operator.

Note that we always load the new state extimate into the x term.

To summarise these cases wrt initial setup:

For the prior example:       y = x_prior, x = x_new, H(x) = Ix
For the model example:       y = A,       x = x_new, H(x) = (B-I)x
For the observation example: y = y_obs,   x = x_new, H(x) = Y(x)

The cost function derivative, the Jacobian wrt each of the model
state vector elements is, ignoring derivatives of C_H(x)^-1:

J' = dJ/dx = (y - H(x))^T (C_y^-1) H'(x)

where H'(x) = dH(x)/dx

The second order differential, the Hessian is a square matrix:

J'' = d^2J/(dx_idx_j) = H'(x)^T (C_y^-1) H'(x) +
(y - H(x))^T (C_y^-1) H''(x)

or if we ignore H''(x), simply:

J'' = d^2J/(dx_idx_j) = H'(x)^T (C_y^-1) H'(x)

We assume that a config file or command line parsing
has been done and the required information is set up in
options. This will include

x.names   :   state variable names
x.state   :   state mean
x.sd      :   state sd

and may include:

x.control     :   control variables, e.g.
x.location    :   location variables e.g. [time]
solve       :   Codes associated with solver options for the
model parameters.
An integer array of the same length as
names

as well as any other information that may be required,
generally available through setup.

The information in self.options is used to initialise the
x data from self.options.xstate and the y data from
self.options.ystate

The result of initialising the Operator then is to:
-- start logging
-- load any data into self.x.state and self.y.state
and possibly self.x.sd, self.y.sd
from various formats that state can be.
-- load location and control information
self.x.location
self.x.control (& y data)
self.x_meta.control etc

"""
if name == None:
import time
thistime = str(time.time())
name = type(self).__name__
name =  "%s.%s" % (name,thistime)
self.thisname = name

if type(datatype) != dict:
# maybe its a list
if type(datatype) == list:
datatype = dict(np.repeat(np.array(datatype),2)\
.reshape((len(datatype),2)))
elif type(datatype) == str:
datatype = {datatype:datatype}
else:
return
self.options = options
self.general = general

for (k,v) in this.iteritems():
self.general[k] = sortopt(self.general,k,v)

env     = env     or self.general.env
logfile = logfile or self.general.logfile
logdir  = logdir  or self.general.logdir
bounds = 'False'
self.logger = sortlog(self,logfile,logger,name=self.thisname,\
logdir=logdir,debug=True)
self.isLinear = False
#  set up x & y states
thisname = name
kk = datatype.keys()
kk.sort()
# need to make sure y values are before the x ones
# so sort & reverse
for t in kk[::-1]:
if parameter != None and t in parameter.dict():
t_meta = "%s_meta"%t
name = {t:thisname}
this = {'datatype':t,'name':name,\
'state':parameter[t].state,'sd':parameter[t].sd,\
'control':parameter[t_meta].control,\
'location':parameter[t_meta].location,\
'bounds':parameter[t_meta].bounds,\
'names':parameter[t_meta].state}
else:
name = {t:thisname}
this = {'datatype':t,'name':'Billy no mates %s'%t,\
'state':None,'sd':None,\
'control':None,'location':None,\
'bounds':'False',\
'names':['dummy']}
# tt is what we will store it as e.g. x_state
for (k,v) in this.iteritems():
#kk = "%s%s"%(t,k)
# if k exists in parameter, use that as v
if parameter != None and not t in parameter.dict() and\
k in parameter.options.dict():
v = parameter.options[k]
self.options[t][k] = \
sortopt(self.options[t],k,v)
self.options[t][k] = (t in eval(k) and eval(k)[t]) \
or self.options[t][k]
# need to tmp load it directly to options
if type(self.options[t][k]).__name__ != 'NoneType':
self.options[k] = self.options[t][k]
tt = "%s_state"%t
self.options.name = "%s-%s"%(self.options.name,t)
self.options.bounds = self.options[t].bounds
# Set up State
self[tt] = State(options,logger=logger,\
logdir=logdir,env=env,\
name=self.options.name,grid=self.general.grid)
# a flag you can use to see if 'something' is loaded
has_state = False
if (type(self.options[t]) == ParamStorage and \
'state' in self.options[t].dict().keys()) or\
(type(self.options[t]) == dict and \
state in self.options[t].keys()):
self[tt].state = self.options[t].state
try:
if self[tt].state.size == 1 and \
self[tt].state == None or self[tt].state == np.array(None):
#raise Exception('Failed to recognise state data ... check the log')
has_state = False
else:
has_state = True
except:
if self[tt].state.size > 1:
has_state = True
if has_state == True and self[tt].state.size == 1 and \
self[tt].state == np.array(None):
has_state = False

from_parameter = False
if parameter != None and t == 'x' and self[tt].state.size == \
parameter[tt].state.size:
# must have loaded from parameter
self.x_state.state = self.x.state
from_parameter = True

if from_parameter == False and has_state == False:
# then we'd better set up a grid from the defaults
if not self[tt].options[t].apply_grid:
self[tt].options[t].apply_grid = True
State.regrid(self[tt])
has_state = True

#if self[tt].Data.state == np.array(None):
#    has_state = False
if has_state:
self.options[t].sd = sortopt(self.options[t],'sd',np.zeros_like(self[tt].Data.state))
try:
self[tt].Data.sd = sortopt(self[tt].Data,'sd',\
np.array([float(v) for v in np.array(self.options[t].sd).flatten().tolist()]))
except:
raise Exception('Error interpreting data for %s: problem reading or parsing'%self.thisname)
self[tt].Name.is_grid = False
if has_state and 'apply_grid' in self[tt].options[t].dict() and \
self[tt].options[t].apply_grid:
if 'gridder' in self[tt].Data.dict():
# if gridded, load the grid data
self[tt].Data.state = self[tt].Data.gridder.grid
self[tt].Data.sd = self[tt].Data.gridder.sdgrid
self[tt].Name.is_grid = True
if from_parameter:
self[tt].Name.is_grid = True

if from_parameter or not self[tt].Name.is_grid:
# override with options data if there
self[tt].Data.sd = \
sortopt(self.options[t],'sd',self[tt].Data.sd)
# check the size of sd
sd1 = np.array(self[tt].Data.sd)
r = self[tt].Data.state.astype(float)
# ensure sd is float
self[tt].Data.sd = np.array(self[tt].Data.sd)
shape = self[tt].Data.sd.shape
sd = np.zeros_like(self[tt].Data.sd).astype(float).flatten()
for (i,v) in enumerate(self[tt].Data.sd.flatten()):
sd[i] = float(v)
self[tt].Data.sd = sd.reshape(shape)
sd1 = np.array(self[tt].Data.sd)
if has_state and (not self[tt].Name.is_grid or from_parameter)\
and sd1.size != r.size:
# try to match up sd with data shape & size
if not self[tt].Name.is_grid:
try:
(n_samples,n_states) = r.shape
except:
if r.size == 0:
if 'logger' in self.dict():
self.logger.error("no data points")
import sys
sys.exit(-1)
n_samples = 1
n_states  = r.size
else:
n_states = r.shape[-1]
n_samples = r.size/n_states

if parameter:
n = parameter.options["%ssd"%t].size
else:
n = r[:,0].size
if sd1.size == 1:
self[tt].Data.sd = np.zeros_like(r) + sd1.flatten()[0]
elif n != n_states:
if 'logger' in self.dict():
self.logger.info("Warning: some isses with state" + \
" and sd data. They dont look the same size " +\
"(%d vs %d) ... tiling accordingly"%(n,n_states))
# have to be careful here
#
self[tt].Data.sd = \
np.tile(np.array(sd1),n_samples)\
.reshape((n_samples,sd1.shape[0]))
if from_parameter:
self[tt].Data.sd = self[tt].Data.sd.reshape(r.shape)
# sort inverse correlation matrix if its from sd

if has_state:
self[tt].Data.C1 = self[tt].Data.sd*0.
ww = np.where(self[tt].Data.sd >0)
self[tt].Data.C1[ww] = \
1./(self[tt].Data.sd[ww] * self[tt].Data.sd[ww])

self[t]           = self[tt].Data   #  .to_dict()
self["%s_meta"%t] = self[tt].Name   #  .to_dict()
# If there is a y datatype, load spectral info

kk = datatype.keys()
kk.sort()
for t in kk[::-1]:
if t == 'y':
self.y_meta.spectral = Spectral(self,name=self.thisname)
self.linear = ParamStorage()
self.linear = ParamStorage()
if 'y' in self.dict():
self.linear.H = np.zeros(self.y.state.shape)
else:
self.linear.H = np.zeros(self.x.state.shape)

self.epsilon = sortopt(self.general,'epsilon',1e-15)
# a hook to allow users not to have to write a whole new __init__
self.J_prime_approx = None
self.doPlot = sortopt(self.general,'doplot',False)
self.plotMod = sortopt(self.general,'plotmod',10)
self.plotMovie = sortopt(self.general,'plotmovie',False)
self.showPlot = sortopt(self.general,'showplot',False)
self.plotMovieFrame = 0
# check to see if there is anything worth plotting
# in this operator
# this is determined by whether a filename is
# specified for output
try:
self.plot_filename = \
sortopt(self.y_state.options.y.result,\
'filename',self.thisname) + '.plot'
except:
try:
self.plot_filename = \
sortopt(self.x_state.options.x.result,\
'filename',self.thisname) + '.plot'
except:
try:
self.plot_filename = \
sortopt(self.options.result,\
'filename',self.thisname) + '.plot'
except:
self.doPlot = False
if 'plot_filename' in self.dict():
outdir = os.path.dirname(self.plot_filename)
try:
os.makedirs(outdir)
except:
pass
self.Hcount = 0

return

'''
A method that gets accessed after data
are loaded. You should think of this as a hook
so that you don't have to write a whole new __init__
method.

In this default method, we use it to declare that
the operator is linear (self.isLinear = True)
and also to set up H_prime for an identity operator.

'''
self.isLinear = True
try:
self.linear.H_prime = np.eye(self.y.state.size)
except:
self.linear.H_prime = np.eye(self.x.state.size)

'''
A method that gets accessed before data
are loaded. You should think of this as a hook
so that you don't have to write a whole new __init__
method.

'''
# no approximate method for J_prime is appropriate here
# as J_prime is trivial for an identity operator
self.J_prime_approx = None

def H(self,x):
'''
The default operator: an identity operator

This should return self.linear.H with dimensions the same
as y if there is a y state. Otherwise, it is assumed of size x.
'''
self.linear.H = x
return self.linear.H

def H_prime(self,x):
'''
The default differential operator

dH(x)/dx

Here, we return a full matrix but that is not always needed
as it can be large for large problems

It may be called by J_prime_prime if there
is no better way to calculate J_prime_prime

If you want to circumvent
'''
if self.isLinear and 'H_prime' in self.linear.dict():
return self.linear.H_prime
xfull,Cx1,xshape,y,Cy1,yshape = self.getxy()
xshape=self.x.state.shape
xnparam = xshape[-1]
xloc = x.size/xnparam
try:
nb = self.y.state.shape[-1]
except:
nb = 1
# this is to represent the x state
xx = x.reshape((xloc,xnparam))
# this is the output:
# xloc x xnparam x nb
out = np.zeros((np.array([xloc,xnparam,xloc,nb])))
if self.isLinear:
self.Hx = self.linear.H
else:
self.Hx = self.H(x)
d = np.zeros((xloc,nb))
# we add delta over all locations at the
# same time here
for i in xrange(xnparam):
x1 = xx.copy()
delta = self.delta[:,i].copy()
d[:] = 0.
x1[:,i] = xx[:,i] + delta[:]
ww = np.where(x1[:,i] > self.xmax[:,i])[0]
x1[ww,i] = xx[ww,i] - delta[ww]
ww = np.where(x1[:,i] < self.xmin[:,i])
x1[ww,i] = self.xmin[ww,i]
delta = (x1 -xx)[:,i]
www = np.where(delta != 0)
H1x = self.H(x1.flatten())
for j in xrange(nb):
try:
out[www,i,www,j] = (H1x - self.Hx)[www,j].flatten()/delta[www]
except:
out[www,i,www,j] = (H1x - self.Hx).reshape(xshape)[www,j].flatten()/delta[www]
self.linear.H_prime = out.reshape([xnparam*xloc,xloc*nb])
return self.linear.H_prime
# checked: this is the correct way to stack these

'''
Utility to return a full state vector representation
from one that may be only partial

If sum is True, then if multiple items appear per location
they are summed. That is a bit fiddly, but is needed

If sum is False, then only the first instance we come across

which will normally be the hessian
'''
#if self.thisname == 'eoldas.solver.eoldas.solver-obs': # and self.x.state[0].sum() == 1:
if M == True:
n = fullx.flatten().shape[0]
ww = np.where(MM)[0]
count = 0
stmp = np.array(smallx).reshape((self.x.state.shape)*2)
# ww is the list of time slots/params that will need
# step over each
if 'xlookup' in self.dict():
for i in self.xlookup[count]:
jc = 0
jc += 1
if sum:
out[ns,j,:,:] +=  this
else:
out[ns,j,:,:] =  this
else:
out = out.reshape(n,n)
count += 1
return out.reshape(n,n)
out = fullx*0.
if 'xlookup' in self.dict():
stmp = smallx.reshape(self.x.state.shape)
if sum:
# This kept me puzzled for ages ...
# Its when when have multiple terms
# per observation period
# Lewis, UCL 23 Aug 2011
stmpt = stmp.T
# this reduces smallx to the size of the y state
smallx = stmp[self.xunlookup].T
for i in xrange(len(self.xunlookup)):
for j in xrange(1,len(self.xlookup[i])):
smallx[...,i] = smallx[...,i] + \
stmpt[...,self.xlookup[i][j]]
smallx = smallx.T
else:
smallx = stmp[self.xunlookup]
return out
raise Exception(\

'''
Utility to load the state vector required here
(self.input_x) from fullx

'''
from eoldas_Lib import isfloat
if 'xlookup' in self.dict():
self.xlookup = np.atleast_1d(self.xlookup)
for i in xrange(len(self.xlookup)):
self.x.state[self.xlookup[i]] = state[i]
else:
self.x.state = state
return True
if not 'x_meta' in self.dict():
self.x_meta = ParamStorage()
self.x_meta.state = self.x_state._state.name.state
self.x_meta.location = self.x_state._state.name.location
try:
self.x_meta.is_grid = self.options.x.apply_grid
except:
self.x_meta.is_grid = self.options.x.apply_grid = False
if not 'x' in self.dict():
self.x = ParamStorage()
to_names = self.x_meta.state
from_names = fullx.x_meta.state
to_loc = self.x_meta.location
from_loc = fullx.x_meta.location
to_names = np.atleast_1d(to_names)
n_param = len(to_names)
#
# this is used in case a full size array is passed
if not (from_loc == to_loc).all():
# cant load as location bases not the same
self.logger.error("Cannot load state vector as location bases are different")
return False
if self.x_meta.is_grid and fullx.x_meta.is_grid:
# a straight load but may be different parameters
self.x_meta.state = np.atleast_1d(self.x_meta.state)
fullx.x_meta.state = np.atleast_1d(fullx.x_meta.state)
if np.in1d(self.x_meta.state,fullx.x_meta.state).all() and \
len(self.x_meta.state) == len(fullx.x_meta.state):
self.x.state = fullx.x.state
else: # downsizing the gridded dataset
xshape = np.array(fullx.x.state.shape)
xshape[-1] = self.x_meta.state.size
xshape = tuple(xshape)
self.x.state = np.zeros(xshape)
for i in xrange(self.x_meta.state.size):
ww = np.where(self.x_meta.state[i] == fullx.x_meta.state)[0]
self.x.state[...,i] = fullx.x.state[...,ww].reshape(\
self.x.state[...,i].shape)
self.novar = True
self.logger.info('There appears to be no valid data ... ignoring this Operator')
return False

elif fullx.x_meta.is_grid:
# so we are coming from a grid to a non grid
# which is the most likely case
# we only want to load those samples where
# location matches
# now set to_loc to the actual location data
from_loc = np.atleast_1d(from_loc)
nd = len(from_loc)
# we need to get the qlocation info from somewhere
if not 'qlocation' in self.x and \
not 'qlocation' in self.x_state:
try:
for i in ['qlocation','location']:
self.x[i] = self.y[i]
self.x_meta[i] = self.y_meta[i]
except:
self.logger.info('incomplete x state definition .. a prior? copying from parameter')
self.x.qlocation = fullx.x_state.ungridder.qlocation.copy()
self.x.location = fullx.x_state.ungridder.location.copy()
# so probably need to fix y
for i in ['qlocation','location']:
self.x_meta[i] = fullx.x_meta[i]
try:
for i in ['qlocation','location']:
self.y_meta[i] = self.x_meta[i]
self.y[i] = self.x[i]
if self.y.state.size == self.y_meta.state.size:
# then we have only a partial load of y state
# probably because we had no location info before
nloc = self.x.location.shape[0]
self.y.state = np.tile(self.y.state,nloc).reshape(nloc,self.y_meta.state.size)

if self.y.sd.size == self.y_meta.state.size:
# then we have only a partial load of y state
# probably because we had no location info before
nloc = self.x.location.shape[0]
self.y.sd = np.tile(self.y.sd,nloc).reshape(nloc,self.y_meta.state.size)
if self.y.C1.size == self.y_meta.state.size:
# then we have only a partial load of y state
# probably because we had no location info before
nloc = self.x.location.shape[0]
self.y.C1 = np.tile(self.y.C1,nloc).reshape(nloc,self.y_meta.state.size)

except:
pass
#self.x_meta. fullx.x_meta
#raise Exception('Illegal x state definition')

try:
to_loc = self.x.qlocation
#- fullx.x_meta.gridder.grid[0:nd]
except:
to_loc = self.x_state.qlocation
#- fullx.x_meta.gridder.grid[0:nd]
self.x.qlocation = self.x_state.qlocation

try:
except:
# x state is not yet set, so we must work out itsd
# shape by other means
self.x.qlocation = np.atleast_1d(self.x.qlocation)
,self.x_meta.state.size)

self.x.fullstate = self.x.state
to_loc = np.atleast_1d(to_loc)
for i in xrange(len(to_loc)):
for j in xrange(n_param):
n = np.where(to_names[j] == from_names)[0][0]
tup = tuple(np.append(to_loc[i],n))
#tuple(np.concatenate([to_loc[i],[n]]))
self.x.state[i,j] = fullx.x.state[tup]

state = m.reshape(m.size/n_param,n_param)

location = m.reshape(m.size/self.x.location.shape[1],\
self.x.location.shape[1])
qlocation = m.reshape(m.size/self.x.location.shape[1],\
self.x.location.shape[1])

# The length of the x vector is the same as the y
# but y may contain multiple observations
# for a given location
# So, we form a list xlookup that contains the
# indices of the multiple y terms for x
# If there is no y vector then they have one element
# each, but more generally they will have
# multiple y terms
xlookup = [[] for i in xrange((m.size/self.x.location.shape[1]))]

for i in xrange(len(np.atleast_1d(xlookup))):
thisloc = str(qlocation[i,:])
for j in np.where(self.x.qlocation[:,0] \
== qlocation[i,0])[0]:
if str(self.x.qlocation[j]) == thisloc:
xlookup[i].append(j)

self.xlookup = xlookup
# xunlookup is the opposite of xlookup but only
# gives the index of the first element on the list
# for each location
self.xunlookup = np.array([i[0] for i in self.xlookup])
# now expand state to self.x.state
for i in xrange(len(np.atleast_1d(self.xlookup))):
self.xlookup[i] = np.array(self.xlookup[i])
self.x.state[self.xlookup[i]] = state[i]

# make sure we store some bounds info in a convenient manner
# required eg for finite differerence approximations
self.epsilon = sortopt(self.general,'epsilon',1e-15)
try:
bounds = self.x_meta.bounds
except:
try:
bounds = self.options.bounds
except:
raise Exception('Bounds information not defined')
if self.x_meta.is_grid and fullx.x_meta.is_grid:
ww = where_grid_data
self.logger.debug('Coming from a gridded dataset to a gridded dataset')
elif fullx.x_meta.is_grid:
self.logger.debug('Coming from a gridded dataset to a non gridded dataset')
ww = tuple(to_loc[0])
else:
self.logger.debug('Coming from a non gridded dataset to a non gridded dataset')
# it should be an int or float
to_loc = np.atleast_1d(to_loc)
ww = to_loc[0]
try:
ok,dummy = isfloat(ww)
if not ok:
# but it might not be for some strange reason
# shouldnt be here, but give it a go anyway
self.logger.error('Having trouble interpreting dataset')
self.logger.error(' ... have ended up somewhere unexpected, but will keep trying')
ww = tuple([0]*len(to_loc))
except:
ww = tuple(ww*0)
self.logger.error('Issues with this dataset for this operator, but will keep trying')
ww = tuple([0]*len(to_loc))
if 'x_state' in self.dict():
# are there any data here?
if nn == 0:
self.novar = True
self.logger.info('There appears to be no valid data ... ignoring this Operator')
return False

dd = self.epsilon*(bounds[:,1]-\
bounds[:,0])

if self.delta.size == 0:
# no data here ..?
self.novar = True
self.logger.info('There appears to be no valid data ... ignoring this Operator')
self.xranger = self.delta
return False
if self.delta.size == self.x.state.shape[-1]:
ww = tuple(np.array(self.x.state.shape)[:-1])
self.delta = np.tile(self.delta,ww).reshape(self.x.state.shape)
self.xmin = np.tile(self.xmin,ww).reshape(self.x.state.shape)
self.xmax = np.tile(self.xmax,ww).reshape(self.x.state.shape)
try:
self.xranger = self.xmax[0] - self.xmin[0]
except:
raise Exception('An error occurred setting eoldas bounds in the Operator')
return True

def H_prime_prime(self,x):
'''
Not yet implemented

d^2H(x)/dxi dxj
'''
if isLinear and 'linear' in self.dict():
return self.linear.H_prime_prime
elif not 'linear' in self.dict():
self.linear = ParamStorage()
return None

def getxy(self,diag=False):

xshape = self.x.state.shape
x = self.x.state.astype(float).flatten()
if 'x' in self.dict():
# we mean sd here even though its called Cx1
Cx1 = self.x.sd.astype(float)
else:
Cx1 = 0.*x
if len(np.array(Cx1).flatten()) != len(np.array(x).flatten()):
# then may have to unpack to a bigger version of Cx1
if 'xlookup' in self.dict():
try:
Cx1 = Cx1.reshape(len(self.xlookup),xshape[-1])
Cx1a = np.zeros_like(self.x.state)
for i in xrange(len(self.xlookup)):
Cx1a[self.xlookup[i],:] = Cx1[i,:]
Cx1 = Cx1a
except:
pass
if 'y' in self.dict():
y = self.y.state.astype(float).flatten()
yshape = self.y.state.shape
Cy1 = self.y.C1.astype(float)
else:
y = 0.*x
Cy1 = Cx1
yshape = xshape
if diag:
# if diag flag is set, return sd, not C^-1
try:
Cx1 = np.diag(Cx1)
#Cx1 = 1./np.sqrt(Cx1)
except:
pass
try:
Cy1 = np.diag(Cy1)
ww = np.where(Cy1)
Cy1[ww] = 1./np.sqrt(Cy1[ww])
except:
pass
if Cx1.size < x.size:
Cx1 = np.tile(Cx1,yshape[0])
if Cy1.size < x.size:
Cy1 = np.tile(Cy1,yshape[0])
else:
Cx1 = Cx1.flatten()
Cy1 = Cy1.flatten()
return x,Cx1,xshape,y,Cy1,yshape

memory = lambda self : self.logger.info("Memory: %.3f GB"%\
float(resource.getpagesize()))/(1024**3)))

def write(self,filename,dataset,fmt='pickle'):
'''
A write method for outputting state variables

e.g. self.write('xxx.dat','x',fmt='PARAMETERS')

Will also do basic plotting of state and observations
if possible.

'''
try:
op.plot(ploterr=True)
except:
# never mind
pass
state = '%s_state'%dataset
if not state in self.dict():
self.logger.error(\
'Cannot write state %s ... not in operator'%dataset)
return
writer(self[state],filename,dataset,fmt=fmt)

def cost(self):
'''
If the Operator is a super operator
it may contain other operators in self.operators
in which case, cost() returns J and J_prime
for all sub operators

If not, then just calculate the cost in this operator.
Note that in the case of a super operator, we do not
count self cost.

'''
#self.memory()
if not 'operators' in self.dict():
op = self
J,J_prime_tmp = op.J_prime()
op.x_state.logger.info("J = %f"%Jtmp)
#op.x_state.logger.info("J' = %s"%str(J_prime))

for i,op in enumerate(self.operators):
# load from self.x.state, the *full* representation
# into that required by this operator
Jtmp,J_prime_tmp = op.J_prime()
#self.logger.info('operator %s'%op.options.thisname)
#self.logger.info("%d     J   = %f"%(i,Jtmp))
#self.logger.info("%d max |J'| = %s"%(i,str(\
#                            np.max(np.abs(J_prime_tmp),axis=0))))

op.x_state.logger.info("     J   = %f"%Jtmp)
if 'debug' in self.dict() and self.debug > 1:
op.x_state.logger.info("     J'  = %s"%str(J_prime_tmp))
op.x_state.logger.info("     x   = %s"%str(self.x.state))
if op.doPlot and op.Hcount%op.plotMod == 0:
try:
op.plot()
except:
# never mind but inform the user
op.logger.info("... didn't make it")
op.Hcount += 1

if i == 0:
J = Jtmp
J_prime = J_prime_tmp
else:
J += Jtmp
J_prime += J_prime_tmp
#self.memory()
if self.doPlot and self.Hcount%self.plotMod == 0:
try:
self.plot(noy=True)
except:
# never mind but inform the user
self.logger.info("... didn't make it all")
self.Hcount += 1

return J,J_prime.flatten()

def invtrans(self,x,lim=[]):
'''
Apply the inverse transform to the state vector
if it is defined
'''
# lots of reasons not to do this ...
xshape = x.shape
x = x.flatten()
if lim != []:
lu = lim[1].flatten()
ll = lim[0].flatten()
ww = np.where(x>lu)
x[ww] = lu[ww]
ww = np.where(x<ll)
x[ww] = ll[ww]
try:
invtransform = self.invtransform
if invtransform == None:
return x.reshape(xshape)
except:
return x.reshape(xshape)
xout = x.copy().reshape(xshape)
thisinvtransform = self.invtransform
todo = np.in1d(thisinvtransform,self.x_meta.state)
for i in np.where(~todo)[0]:
this = thisinvtransform[i].replace(self.x_meta.state[i],'xout[...,i]')
xout[...,i] = eval(this)
return xout

def plot(self,noy=False,ploterr=False):
'''
Form a plot of the  x & y state
'''
import pylab
if not 'y_state' in self.dict() and noy == False:
self.logger.error('Failed attempt to plot y data: non existent y state')
return
x1,Cx1,xshape,y,Cy1,yshape = self.getxy(diag=True)
# try various places to bet bounds info
have_y = noy
if ploterr:
self.logger.info("plotting uncertainty info")
try:
x_min = self.xmin[0]
x_max = self.xmax[0]
except:
x_max = np.zeros(xshape[-1])
x_min = np.zeros(xshape[-1])
try:
for i in xrange(xshape[-1]):
x_max[i] = self.x_meta.bounds[i,1]
x_min[i] = self.x_meta.bounds[i,0]
except:
for i in xrange(xshape[-1]):
x_max[i] = max(x[...,i])
x_min[i] = max(x[...,i])

lu = np.tile(x_max,xshape[:-1])
ll = np.tile(x_min,xshape[:-1])
lim = (ll,lu)
# are any transforms required ?
x = self.invtrans(x1.reshape(xshape),lim=lim).flatten()

if ploterr:
try:
Cx1a = np.array(np.matrix(Cx1).diagonal()).flatten().reshape(xshape)
except:
Cx1a = Cx1.reshape(xshape)
try:
Cy1a = np.array(np.matrix(Cy1).diagonal()).flatten().reshape(yshape)
except:
try:
Cy1a = Cy1.reshape(yshape)
except:
Cy1a = np.array(list(Cy1)*np.array(yshape)[...,-1])
try:
dx1 = self.invtrans((x1+Cx1a.reshape(x1.shape)*1.96).reshape(xshape),lim=lim)
dx2 = self.invtrans((x1-Cx1a.reshape(x1.shape)*1.96).reshape(xshape),lim=lim)
except:
dx1 = dx2 = x1*0.
# dx1 should be upper bound of x
tmp = dx1.copy()
ww = np.where(dx2 > dx1)
tmp[ww] = dx2[ww]
dx2[ww] = dx1[ww]
dx1 = tmp
ddx1 = dx1 - x.reshape(dx1.shape)
ddx2 = x.reshape(dx1.shape) - dx2
have_y = True
try:
# we dont have y error yet
dy1 = self.y.state + self.y.sd*1.96 #(y + Cy1a.reshape(y.shape)*1.96).reshape(yshape)*0
dy2 = self.y.state - self.y.sd*1.96  #(y - Cy1a.reshape(y.shape)*1.96).reshape(yshape)*0
ddy2 = ddy1 = self.y.sd*1.96
try:
fddy2 = fddy1 = self.fwdSd*1.96
except:
fddy2 = fddy1 = 0
except:
have_y = False
dy1 = 0.
dy2 = 0.
ddy2 = ddy1 = 0.
fddy2 = fddy1 = 0
self.logger.info('starting plots ...')
pylab.rcParams.update(\
{'axes.labelsize':5,\
'text.fontsize': 6,
'xtick.labelsize': 6,\
'ytick.labelsize': 6})
fig = pylab.figure(1)
axisNum = 0
pylab.clf()
x = x.reshape(xshape)
y = y.reshape(yshape)
label = self.x_meta.location[0]
isgrid = True
try:
(locations,qlocations,state,sd) = \
self.x_state.ungrid(x,Cy1)
location = locations
except:
isgrid = False
try:
location = self.x.location[:,0]
except:
try:
if self.y.location[:,0].size == x[:,0].size:
location = self.y.location[:,0]
else:
label = label + ' index'
location = \
np.arange(x[:,0].size)\
*self.x_meta.qlocation[0][2]\
+ self.x_meta.qlocation[0][0]
except:
label = self.x_meta.location[0] + ' index'
location = np.arange(x[:,0].size)\
*self.x_meta.qlocation[0][2]\
+ self.x_meta.qlocation[0][0]

x_min = self.invtrans(x_min)
x_max = self.invtrans(x_max)
tmp = x_max.copy()
ww = np.where(x_min > x_max)
tmp[ww] = x_min[ww]
x_min[ww] = x_max[ww]
x_max = tmp

nn = xshape[-1]
for i in xrange(xshape[-1]):
axisNum += 1
ax = pylab.subplot(nn, 1, axisNum)
maxer = x_max[i]
if x_max[i] == x_min[i]:
maxer = x_max[i] + 0.001
#ax.set_ylim(x_min[i],maxer)
#try:
#	ax.set_xlim(np.min(location),np.max(location))
#except:
#pass
ax2 = ax.twinx()
yprops = dict(rotation=90,\
horizontalalignment='right',\
verticalalignment='center')
ax2.set_ylabel(self.x_meta.state[i],**yprops)
# some transform bug in plotting error bars
# leave for now
if not have_y  and (True or isgrid or True): # or x.shape[0] > 100:
location = location.flatten()
try:
if ploterr:
ax.fill_between(location,y1=dx2[:,i],y2=dx1[:,i],facecolor='0.8')
ax.plot(location,x[:,i],'r')
except:
xx = np.arange(tuple(np.array(xshape)[:-1])).flatten()
if ploterr:
ax.fill_between(xx,y1=dx2[:,i],y2=dx1[:,i],facecolor='0.8')
ax.plot(xx,x[:,i],'r')
else:
try:
if ploterr:
ax.errorbar(location, x[:,i], yerr=[ddx2[:,i],ddx1[:,i]], fmt='ro')
ax.plot(location,x[:,i],'ro')
except:
xx = np.arange(tuple(np.array(xshape)[:-1])).flatten()
if ploterr:
ax.errorbar(xx,x[:,i], yerr=[dx2[:,i],dx1[:,i]], fmt='ro')
ax.plot(xx,x[:,i],'ro')
#ax2 = ax.twinx()
# set nice ticks
ax.yaxis.set_major_locator(pylab.MaxNLocator(3))
#ax2.yaxis.set_major_locator(pylab.MaxNLocator(2))
ax2.set_yticks(ax.get_yticks())
ax2.set_yticklabels([])
pylab.xticks()
ax.xaxis.set_major_locator(pylab.MaxNLocator(9))
if i == xshape[-1]-1:
ax.set_xlabel(label)
else:
ax.set_xticklabels([])
ax2.set_yticks(ax.get_yticks())
ax2.set_yticklabels([])
#for lab in ax.get_xticklabels():
#    lab.set_rotation(0)
#ax.set_ylim(ax2.get_ylim())
if self.showPlot:
pylab.show()
else:
if self.plotMovie and 'plot_filename' in self.dict():
fig.savefig(self.plot_filename + '.%08d.x.png'%self.plotMovieFrame)
fig.savefig(self.plot_filename + '.x.png')

if noy ==  True:
if not self.doPlot:
if self.plotMovie:
self.logger.info("... written plots to " + \
self.plot_filename + ".%08d.x.png"%self.plotMovieFrame)
self.plotMovieFrame += 1
self.logger.info("... written plots to " + \
self.plot_filename + ".x.png")
return
# plot the other way around?
axisNum = 0
nn1 = int(np.sqrt(yshape[1]))
nn2 = yshape[1]/nn1 + 1
fig = pylab.figure(3)
pylab.clf()

# fix issues encounterd with 1D arrays
Hx = np.matrix(self.Hx.reshape(self.y.state.shape))
hshape = np.array(Hx.shape)
if hshape[1] == yshape[0] and hshape[0] == yshape[1]:
Hx = Hx.T
Hx = np.array(Hx)
# y plotting
# This will fail if there is no Hx, so wrap try around
ymax1 = np.max(y)
ymax2 = np.max(Hx)
try:
except:
# a spectral plot
#if ploterr:
for i in xrange(yshape[1]):
axisNum += 1
ymax1 = np.max(y)
ymax2 = np.max(Hx)
ax = pylab.subplot(nn1,nn2, axisNum)
#ax.set_ylim(0.,np.max([ymax1,ymax2]))
#try:
#    ax.set_xlim(np.min(location),np.max(location))
#except:
#    pass

ax2 = ax.twinx()
yprops = dict(rotation=90,\
horizontalalignment='right',\
verticalalignment='center')
ax2.set_ylabel(self.y_meta.state[i],**yprops)

# some transform bug in plotting error bars
# leave for now
if isgrid or y.shape[0] > 100:
try:
if False and ploterr:
ax.fill_between(location,y1=dy2[:,i],y2=dy1[:,i],facecolor='0.8')
ax.plot(location,Hx[:,i],'r')
ax.plot(location,y[:,i],'g,')
except:
try:
xx = location
except:
xx = np.arange(np.array(yshape)[:-1]).flatten()
if False and ploterr:
try:
ax.fill_between(xx,y1=dy2[:,i],y2=dy1[:,i],facecolor='0.8')
except:
ax.fill_between(xx,y1=dy2,y2=dy1,facecolor='0.8')
ax.plot(xx,Hx[:,i],'r')
ax.plot(xx,Hx[:,i],'ro')
if ploterr:
ax.errorbar(xx,y[:,i],yerr=ddy2[:,i], fmt='go')
ax.plot(xx,y[:,i],'go')
else:
try:
if ploterr:
ax.errorbar(location, y[:,i], yerr=[ddy2[:,i],ddy1[:,i]], fmt='go')
try:
if (Hx == x):
if ploterr:
ax.errorbar(location, Hx[:,i], yerr=[ddx2[:,i],ddx1[:,i]], fmt='ro')
else:
ax.plot(location,Hx[:,i],'ro')
else:
ax.plot(location,Hx[:,i],'ro')
ax.plot(location,y[:,i],'go')

except:
if (Hx == x).all():
if ploterr:
ax.errorbar(location, Hx[:,i], yerr=[ddx2[:,i],ddx1[:,i]], fmt='ro')
else:
ax.plot(location,Hx[:,i],'ro')
else:
ax.plot(location,Hx[:,i],'ro')
ax.plot(location,y[:,i],'go')
except:
try:
xx = location
except:
xx = np.arange(np.array(yshape)[:-1]).flatten()

if ploterr:
ax.errorbar(xx,y[:,i], yerr=[ddy2[:,i],ddy1[:,i]], fmt='ro')
ax.plot(xx,Hx[:,i],'r')
ax.plot(xx,y[:,i],'g.')
#ax2 = ax.twinx()
# set nice ticks
#ax2.set_yticks(ax.get_yticks())
ax.yaxis.set_major_locator(pylab.MaxNLocator(3))
#ax2.yaxis.set_major_locator(pylab.MaxNLocator(2))
ax2.set_yticklabels([])
pylab.xticks()
ax.xaxis.set_major_locator(pylab.MaxNLocator(9))
if i == xshape[-1]-1:
ax.set_xlabel(label)
else:
ax.set_xticklabels([])
#ax.set_ylim(ax2.get_ylim())
for lab in ax.get_xticklabels():
lab.set_rotation(0)
if self.showPlot:
pylab.show()
else:
fig.savefig(self.plot_filename + '.y2.png')
if self.plotMovie:
fig.savefig(self.plot_filename + '.%08d.y2.png'%self.plotMovieFrame)

nn1 = int(np.sqrt(yshape[0]))
nn2 = yshape[0]/nn1 + 1
fig = pylab.figure(2)
pylab.clf()
axisNum = 0
ymax1 = np.max(y)
ymax2 = np.max(Hx)

for i in xrange(yshape[0]):
axisNum += 1
ax = pylab.subplot(nn1,nn2, axisNum)
ax.set_ylim(0.,np.max([ymax1,ymax2]))
try:
wl = self.y_meta.spectral.nlw[self.y_meta.spectral.median_bands]
except:
wl = np.arange(yshape[1])
pylab.plot(wl,y[i])
if ploterr:
try:
ax.errorbar(wl,y[i], yerr=[ddy2[i],ddy1[i]], fmt='ro')
except:
pass
pylab.plot(wl,Hx[i],'g^')
if ploterr:
try:
ax.errorbar(wl,Hx[i], yerr=[fddy2[i],fddy1[i]], fmt='g^')
except:
pass
#try:
##    if (Hx == x).all():
#        if ploterr:
#            ax.errorbar(wl, Hx[i], yerr=[ddx2[:,i],ddx1[:,i]], fmt='ro')
pylab.yticks()
ax.yaxis.set_major_locator(pylab.MaxNLocator(3))
if not (axisNum-1)%nn2 == 0:
ax.set_yticklabels([])
pylab.xticks()
ax.xaxis.set_major_locator(pylab.MaxNLocator(3))
if (axisNum-1)/nn2 == nn1-1:
for label in ax.get_xticklabels():
label.set_rotation(90)
else:
ax.set_xticklabels([])

if self.showPlot:
pylab.show()
else:
fig.savefig(self.plot_filename + '.y.png')
if self.plotMovie:
fig.savefig(self.plot_filename + '.%08d.y.png'%self.plotMovieFrame)

if self.plotMovie:
self.logger.info("... written plots to " + \
self.plot_filename + ".%08d.y.png"%self.plotMovieFrame + " and " +\
self.plot_filename + ".%08d.y2.png"%self.plotMovieFrame + " and " +\
self.plot_filename + ".%08d.x.png"%self.plotMovieFrame)
self.logger.info("... written plots to " + \
self.plot_filename + ".y.png" + " and " +\
self.plot_filename + ".y2.png" + " and " +\
self.plot_filename + ".x.png")

if self.plotMovie:
self.plotMovieFrame += 1

def hessian(self):
'''
If the Operator is a super operator
it may contain other operators in self.operators
in which case, cost() returns J_prime_prime
summed for all sub operators

If not, then just calculate the J_prime_prime in this operator.
Note that in the case of a super operator, we do not
count self cost.

'''
self.x_state.logger.info('Calculating Hessian')
if not 'operators' in self.dict():
op = self
J,J_prime_tmp,J_prime_prime_tmp = \
self.J_prime_prime ()

self.x.state,sum=True,M=True)
op.x_state.logger.info("J = %f"%Jtmp)
#op.x_state.logger.info("J' = %s"%str(\
#        J_prime[tuple(np.array(J_prime_tmp.shape[:-1])*0)]))
for i,op in enumerate(self.operators):
# load from self.x.state, the *full* representation
# into that required by this operator
Jtmp,J_prime_tmp,J_prime_prime_tmp = \
op.J_prime_prime ()
self.x.state,sum=True,M=True)
#J_prime_prime_tmp = J_prime_prime_tmp
op.x_state.logger.info("     J   = %f"%Jtmp)
if i == 0:
J = Jtmp
J_prime = J_prime_tmp
J_prime_prime = J_prime_prime_tmp
else:
J += Jtmp
J_prime += J_prime_tmp
J_prime_prime += J_prime_prime_tmp
# fwd uncertainty
# = H_prime.T * J_prime_prime * H_prime
#self.memory()
return J,J_prime.flatten(),J_prime_prime

def fwdError(self):
'''
Calculate the fwd modelling uncertainty for all operators

This assumes that self.Ihessian has been calculated

For each operator then, op.fwdUncert = H_prime.T self.Ihessian H_prime

'''
if not 'Ihessian' in self.dict():
self.x_state.logger.error('Cannot call fwdError if Ihessian has not been calculated')
return 0
self.x_state.logger.info('Calculating Fwd Model error')
n = 0
if not 'operators' in self.dict():
op = self
try:
Ihessian = self.Ihessian
x,Cx1,xshape,y,Cy1,yshape = self.getxy()
J,J_prime0 = self.J_prime()
H_prime = self.H_prime(x)
op.fwdUncert = H_prime.T * Ihessian * H_prime
n = n + 1
except:
pass
for i,op in enumerate(self.operators):
# load from self.x.state, the *full* representation
# into that required by this operator
try:
x,Cx1,xshape,y,Cy1,yshape = op.getxy()
nn = np.sqrt(Ihessian.size).astype(int)
Ihessian = Ihessian.reshape((nn,nn))
J,J_prime0 = op.J_prime()
H_prime = op.H_prime(x)
# but Ihessian only has one entry for each (time/space) sample
# and here we need a version of this that is expanded
op.xlookup = np.atleast_1d(op.xlookup)
nsamples = len(op.xlookup)
ntotal = np.array([len(op.xlookup[i]) for i in xrange(nsamples)]).sum()
if  nsamples != ntotal:
# need to expand
nn = H_prime.shape[0]
Ihessian2 = np.zeros((nn,nn))
count = 0
nparams = xshape[-1]
for j in xrange(len(op.xlookup)):
for k in xrange(len(op.xlookup[j])):
Ihessian2[count*nparams:(count+1)*nparams,count*nparams:(count+1)*nparams] = Ihessian[j*nparams:(j+1)*nparams,j*nparams:(j+1)*nparams]
count = count + 1
Ihessian = Ihessian2
op.fwdUncert = np.matrix(H_prime).T * np.matrix(Ihessian) *  np.matrix(H_prime)
op.fwdSd = np.array(np.sqrt(op.fwdUncert.diagonal())).reshape(yshape)
n = n + 1
except:
pass
return n

def J ( self ):
"""
The operator contribution to the cost function:

J = 0.5 (y - H(x))^T (C_y^-1 + C_H(x)^-1) (y - H(x))

This is a single value, J
"""
x,Cx1,xshape,y,Cy1,yshape = self.getxy()
self.Hx = self.H(x).flatten()
# NB, Hx should be same shape as y
# or y is a constant
d1 = np.array((y - self.Hx))
self.linear.d1 = np.array((y - self.Hx))
d = (Cy1) * d1
result = (0.5 * d1 * d).flatten()
w = np.where(result>0)
J = result[w].sum()
return J

def J_prime ( self ):
"""
The operator contribution to the cost function:

J' = dJ/dx = (y - H(x))^T (C_y^-1) H'(x)

This should be of dimensions n_
"""
J = self.J()
#if self.J_prime_approx != None:
#    return self.J_prime_approx()
x,Cx1,xshape,y,Cy1,yshape = self.getxy()
H_prime_x = self.H_prime(x)
try:
Jprime = -np.array(self.linear.d1 * Cy1 * H_prime_x)
except:
Jprime = -np.array(self.linear.d1 * Cy1 * np.matrix(H_prime_x).T)[0]
# Jprime should be the same shape as self.x.state
if Jprime.size == self.x.state.size **2:
try:
Jprime = -np.matrix(self.linear.d1) * np.matrix(Cy1 * H_prime_x)
if Jprime.size == self.x.state.size **2:
Jprime = Jprime.diagonal()
except:
Jprime = Jprime.diagonal()
return J,Jprime

#return J,

def J_prime_prime ( self ):
"""
The operator contribution to the cost function:

J'' = d^2J/(dx_idx_j)

Here, this is simply Cy^-1

NB: if you inherit this class you *must*
define a new J_prime_prime if you want to
calculate uncertainties

or sel J_prime_prime to return J_prime_prime_approx
"""
self.x_state.logger.info(' ... Calculating Hessian')
x,Cx1,xshape,y,Cy1,yshape = self.getxy()
# the operator here is I()
# so we return C-1
J,J_prime_0 = self.J_prime()
return J,J_prime_0 , np.diag(Cy1)

def J_prime_prime_approx ( self ):
"""
The operator contribution to the cost function:

J'' = d^2J/(dx_idx_j)

numerically.

This is costly and you can normally find a better way to do it.
"""
## an attempt at doint it numerically
## revisit this some time
# run the base J_prime
J,J_prime_0 = self.J_prime()
J_prime_0 = J_prime_0.copy()
x0 = x.copy()
xshape2 = np.array(xshape)
x0 = x0.flatten()
out = np.zeros((x0.size,x0.size))
xmax = self.xmax.flatten()
xmin = self.xmin.flatten()
delta = self.delta.flatten()
x1 = x0.copy()
xnparam = xshape2.prod()
for i in xrange(xnparam):
self.x_state.logger.info('...%d/%d'%(i+1,xnparam))
if delta[i] != 0:
x1[i] += delta[i]
if x1[i] > xmax[i]:
x1[i] = xmax[i] - delta[i]
if x1[i] < xmin[i]:
x1[i] = xmin[i]+delta[i]
d = x1[i] - x0[i]
self.x.state = x1.reshape(xshape).copy()
J1,J_prime_1 = self.J_prime()
out[:,i] = ((J_prime_1-J_prime_0)/d).flatten()
x1[i] = x0[i]
self.x.state = x0.reshape(xshape).copy()
return J,J_prime_0,out

def JJ(self,x):
'''
A call to self.J that also loads x

Required by self.J_prime_approx_3()

'''
self.x.state = x.reshape(self.x.state.shape)
return self.J({})

def J_prime_approx_1(self):
'''
A discrete approximation to J_prime

The method assumes that J_prime is independent for
each sample in the last column of the state vector.

This will be appropriate e.g. for observation operators
where the derivative of one observation does
does not depend on any other samples. This means that we can
take finite difference steps only over this last dimension
and not over all samples.

J_prime_approx_3
J_prime_approx_2 (not yet implemented)

'''
# can probably reuse this code to write a better version of H_prime
# J at x
J0 = self.J()
# make a copy of x.state
xstate = self.x.state.copy()
# now loop over each parameter
# this greatly speeds up the calculation of an approximate J_prime
# as we know that each observation is independent (i.e the J_prime
# of each observation depends only on the state vactor elements for
# that observation)
for i in np.where(self.xranger>0)[0]:
fx0 = xstate.copy()
delta = self.delta[:,i]
xmax = self.xmax[:,i]
xmin = self.xmin[:,i]
x0 = xstate[:,i]
x1 = x0 + delta
ww = np.where(x1 > xmax[i])
x1[ww] = x0[ww] -delta[ww]
ww = np.where(x1 < xmin)
x1[ww] = xmin[ww]
delta = x1 - x0
ww = np.where(delta != 0)
fx0[ww,i] = x1[ww]
self.x.state = fx0
J1 = self.J()
dJ = np.zeros_like(delta)
dJ[ww] = (J1-J0)/delta[ww]
self.linear.J_prime[...,i] = dJ
self.x.state = xstate.copy()
return J0,self.linear.J_prime

def J_prime_approx_3(self):
'''
A discrete approximation to J_prime

This method is the 'backup' and default approximation method
for J_prime as it treats all samples independently
and so has to go ober x.size finite steps for J.

Generally, J_prime_approx_2 or J_prime_approx_1 will
be significantly faster than this, but there may be occasions
when this method is appropriate.

The method makes use of and requires DerApproximator
so will only work if this python library is available.

It might be adviseable at some point to write a backup method in case
that is not installed, but thats not a very good use of time really ...

'''
try:
from DerApproximator import get_d1
except:
raise Exception(\
"Cannot import DerApproximator for derivative approx"\
+ " in J_prime_approx_3 ... check it is installed or"\
+ " avoid calling this method")
J,J_prime = self.J_prime()
self.J_prime_approx = J_prime.flatten()
return self.J_prime_approx

self.J_prime_approx = get_d1(self.JJ,np.array(self.x.state).flatten())
return self.J_prime_approx

def tester(plot=True):
'''
Derivative test for individual J_primes

'''
from eoldas_Solver import eoldas_Solver
from eoldas_ConfFile import ConfFile

logdir = 'logs'
logfile = 'log.dat'
thisname = 'eoldas_Test1'
conffile = ['semid_default.conf'] # ,'Obs1.conf']
confs = ConfFile(conffile,logger=None,\
logdir=logdir,\
logfile=logfile,\
solver = eoldas_Solver(confs,thisname=thisname,\
logger=confs.logger,\
print "See logfile logs/log.dat for results of test"

Expectation = '''
- eoldas_Test1-obs-x - INFO - operator eoldas_Test1-obs-x
- eoldas_Test1-obs-x - INFO - Calculating J_prime
- eoldas_Test1-obs-x - INFO - Calculating approximate J_prime
- eoldas_Test1-obs-x - INFO - J_prime        Range: [-78928.2236:83512.5508]
- eoldas_Test1-obs-x - INFO - J_prime_approx Range: [-78928.2216:83512.5483]
- eoldas_Test1-obs-x - INFO - Mean Diff -0.000050
- eoldas_Test1-obs-x - INFO - Mean Abs Diff 0.001528
- eoldas_Test1-obs-x - INFO - RMSE 0.001904
- eoldas_Test1-modelt-x - INFO - operator eoldas_Test1-modelt-x
- eoldas_Test1-modelt-x - INFO - Calculating J_prime
- eoldas_Test1-modelt-x - INFO - Calculating approximate J_prime
- eoldas_Test1-modelt-x - INFO - J_prime        Range: [-1.621250:1.696041]
- eoldas_Test1-modelt-x - INFO - J_prime_approx Range: [-1.621250:1.696041]
- eoldas_Test1-modelt-x - INFO - Mean Diff 0.000000
- eoldas_Test1-modelt-x - INFO - Mean Abs Diff 0.000000
- eoldas_Test1-modelt-x - INFO - RMSE 0.000000
'''
print 'Expectation:'
print Expectation
solver.confs.infos = np.atleast_1d(solver.confs.infos)
for i in xrange(len(solver.confs.infos)):
solver.prep(i)
# randomise, so we get a good signal to look at
# Make the xstate random here, just as a good test

#solver.root.x.state = np.random.rand(solver.root.x.state.size).\
#    reshape(solver.root.x.state.shape)
#solver.root.x.state = (np.arange(solver.root.x.state.size)\
#        / float(solver.root.x.state.size)).\
#        reshape(solver.root.x.state.shape)

#solver.root.x.state[...,0] = 1
#solver.root.x.state[...,0]*10 + \
# np.random.rand(solver.root.x.state[...,0].size)
Jall = 0
J_primeall = solver.root.x.state * 0.
J_prime_approxall = solver.root.x.state * 0.
for i,op in enumerate(solver.root.operators):

op.logger.info('operator %s'%op.options.thisname)
op.logger.info('Calculating J_prime')
J,J_prime = op.J_prime({})

# unload into the *full* representation
J_primeall += J_prime_tmp
J_prime = J_prime.flatten()

op.logger.info('Calculating approximate J_prime')
J_prime_approx = op.J_prime_approx_slow()
J_prime_approxall += J_prime_tmp
J_prime_approx = J_prime_approx.flatten()

d = np.atleast_1d(J_prime - J_prime_approx)
n = float(len(d))
op.logger.info('J_prime        Range: [%.6f:%.6f]'\
%(np.min(J_prime),np.max(J_prime)))
op.logger.info('J_prime_approx Range: [%.6f:%.6f]'\
%(np.min(J_prime_approx),np.max(J_prime_approx)))
op.logger.info('Mean Diff %f'%((d).sum()/n))
op.logger.info('Mean Abs Diff %f'%(np.abs(d).sum()/n))
op.logger.info('RMSE %f'%np.sqrt((d*d).sum()/n))

if plot:
try:
import pylab
max = np.max([np.max(J_prime),np.max(J_prime_approx)])
min = np.min([np.min(J_prime),np.min(J_prime_approx)])
pylab.plot(min,max,'b-')
pylab.plot(J_prime,J_prime_approx,'o')
pylab.show()
except:
pass
if plot:
# ideally all points should overlie each other
# on a line
J_prime = J_primeall.flatten()
J_prime_approx = J_prime_approxall.flatten()

J = solver.cost(xopt)
J_prime2 = solver.root.x.state*0.
J_prime2tmp = solver.cost_df(xopt)
J_prime2 =J_prime2.flatten()

J_prime_approx2 = solver.root.x.state*0.
J_prime_approx2tmp = solver.approx_cost_df(xopt)
J_prime_approx2 =J_prime_approx2.flatten()

try:
import pylab
max = np.max([np.max(J_prime),np.max(J_prime_approx)])
min = np.min([np.min(J_prime),np.min(J_prime_approx)])
pylab.plot(min,max,'b-')
pylab.plot(J_prime,J_prime_approx,'o')
pylab.plot(J_prime,J_prime2,'v')
# for some reason J_prime_approx2 is different
pylab.plot(J_prime,J_prime_approx2,'^')
pylab.show()
except:
pass

def demonstration(plot=False):
tester(plot=plot)

if __name__ == "__main__":
demonstration()

```