import numpy as np
 
from pyoperators import (
    flags, Operator, AdditionOperator, BlockColumnOperator,
    BlockDiagonalOperator, BlockRowOperator,  CompositionOperator,
    DiagonalOperator, HomothetyOperator, IdentityOperator,
    MultiplicationOperator, I, asoperator)
from pyoperators.core import BlockOperator
from pyoperators.utils import merge_none
from pyoperators.utils.testing import (
    assert_eq, assert_is_instance, assert_raises, assert_is_type)
from .common import Stretch
 
 
def test_partition1():
    o1 = HomothetyOperator(1, shapein=1)
    o2 = HomothetyOperator(2, shapein=2)
    o3 = HomothetyOperator(3, shapein=3)
    r = DiagonalOperator([1, 2, 2, 3, 3, 3]).todense()
 
    def func(ops, p):
        op = BlockDiagonalOperator(ops, partitionin=p, axisin=0)
        assert_eq(op.todense(6), r, str(op))
    for ops, p in zip(
            ((o1, o2, o3), (I, o2, o3), (o1, 2*I, o3), (o1, o2, 3*I)),
            (None, (1, 2, 3), (1, 2, 3), (1, 2, 3))):
        yield func, ops, p
 
 
def test_partition2():
    # in some cases in this test, partitionout cannot be inferred from
    # partitionin, because the former depends on the input rank
    i = np.arange(3*4*5*6).reshape(3, 4, 5, 6)
 
    def func(axisp, p, axiss):
        op = BlockDiagonalOperator(3*[Stretch(axiss)], partitionin=p,
                                   axisin=axisp)
        assert_eq(op(i), Stretch(axiss)(i))
    for axisp, p in zip(
            (0, 1, 2, 3, -1, -2, -3),
            ((1, 1, 1), (1, 2, 1), (2, 2, 1), (2, 3, 1), (2, 3, 1), (2, 2, 1),
             (1, 2, 1), (1, 1, 1))):
        for axiss in (0, 1, 2, 3):
            yield func, axisp, p, axiss
 
 
def test_partition3():
    # test axisin != axisout...
    pass
 
 
def test_partition4():
    o1 = HomothetyOperator(1, shapein=1)
    o2 = HomothetyOperator(2, shapein=2)
    o3 = HomothetyOperator(3, shapein=3)
 
    @flags.separable
    class Op(Operator):
        pass
    op = Op()
    p = BlockDiagonalOperator([o1, o2, o3], axisin=0)
    r = (op + p + op) * p
    assert isinstance(r, BlockDiagonalOperator)
 
 
def test_block1():
    ops = [HomothetyOperator(i, shapein=(2, 2)) for i in range(1, 4)]
 
    def func(axis, s):
        op = BlockDiagonalOperator(ops, new_axisin=axis)
        assert_eq(op.shapein, s)
        assert_eq(op.shapeout, s)
    for axis, s in zip(
            range(-3, 3),
            ((3, 2, 2), (2, 3, 2), (2, 2, 3), (3, 2, 2), (2, 3, 2),
             (2, 2, 3))):
        yield func, axis, s
 
 
def test_block2():
    shape = (3, 4, 5, 6)
    i = np.arange(np.product(shape)).reshape(shape)
 
    def func(axisp, axiss):
        op = BlockDiagonalOperator(shape[axisp]*[Stretch(axiss)],
                                   new_axisin=axisp)
        axisp_ = axisp if axisp >= 0 else axisp + 4
        axiss_ = axiss if axisp_ > axiss else axiss + 1
        assert_eq(op(i), Stretch(axiss_)(i))
    for axisp in (0, 1, 2, 3, -1, -2, -3):
        for axiss in (0, 1, 2):
            yield func, axisp, axiss
 
 
def test_block3():
    # test new_axisin != new_axisout...
    pass
 
 
def test_block4():
    o1 = HomothetyOperator(1, shapein=2)
    o2 = HomothetyOperator(2, shapein=2)
    o3 = HomothetyOperator(3, shapein=2)
 
    @flags.separable
    class Op(Operator):
        pass
    op = Op()
    p = BlockDiagonalOperator([o1, o2, o3], new_axisin=0)
    r = (op + p + op) * p
    assert isinstance(r, BlockDiagonalOperator)
 
 
def test_block_column1():
    I2 = IdentityOperator(2)
    I3 = IdentityOperator(3)
    assert_raises(ValueError, BlockColumnOperator, [I2, 2*I3], axisout=0)
    assert_raises(ValueError, BlockColumnOperator, [I2, 2*I3], new_axisout=0)
 
 
def test_block_column2():
    p = np.matrix([[1, 0], [0, 2], [1, 0]])
    o = asoperator(np.matrix(p))
    e = BlockColumnOperator([o, 2*o], axisout=0)
    assert_eq(e.todense(), np.vstack([p, 2*p]))
    assert_eq(e.T.todense(), e.todense().T)
    e = BlockColumnOperator([o, 2*o], new_axisout=0)
    assert_eq(e.todense(), np.vstack([p, 2*p]))
    assert_eq(e.T.todense(), e.todense().T)
 
 
