##############################################################################
#
# Copyright (c) 2002 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
import os
import pickle
import unittest
 
from persistent.dict import PersistentDict
from persistent import UPTODATE
import transaction
 
import ZODB.tests.util
from zodbcode import tests # import this package, to get at __file__ reliably
from zodbcode.module \
     import ManagedRegistry, PersistentModuleImporter, PersistentPackage
 
# snippets of source code used by testModules
foo_src = """\
import string
x = 1
def f(y):
    return x + y
"""
quux_src = """\
from foo import x
def f(y):
    return x + y
"""
side_effect_src = """\
x = 1
def inc():
    global x
    x += 1
    return x
"""
builtin_src = """\
x = 1, 2, 3
def f():
    return len(x)
"""
nested_src = """\
def f(x):
    def g(y):
        def z(z):
            return x + y + z
        return x + y
    return g
"""
 
nested_err_src = nested_src + """\
g = f(3)
"""
 
closure_src = """\
def f(x):
    def g(y):
        return x + y
    return g
 
inc = f(1)
"""
 
class TestPersistentModuleImporter(PersistentModuleImporter):
 
    def __init__(self, registry):
        self._registry = registry
        self._registry._p_activate()
 
    def __import__(self, name, globals={}, locals={}, fromlist=[]):
        mod = self._import(self._registry, name, self._get_parent(globals),
                           fromlist)
        if mod is not None:
            return mod
        return self._saved_import(name, globals, locals, fromlist)
 
class TestBase(unittest.TestCase):
 
    def setUp(self):
        self.db = ZODB.tests.util.DB()
        self.root = self.db.open().root()
        self.registry = ManagedRegistry()
        self.importer = TestPersistentModuleImporter(self.registry)
        self.importer.install()
        self.root["registry"] = self.registry
        transaction.commit()
        _dir, _file = os.path.split(tests.__file__)
        self._pmtest = os.path.join(_dir, "_pmtest.py")
 
    def tearDown(self):
        # just in case
        transaction.abort()
        self.db.close()
        self.importer.uninstall()
 
    def sameModules(self, registry):
        m1 = self.registry.modules()
        m1.sort()
        m2 = registry.modules()
        m2.sort()
        self.assertEqual(m1, m2)
 
    def useNewConnection(self):
        # load modules using a separate connection to test that
        # modules can be recreated from the database
        cn = self.db.open()
        reg = cn.root()["registry"]
        self.sameModules(reg)
        for name in reg.modules():
            mod = reg.findModule(name)
            mod._p_activate()
            self.assertEqual(mod._p_state, UPTODATE)
            for obj in mod.__dict__.values():
                if hasattr(obj, "_p_activate"):
                    obj._p_activate()
        # XXX somehow objects are getting registered here, but not
        # modified.  need to figure out what is going wrong, but for
        # now just abort the transaction.
        ##assert not cn._registered
        transaction.abort()
        cn.close()
 
