Require Coq.Setoids.Setoid.

Reserved Notation "x >>= f" (at level 42, left
associativity).

Reserved Notation "m [] n"  (at level 42, left
associativity).

Inductive monad_syn (X : Type) : Type :=
  | ret : X -> monad_syn X
  | mplus : monad_syn X -> monad_syn X -> monad_syn X
  | mzero : monad_syn X
  | bind : forall (Y :Type), monad_syn Y -> (Y -> monad_syn X) -> monad_syn X.

Notation "x >>= f" := (bind _ _ x f).

Notation " m [] n" := (mplus _ m n).

Require String. Open Scope string_scope.

Ltac move_to_top x :=
  match reverse goal with
  | H : _ |- _ => try move x after H
  end.

Tactic Notation "assert_eq" ident(x) constr(v) :=
  let H := fresh in
  assert (x = v) as H by reflexivity;
  clear H.

Tactic Notation "Case_aux" ident(x) constr(name) :=
  first [
    set (x := name); move_to_top x
  | assert_eq x name; move_to_top x
  | fail 1 "because we are working on a different case" ].

Tactic Notation "Case" constr(name) := Case_aux Case name.
Tactic Notation "SCase" constr(name) := Case_aux SCase name.
Tactic Notation "SSCase" constr(name) := Case_aux SSCase name.
Tactic Notation "SSSCase" constr(name) := Case_aux SSSCase name.
Tactic Notation "SSSSCase" constr(name) := Case_aux SSSSCase name.
Tactic Notation "SSSSSCase" constr(name) := Case_aux SSSSSCase name.
Tactic Notation "SSSSSSCase" constr(name) := Case_aux SSSSSSCase name.
Tactic Notation "SSSSSSSCase" constr(name) := Case_aux SSSSSSSCase name.

Axiom mplus_assoc : forall (X : Type) (m n p : monad_syn X),
  (m [] n)[] p =  m [] (n [] p).

Axiom mzero_left : forall (X: Type) (m : monad_syn X),
  (mzero X) [] m = m.

Axiom mzero_right : forall (X : Type) (m : monad_syn X),
   m [] (mzero X) = m.

Axiom mzero_bind : forall (X Y : Type) (f : X -> monad_syn Y),
  (mzero X) >>= f = mzero Y.

Axiom ret_not_mzero : forall (X :Type) (x :X),
  not (ret X x = mzero X).

Axiom mplus_mzero_both_mzero : forall X (m n : monad_syn X),
   m [] n = mzero X <-> (m = mzero X) /\ (n = mzero X).

Axiom ret_bind : forall (X Y : Type) (x : X) (f : X -> monad_syn Y),
  (ret X x) >>=  f = f x.

Axiom mplus_bind : forall (X Y:Type)(m n:monad_syn X)(f:X->monad_syn Y),
  (m [] n) >>= f = (m >>= f) [] (n >>= f).

Axiom bind_mzero_f_mzero:forall (X Y:Type)(m:monad_syn X)(f:X->monad_syn Y),
  not(m = mzero X) -> 
    (m >>= f = mzero Y) -> f = (fun z => (mzero Y)).

Axiom mzero_or_not_mzero : forall (X: Type) (m : monad_syn X),
  m = mzero X \/ not(m = mzero X).

Axiom mzero_f_bind : forall (X Y : Type) (m : monad_syn X),
   m >>= (fun z => mzero Y) = mzero Y.

Theorem x : forall (X : Type) ( m : monad_syn X),
  m = mzero X \/ (exists (x : X), m = ret X x)\/ 
  exists (n p : monad_syn X),
    (not (n = mzero X)) /\
    (not (p = mzero X)) /\
    m = mplus X n p.