def test_block_row1():
    I2 = IdentityOperator(2)
    I3 = IdentityOperator(3)
    assert_raises(ValueError, BlockRowOperator, [I2, 2*I3], axisin=0)
    assert_raises(ValueError, BlockRowOperator, [I2, 2*I3], new_axisin=0)
 
 
def test_block_row2():
    p = np.matrix([[1, 0], [0, 2], [1, 0]])
    o = asoperator(np.matrix(p))
    r = BlockRowOperator([o, 2*o], axisin=0)
    assert_eq(r.todense(), np.hstack([p, 2*p]))
    assert_eq(r.T.todense(), r.todense().T)
    r = BlockRowOperator([o, 2*o], new_axisin=0)
    assert_eq(r.todense(), np.hstack([p, 2*p]))
    assert_eq(r.T.todense(), r.todense().T)
 
 
def test_partition_implicit_commutative():
    partitions = (None, None), (2, None), (None, 3), (2, 3)
    ops = [I, 2*I]
 
    def func(op1, op2, p1, p2, cls):
        op = operation([op1, op2])
        assert type(op) is cls
        if op.partitionin is None:
            assert op1.partitionin is op2.partitionin is None
        else:
            assert op.partitionin == merge_none(p1, p2)
        if op.partitionout is None:
            assert op1.partitionout is op2.partitionout is None
        else:
            assert op.partitionout == merge_none(p1, p2)
    for operation in (AdditionOperator, MultiplicationOperator):
        for p1 in partitions:
            for p2 in partitions:
                for cls, aout, ain, pout1, pin1, pout2, pin2 in zip(
                        (BlockRowOperator, BlockDiagonalOperator,
                         BlockColumnOperator),
                        (None, 0, 0), (0, 0, None), (None, p1, p1),
                        (p1, p1, None), (None, p2, p2), (p2, p2, None)):
                    op1 = BlockOperator(
                        ops, partitionout=pout1, partitionin=pin1, axisin=ain,
                        axisout=aout)
                    op2 = BlockOperator(
                        ops, partitionout=pout2, partitionin=pin2, axisin=ain,
                        axisout=aout)
                    yield func, op1, op2, p1, p2, cls
 
 
def test_partition_implicit_composition():
    partitions = (None, None), (2, None), (None, 3), (2, 3)
    ops = [I, 2*I]
 
    def func(op1, op2, pin1, pout2, cls):
        op = op1 * op2
        assert_is_instance(op, cls)
        if not isinstance(op, BlockOperator):
            return
        pout = None if isinstance(op, BlockRowOperator) else \
               merge_none(pin1, pout2)
        pin = None if isinstance(op, BlockColumnOperator) else \
              merge_none(pin1, pout2)
        assert pout == op.partitionout
        assert pin == op.partitionin
    for pin1 in partitions:
        for pout2 in partitions:
            for cls1, cls2, cls, aout1, ain1, aout2, ain2, pout1, pin2, in zip(
                    (BlockRowOperator, BlockRowOperator, BlockDiagonalOperator,
                     BlockDiagonalOperator),
                    (BlockDiagonalOperator, BlockColumnOperator,
                     BlockDiagonalOperator, BlockColumnOperator),
                    (BlockRowOperator, HomothetyOperator,
                     BlockDiagonalOperator, BlockColumnOperator),
                    (None, None, 0, 0), (0, 0, 0, 0), (0, 0, 0, 0),
                    (0, None, 0, None), (None, None, pin1, pin1),
                    (pout2, None, pout2, None)):
                op1 = BlockOperator(ops, partitionin=pin1, partitionout=pout1,
                                    axisout=aout1, axisin=ain1)
                op2 = BlockOperator(ops, partitionout=pout2, partitionin=pin2,
                                    axisout=aout2, axisin=ain2)
                yield func, op1, op2, pin1, pout2, cls
 
 
def test_mul():
    opnl = Operator(shapein=10, flags='square')
    oplin = Operator(flags='linear,square', shapein=10)
    clss = ((BlockRowOperator, BlockDiagonalOperator, BlockRowOperator),
            3 * (BlockDiagonalOperator,),
            (BlockDiagonalOperator, BlockColumnOperator, BlockColumnOperator),
            (BlockRowOperator, BlockColumnOperator, AdditionOperator))
 
    def func(op, cls1, cls2, cls3):
        operation = CompositionOperator \
                    if op.flags.linear else MultiplicationOperator
        op1 = cls1(3*[op], axisin=0)
        op2 = cls2(3*[op], axisout=0)
        result = op1 * op2
        assert_is_type(result, cls3)
        assert_is_type(result.operands[0], operation)
    for op in opnl, oplin:
        for cls1, cls2, cls3 in clss:
            yield func, op, cls1, cls2, cls3