class TestModule(TestBase):
 
    def testModule(self):
        self.registry.newModule("pmtest", open(self._pmtest).read())
        transaction.commit()
        self.assert_(self.registry.findModule("pmtest"))
        import pmtest
        pmtest._p_deactivate()
        self.assertEqual(pmtest.a, 1)
        pmtest.f(4)
        self.useNewConnection()
 
    def testUpdateFunction(self):
        self.registry.newModule("pmtest", "def f(x): return x")
        transaction.commit()
        import pmtest
        self.assertEqual(pmtest.f(3), 3)
        copy = pmtest.f
        self.registry.updateModule("pmtest", "def f(x): return x + 1")
        transaction.commit()
        pmtest._p_deactivate()
        self.assertEqual(pmtest.f(3), 4)
        self.assertEqual(copy(3), 4)
        self.useNewConnection()
 
    def testUpdateClass(self):
        self.registry.newModule("pmtest", src)
        transaction.commit()
        import pmtest
        inst = pmtest.Foo()
        v0 = inst.x
        v1 = inst.m()
        v2 = inst.n()
        self.assertEqual(v1 - 1, v2)
        self.assertEqual(v0 + 1, v1)
        self.registry.updateModule("pmtest", src2)
        transaction.commit()
        self.assertRaises(AttributeError, getattr, inst, "n")
        self.useNewConnection()
 
    def testModules(self):
        self.registry.newModule("foo", foo_src)
        # quux has a copy of foo.x
        self.registry.newModule("quux", quux_src)
        # bar has a reference to foo
        self.registry.newModule("bar", "import foo")
        # baz has reference to f and copy of x,
        # remember the the global x in f is looked up in foo
        self.registry.newModule("baz", "from foo import *")
        import foo, bar, baz, quux
        self.assert_(foo._p_oid is None)
        transaction.commit()
        self.assert_(foo._p_oid)
        self.assert_(bar._p_oid)
        self.assert_(baz._p_oid)
        self.assert_(quux._p_oid)
        self.assertEqual(foo.f(4), 5)
        self.assertEqual(bar.foo.f(4), 5)
        self.assertEqual(baz.f(4), 5)
        self.assertEqual(quux.f(4), 5)
        self.assert_(foo.f is bar.foo.f)
        self.assert_(foo.f is baz.f)
        foo.x = 42
        self.assertEqual(quux.f(4), 5)
        transaction.commit()
        self.assertEqual(quux.f(4), 5)
        foo._p_deactivate()
        # foo is deactivated, which means its dict is empty when f()
        # is activated, how do we guarantee that foo is also
        # activated?
        self.assertEqual(baz.f(4), 46)
        self.assertEqual(bar.foo.f(4), 46)
        self.assertEqual(foo.f(4), 46)
        self.useNewConnection()
 
    def testFunctionAttrs(self):
        self.registry.newModule("foo", foo_src)
        import foo
        A = foo.f.attr = "attr"
        self.assertEqual(foo.f.attr, A)
        transaction.commit()
        self.assertEqual(foo.f.attr, A)
        foo.f._p_deactivate()
        self.assertEqual(foo.f.attr, A)
        del foo.f.attr
        self.assertRaises(AttributeError, getattr, foo.f, "attr")
        foo.f.func_code
        self.useNewConnection()
 
    def testFunctionSideEffects(self):
        self.registry.newModule("effect", side_effect_src)
        import effect
        effect.inc()
        transaction.commit()
        effect.inc()
        self.assert_(effect._p_changed)
        self.useNewConnection()
 
    def testBuiltins(self):
        self.registry.newModule("test", builtin_src)
        transaction.commit()
        import test
        self.assertEqual(test.f(), len(test.x))
        test._p_deactivate()
        self.assertEqual(test.f(), len(test.x))
        self.useNewConnection()
 
    def testNested(self):
        self.assertRaises(TypeError,
                          self.registry.newModule, "nested", nested_err_src)
        self.registry.newModule("nested", nested_src)
        transaction.commit()
        import nested
        g = nested.f(3)
        self.assertEqual(g(4), 7)
 
    def testLambda(self):
        # test a lambda that contains another lambda as a default
        self.registry.newModule("test",
                                "f = lambda x, y = lambda: 1: x + y()")
        transaction.commit()
        import test
        self.assertEqual(test.f(1), 2)
        self.useNewConnection()
 
    def testClass(self):
        self.registry.newModule("foo", src)
        transaction.commit()
        import foo
        obj = foo.Foo()
        obj.m()
        self.root["m"] = obj
        transaction.commit()
        foo._p_deactivate()
        o = foo.Foo()
        i = o.m()
        j = o.m()
        self.assertEqual(i + 1, j)
        self.useNewConnection()
 
    def testPackage(self):
        self.registry.newModule("A.B.C", "def f(x): return x")
        transaction.commit()
 
        import A.B.C
        self.assert_(isinstance(A, PersistentPackage))
        self.assertEqual(A.B.C.f("A"), "A")
 
        self.assertRaises(ValueError, self.registry.newModule,
                          "A.B", "def f(x): return x + 1")
 
        self.registry.newModule("A.B.D", "def f(x): return x")
        transaction.commit()
 
        from A.B import D
        self.assert_(hasattr(A.B.D, "f"))
        self.useNewConnection()
 
    def testPackageInit(self):
        self.registry.newModule("A.B.C", "def f(x): return x")
        transaction.commit()
 
        import A.B.C
 
        self.registry.newModule("A.B.__init__", "x = 2")
        transaction.commit()
 
        import A.B
        self.assert_(hasattr(A.B, "C"))
        self.assertEqual(A.B.x, 2)
 
        self.assertRaises(ValueError, self.registry.newModule,
                          "A.__init__.D", "x = 2")
        self.useNewConnection()
 
    def testPackageRelativeImport(self):
        self.registry.newModule("A.B.C", "def f(x): return x")
        transaction.commit()
 
        self.registry.newModule("A.Q", "from B.C import f")
        transaction.commit()
 
        import A.Q
        self.assertEqual(A.B.C.f, A.Q.f)
 
        self.registry.updateModule("A.Q", "import B.C")
        transaction.commit()
 
        self.assertEqual(A.B.C.f, A.Q.B.C.f)
 
        try:
            import A.B.Q
        except ImportError:
            pass
        self.useNewConnection()
 
    def testImportAll(self):
        self.registry.newModule("A.B.C",
                                """__all__ = ["a", "b"]; a, b, c = 1, 2, 3""")
        transaction.commit()
 
        d = {}
        exec "from A.B.C import *" in d
        self.assertEqual(d['a'], 1)
        self.assertEqual(d['b'], 2)
        self.assertRaises(KeyError, d.__getitem__, "c")
 
        self.registry.newModule("A.B.D", "from C import *")
        transaction.commit()
 
        import A.B.D
        self.assert_(hasattr(A.B.D, "a"))
        self.assert_(hasattr(A.B.D, "b"))
        self.assert_(not hasattr(A.B.D, "c"))
 
        self.registry.newModule("A.__init__", """__all__ = ["B", "F"]""")
        transaction.commit()
 
        self.registry.newModule("A.F", "spam = 1")
        transaction.commit()
 
        import A
        self.assertEqual(A.F.spam, 1)
        self.useNewConnection()
 
