• Facebook
  • Twitter
  • Reddit
  • StumbleUpon
  • Digg
  • email

"""
Session Management
(from web.py)
"""
 
import os, time, datetime, random, base64
try:
    import cPickle as pickle
except ImportError:
    import pickle
try:
    import hashlib
    sha1 = hashlib.sha1
except ImportError:
    import sha
    sha1 = sha.new
 
import utils
import webapi as web
 
__all__ = [
    'Session', 'SessionExpired',
    'Store', 'DiskStore', 'DBStore',
]
 
web.config.session_parameters = utils.storage({
    'cookie_name': 'webpy_session_id',
    'cookie_domain': None,
    'timeout': 86400, #24 * 60 * 60, # 24 hours in seconds
    'ignore_expiry': True,
    'ignore_change_ip': True,
    'secret_key': 'fLjUfxqXtfNoIldA0A0J',
    'expired_message': 'Session expired',
})
 
class SessionExpired(web.HTTPError): 
    def __init__(self, message):
        web.HTTPError.__init__(self, '200 OK', {}, data=message)
 
class Session(utils.ThreadedDict):
    """Session management for web.py
    """
 
    def __init__(self, app, store, initializer=None):
        self.__dict__['store'] = store
        self.__dict__['_initializer'] = initializer
        self.__dict__['_last_cleanup_time'] = 0
        self.__dict__['_config'] = utils.storage(web.config.session_parameters)
 
        if app:
            app.add_processor(self._processor)
 
    def _processor(self, handler):
        """Application processor to setup session for every request"""
        self._cleanup()
        self._load()
 
        try:
            return handler()
        finally:
            self._save()
 
    def _load(self):
        """Load the session from the store, by the id from cookie"""
        cookie_name = self._config.cookie_name
        cookie_domain = self._config.cookie_domain
        self.session_id = web.cookies().get(cookie_name)
 
        # protection against session_id tampering
        if self.session_id and not self._valid_session_id(self.session_id):
            self.session_id = None
 
        self._check_expiry()
        if self.session_id:
            d = self.store[self.session_id]
            self.update(d)
            self._validate_ip()
 
        if not self.session_id:
            self.session_id = self._generate_session_id()
 
            if self._initializer:
                if isinstance(self._initializer, dict):
                    self.update(self._initializer)
                elif hasattr(self._initializer, '__call__'):
                    self._initializer()
 
        self.ip = web.ctx.ip
 
    def _check_expiry(self):
        # check for expiry
        if self.session_id and self.session_id not in self.store:
            if self._config.ignore_expiry:
                self.session_id = None
            else:
                return self.expired()
 
    def _validate_ip(self):
        # check for change of IP
        if self.session_id and self.get('ip', None) != web.ctx.ip:
            if not self._config.ignore_change_ip:
               return self.expired() 
 
    def _save(self):
        cookie_name = self._config.cookie_name
        cookie_domain = self._config.cookie_domain
        if not self.get('_killed'):
            web.setcookie(cookie_name, self.session_id, domain=cookie_domain)
            self.store[self.session_id] = dict(self)
        else:
            web.setcookie(cookie_name, self.session_id, expires=-1, domain=cookie_domain)
 
    def _generate_session_id(self):
        """Generate a random id for session"""
 
        while True:
            rand = os.urandom(16)
            now = time.time()
            secret_key = self._config.secret_key
            session_id = sha1("%s%s%s%s" %(rand, now, utils.safestr(web.ctx.ip), secret_key))
            session_id = session_id.hexdigest()
            if session_id not in self.store:
                break
        return session_id
 
    def _valid_session_id(self, session_id):
        rx = utils.re_compile('^[0-9a-fA-F]+$')
        return rx.match(session_id)
 
    def _cleanup(self):
        """Cleanup the stored sessions"""
        current_time = time.time()
        timeout = self._config.timeout
        if current_time - self._last_cleanup_time > timeout:
            self.store.cleanup(timeout)
            self.__dict__['_last_cleanup_time'] = current_time
 
    def expired(self):
        """Called when an expired session is atime"""
        self._killed = True
        self._save()
        raise SessionExpired(self._config.expired_message)
 
    def kill(self):
        """Kill the session, make it no longer available"""
        del self.store[self.session_id]
        self._killed = True
 