Proof. intros. induction m.
  Case"m = return x".
    right. left. exists x. reflexivity.
  Case"m = m1 [] m2".
    destruct IHm1.
    SCase"m1 = mzero".
      destruct IHm2. rewrite -> H. rewrite -> H0.
      rewrite -> mzero_left. left. reflexivity. destruct H0. destruct H0.
      rewrite -> H. rewrite -> H0. rewrite -> mzero_left. right. left.
      exists x. reflexivity. destruct H0. destruct H0. rewrite -> H.
      destruct H0. destruct H1. rewrite -> mzero_left. right. right. 
      exists x. exists x0.  split.  apply H0. split.  apply H1. apply H2.
      destruct H.
    SCase"m1 = return x".
      destruct IHm2. rewrite -> H0. rewrite -> mzero_right.
      destruct H. rewrite -> H. right. left. exists x.
      reflexivity. destruct H0. right. right. destruct H. destruct H0.
      exists (ret X x). exists (ret X x0). split. apply ret_not_mzero.
      split. apply ret_not_mzero. rewrite -> H. rewrite -> H0. reflexivity.
      destruct H0. destruct H0. destruct H0. destruct H1. rewrite -> H2.
      rewrite <- mplus_assoc. right. right. exists (mplus X m1 x).
      exists x0. split. rewrite -> mplus_mzero_both_mzero. unfold not. 
      intro. destruct H3. contradiction. split. apply H1. reflexivity.
    SCase"m1 = n [] p".
      destruct IHm2. rewrite -> H0. rewrite -> mzero_right. right. right.
      apply H. destruct H0.  destruct H0. destruct H. destruct H.
      destruct H. destruct H1. rewrite -> H2. right. right. 
      exists (mplus X x0 x1). exists m2. split.
      rewrite -> mplus_mzero_both_mzero. unfold not. intro. destruct H3. 
      contradiction. split. rewrite -> H0. apply ret_not_mzero.
      reflexivity. destruct H. destruct H0. destruct H. destruct H0.
      destruct H. destruct H0. destruct H1. destruct H2. right. right.
      exists (mplus X x x1). exists (mplus X x0 x2). split. 
      rewrite -> mplus_mzero_both_mzero. unfold not. intro. 
      destruct H5. contradiction. split. rewrite -> mplus_mzero_both_mzero.
      unfold not. intro. destruct H5. contradiction. rewrite -> H4.
      rewrite -> H3. reflexivity.
  Case"m = mzero".
    left. reflexivity.
  Case"m = n >>= f".
    destruct IHm.
    SCase"n = mzero".
      rewrite -> H0. rewrite -> mzero_bind. left. reflexivity.
    destruct H0.
    SCase"n = return x".
      destruct H0. rewrite -> H0. rewrite -> ret_bind. apply H.
    destruct H0.
    SCase"n = n1 [] n2".
      destruct H0. destruct H0. destruct H1. rewrite -> H2. 
      rewrite -> mplus_bind. 
      assert(H3 : bind _ _ x0 m0 = (mzero X) -> (m0 = (fun x => (mzero X)))).
      SSCase"Proof of Assertion".
        apply bind_mzero_f_mzero. apply H1.
      assert(H4 : bind _ _ x m0 = (mzero X) -> (m0 = (fun x => (mzero X)))).
      SSCase"Proof of Assertion".
        apply bind_mzero_f_mzero. apply H0.
      assert(H5:bind _ _ x m0=(mzero X) \/ not(bind _ _ x m0 = (mzero X))).
      SSCase"Proof of Assertion".
        apply mzero_or_not_mzero.
      assert(H6 : bind _ _ x0 m0=(mzero X) \/ not(bind _ _ x0 m0=(mzero X))).
      SSCase"Proof of Assertion".
        apply mzero_or_not_mzero.
      destruct H5. apply H4 in H5. rewrite -> H5. left. rewrite->mzero_f_bind.
      rewrite -> mzero_left. rewrite -> mzero_f_bind. reflexivity.
      destruct H6. apply H3 in H6. rewrite -> H6.
      left. rewrite -> mzero_f_bind. rewrite -> mzero_left.
      rewrite -> mzero_f_bind. reflexivity. right. right.
      exists (bind X Y x m0). exists (bind X Y x0 m0).
      split. apply H5. split. apply H6. reflexivity. Qed.