class TestModuleReload(unittest.TestCase):
    """Test reloading of modules"""
 
    def setUp(self):
        self.db = ZODB.tests.util.DB()
        self.open()
        _dir, _file = os.path.split(tests.__file__)
        self._pmtest = os.path.join(_dir, "_pmtest.py")
 
    def tearDown(self):
        transaction.abort()
        self.close()
        self.db.close()
 
    def open(self):
        # open a new db and importer from the storage
        self.root = self.db.open().root()
        self.registry = self.root.get("registry")
        if self.registry is None:
            self.root["registry"] = self.registry = ManagedRegistry()
        self.importer = TestPersistentModuleImporter(self.registry)
        self.importer.install()
        transaction.commit()
 
    def close(self):
        self.importer.uninstall()
        self.root._p_jar.close()
 
    def testModuleReload(self):
        self.registry.newModule("pmtest", open(self._pmtest).read())
        transaction.commit()
        import pmtest
        pmtest._p_deactivate()
        self.assertEqual(pmtest.a, 1)
        pmtest.f(4)
        self.close()
        pmtest._p_deactivate()
        self.open()
        del pmtest
        import pmtest
 
    def testClassReload(self):
        self.registry.newModule("foo", src)
        transaction.commit()
        import foo
        obj = foo.Foo()
        obj.m()
        self.root["d"] = d = PersistentDict()
        d["m"] = obj
        transaction.commit()
        self.close()
        foo._p_deactivate()
        self.open()
        del foo
        import foo
 
    def testModulePicklability(self):
        from zodbcode.tests import test_module
        s = pickle.dumps(test_module)
        m = pickle.loads(s)
        self.assertEqual(m, test_module)
 
def test_suite():
    s = unittest.TestSuite()
    for klass in TestModule, TestModuleReload:
        s.addTest(unittest.makeSuite(klass))
    return s
 
src = """\
class Foo(object):
    def __init__(self):
        self.x = id(self)
    def m(self):
        self.x += 1
        return self.x
    def n(self):
        self.x -= 1
        return self.x
"""
 
src2 = """\
class Foo(object):
    def __init__(self):
        self.x = 0
    def m(self):
        self.x += 10
        return self.x
"""