class Store:
    """Base class for session stores"""
 
    def __contains__(self, key):
        raise NotImplementedError
 
    def __getitem__(self, key):
        raise NotImplementedError
 
    def __setitem__(self, key, value):
        raise NotImplementedError
 
    def cleanup(self, timeout):
        """removes all the expired sessions"""
        raise NotImplementedError
 
    def encode(self, session_dict):
        """encodes session dict as a string"""
        pickled = pickle.dumps(session_dict)
        return base64.encodestring(pickled)
 
    def decode(self, session_data):
        """decodes the data to get back the session dict """
        pickled = base64.decodestring(session_data)
        return pickle.loads(pickled)
 
class DiskStore(Store):
    """
    Store for saving a session on disk.
 
        >>> import tempfile
        >>> root = tempfile.mkdtemp()
        >>> s = DiskStore(root)
        >>> s['a'] = 'foo'
        >>> s['a']
        'foo'
        >>> time.sleep(0.01)
        >>> s.cleanup(0.01)
        >>> s['a']
        Traceback (most recent call last):
            ...
        KeyError: 'a'
    """
    def __init__(self, root):
        # if the storage root doesn't exists, create it.
        if not os.path.exists(root):
            os.mkdir(root)
        self.root = root
 
    def _get_path(self, key):
        if os.path.sep in key: 
            raise ValueError, "Bad key: %s" % repr(key)
        return os.path.join(self.root, key)
 
    def __contains__(self, key):
        path = self._get_path(key)
        return os.path.exists(path)
 
    def __getitem__(self, key):
        path = self._get_path(key)
        if os.path.exists(path): 
            pickled = open(path).read()
            return self.decode(pickled)
        else:
            raise KeyError, key
 
    def __setitem__(self, key, value):
        path = self._get_path(key)
        pickled = self.encode(value)    
        try:
            f = open(path, 'w')
            try:
                f.write(pickled)
            finally: 
                f.close()
        except IOError:
            pass
 
    def __delitem__(self, key):
        path = self._get_path(key)
        if os.path.exists(path):
            os.remove(path)
 
    def cleanup(self, timeout):
        now = time.time()
        for f in os.listdir(self.root):
            path = self._get_path(f)
            atime = os.stat(path).st_atime
            if now - atime > timeout :
                os.remove(path)
 
class DBStore(Store):
    """Store for saving a session in database
    Needs a table with the following columns:
 
        session_id CHAR(128) UNIQUE NOT NULL,
        atime DATETIME NOT NULL default current_timestamp,
        data TEXT
    """
    def __init__(self, db, table_name):
        self.db = db
        self.table = table_name
 
    def __contains__(self, key):
        data = self.db.select(self.table, where="session_id=$key", vars=locals())
        return bool(list(data)) 
 
    def __getitem__(self, key):
        now = datetime.datetime.now()
        try:
            s = self.db.select(self.table, where="session_id=$key", vars=locals())[0]
            self.db.update(self.table, where="session_id=$key", atime=now, vars=locals())
        except IndexError:
            raise KeyError
        else:
            return self.decode(s.data)
 
    def __setitem__(self, key, value):
        pickled = self.encode(value)
        now = datetime.datetime.now()
        if key in self:
            self.db.update(self.table, where="session_id=$key", data=pickled, vars=locals())
        else:
            self.db.insert(self.table, False, session_id=key, data=pickled )
 
    def __delitem__(self, key):
        self.db.delete(self.table, where="session_id=$key", vars=locals())
 
    def cleanup(self, timeout):
        timeout = datetime.timedelta(timeout/(24.0*60*60)) #timedelta takes numdays as arg
        last_allowed_time = datetime.datetime.now() - timeout
        self.db.delete(self.table, where="$last_allowed_time > atime", vars=locals())
 
class ShelfStore:
    """Store for saving session using `shelve` module.
 
        import shelve
        store = ShelfStore(shelve.open('session.shelf'))
 
    XXX: is shelve thread-safe?
    """
    def __init__(self, shelf):
        self.shelf = shelf
 
    def __contains__(self, key):
        return key in self.shelf
 
    def __getitem__(self, key):
        atime, v = self.shelf[key]
        self[key] = v # update atime
        return v
 
    def __setitem__(self, key, value):
        self.shelf[key] = time.time(), value
 
    def __delitem__(self, key):
        try:
            del self.shelf[key]
        except KeyError:
            pass
 
    def cleanup(self, timeout):
        now = time.time()
        for k in self.shelf.keys():
            atime, v = self.shelf[k]
            if now - atime > timeout :
                del self[k]
 
if __name__ == '__main__' :
    import doctest
    doctest.testmod()