Set Implicit Arguments.

Require Import List.
Require String. Open Scope string_scope.

Ltac move_to_top x :=
  match reverse goal with
  | H : _ |- _ => try move x after H
  end.

Tactic Notation "assert_eq" ident(x) constr(v) :=
  let H := fresh in
  assert (x = v) as H by reflexivity;
  clear H.

Tactic Notation "Case_aux" ident(x) constr(name) :=
  first [
    set (x := name); move_to_top x
  | assert_eq x name; move_to_top x
  | fail 1 "because we are working on a different case" ].

Tactic Notation "Case" constr(name) := Case_aux Case name.
Tactic Notation "SCase" constr(name) := Case_aux SCase name.
Tactic Notation "SSCase" constr(name) := Case_aux SSCase name.
Tactic Notation "SSSCase" constr(name) := Case_aux SSSCase name.
Tactic Notation "SSSSCase" constr(name) := Case_aux SSSSCase name.
Tactic Notation "SSSSSCase" constr(name) := Case_aux SSSSSCase name.
Tactic Notation "SSSSSSCase" constr(name) := Case_aux SSSSSSCase name.
Tactic Notation "SSSSSSSCase" constr(name) := Case_aux SSSSSSSCase name.
(* Notations *)
Reserved Notation "x >>= f" (at level 42, left
associativity).

Reserved Notation "x >>- f" (at level 42, left
associativity).

Reserved Notation "m >> n"  (at level 42, left
associativity).

Reserved Notation "m [] n"  (at level 42, left
associativity).

Reserved Notation "m << n"   (at level 42, left
associativity).

Reserved Notation "m <<= n"   (at level 42, left
associativity).

Class Monad (M : Type -> Type) := {
  bind : forall A B,
    M A -> (A -> M B) -> M B;
  unit : forall A , A -> M A;
  bind_assoc :
    forall A B C
    (m : M A)
    (f : A -> M B) (g : B -> M C),
    bind (bind m f) g = bind m (fun i => bind (f i) g);
  right_unit : forall A (m : M A),
    bind m (@unit A) = m;
  left_unit : forall A B (x : A) (f : A -> M B),
    bind (unit x) f = f x
}.

Notation "x >>= f" := (bind _  x f).

Definition join (M : Type -> Type) `{Monad M} (X : Type) (m : M (M X)): M X :=
  m >>= id.

Definition map (M:Type->Type)(X Y:Type)(f:X->Y)`{Monad M}(m:M X):M Y :=
  m >>= (fun x => unit (f x)).

Notation "x >>- f" := (map f x).


Class E_Monad (M : Type -> Type) :=   
{
 emonad_monad :> Monad M;
 epsilon : forall X, X -> M X -> Prop;
 epsilon_unit : forall X (x y : X), epsilon x (unit y) <-> x = y;
 epsilon_bind : forall X Y 
                (m : M X)
                (f : X -> M Y)
                (y : Y),
                epsilon y (m >>= f) <-> 
                exists (x : X), and (epsilon x m) (epsilon y (f x)) 
 }.


Theorem epsilon_join:forall X (M:Type->Type)`{E_Monad M}(m:M(M X))(x:X),
  ( exists (y : M X), ((epsilon y m) /\ (epsilon x y) ) )
   <->
  (epsilon  x (join _ m)).
Proof. intros. split.
 Case "->".
   intros. apply epsilon_bind in H0. unfold join. apply H0.
 Case "<-".
   intros. apply epsilon_bind. unfold join in H0. apply H0. Qed.

Theorem epsilon_map:forall X Y (M:Type->Type)`{E_Monad M}(m:M X)(f:X->Y)(y:Y),
  ( epsilon  y  ( m>>-f ) ) <->
   exists (x : X), (y = f x) /\ (epsilon x m).
Proof. intros. split.
  Case"->".
    intros. unfold map in H0. apply epsilon_bind in H0.
    destruct H0. exists x. destruct H0. split. apply epsilon_unit in H1.
    rewrite -> H1. reflexivity. apply H0.
  Case"<-".
    intros. destruct H0. unfold map. apply epsilon_bind. exists x.
    destruct H0. split. apply H1. apply epsilon_unit. apply H0. Qed.

Theorem all_pairs:forall X Y (M:Type->Type)`{E_Monad M}(l:M X)(m:M Y),
  exists (n : M  (prod X Y)),
  forall (z : X * Y),
      ( epsilon z n ) <-> ( epsilon (fst z) l) /\ (epsilon (snd z) m ).
Proof. intros. 
  exists ( l >>= (fun x => (m >>= (fun y => (unit (x , y)))))).
  intros. split.
  Case"->". 
    intros. split. apply epsilon_bind with (y :=z) in H0.
    destruct H0. destruct H0. apply epsilon_bind in H1. destruct H1.
    destruct H1. apply epsilon_unit in H2. rewrite -> H2. simpl.
    apply H0. apply epsilon_bind in H0. destruct H0. destruct H0.
    apply epsilon_bind in H1. destruct H1. destruct H1.
    apply epsilon_unit in H2. rewrite -> H2. simpl. apply H1.
  Case"<-".
    intros. destruct H0. apply epsilon_bind. exists (fst z).
    split. apply H0. apply epsilon_bind. exists (snd z). split.
    apply H1. assert(H2 : z = (fst z , snd z)). 
    apply surjective_pairing. rewrite <- H2. apply epsilon_unit. reflexivity.
  Qed.


Class Monad_Plus (M:Type->Type) :=   
{ 
  monad_plus_monad :> Monad M;
  mzero : forall X, M X;
  mplus : forall X, M X -> M X -> M X;
  mzero_bind : forall X Y (f : X -> M Y), (mzero X)>>= f = mzero Y;
  mzero_mplus_left : forall X (m : M X), mplus (mzero X) m = m;
  mzero_mplus_right : forall X (m : M X), mplus m (mzero X) = m;
  mplus_assoc : forall X (m n p : M X),
                  mplus m (mplus n p) = mplus (mplus m n) p ;
  mplus_bind : forall X Y (m n : M X) (f : X -> M Y),
                 (mplus m n) >>= f = mplus (m >>= f) ( n >>= f)
}.

Notation " m [] n" := (mplus _ m n).

Class E_Monad_Plus (M:Type->Type) :=   
{ 
  e_monad_plus_monad_plus :> Monad_Plus M;
  e_monad_plus_e_monad :> E_Monad M;
  mplus_epsilon : forall X (x : X)(m n : M X), 
                    epsilon x (mplus _ m  n) <->
                   ((epsilon x m) \/ (epsilon x n));
  mzero_epsilon :forall X (m : M X), 
                 (m = mzero X) <-> 
                 (forall (x : X), not (epsilon x m))
}.
Theorem mplus_mzero_both_mzero : 
  forall (X:Type) (M:Type->Type)`{E_Monad_Plus M}(m n:M X),
     m[]n = mzero X <-> (m = mzero X) /\ (n = mzero X).
Proof. intros. split.
  Case"->".
    intro. split. apply mzero_epsilon. intros.
    unfold not. intros. 
    assert(H2 : epsilon x (m[]n)).
    SCase"Proof of Assertion". 
      apply mplus_epsilon. left. apply H1.
    rewrite -> H0 in H2. apply mzero_epsilon in H2. apply H2.
    reflexivity. apply mzero_epsilon. intros. unfold not. intro.
    assert(H2 : epsilon x (m[]n)).
    SCase"Proof of Assertion".
      apply mplus_epsilon. right. apply H1.
    rewrite -> H0 in H2. apply mzero_epsilon in H2. apply H2. reflexivity.
  Case"<-".
    intro. destruct H0. rewrite -> H0. rewrite -> H1.
    apply mzero_mplus_left. Qed.

Theorem mplus_dist_join : forall X (M:Type->Type)`{Monad_Plus M}(m n:M(M X)), 
  join _ (m [] n) = (join _ m) [] (join _ n).
Proof. intros. unfold join. apply mplus_bind. Qed.

Theorem mplus_dist_map:
  forall X Y (M:Type->Type)`{Monad_Plus M}(m n:M X)(f:X->M Y),
    (m [] n) >>- f = (m >>-f) [] (n >>- f).
Proof. intros. unfold map. apply mplus_bind. Qed.

Class Monad_Plus_WF (M : Type -> Type) :=   
{ 
  monad_plus_wf_monad_plus :> Monad_Plus M;
  rel : forall X, M X -> M X -> Prop;
  well_found : forall X, well_founded (@rel X);
  mzero_min : forall X (m : M X), rel (mzero X) m;
  left_plus : forall X (m n: M X),
    not(n = mzero X) -> rel m (m [] n);
  right_plus : forall X (m n: M X), 
    not(m = mzero X) -> rel n (m[]n)
 }.

Class Fold_Monad (M : Type -> Type) :=   
{
 foldmonad_monad :> Monad M;
 fold : forall X Y,(X->Y->Y)->Y->(M X)->Y 
 }.

Class Fold_Monad_Plus (M : Type -> Type) :=
{
  fmonadplus_fmonad :> Fold_Monad M;
  fmonadplus_monadplus :> Monad_Plus M;
  fold_prop : forall X (m : M X),
    fold (fun x=>fun y=>((unit x)[]y))
         (mzero X)
          m = m
}.

Theorem fold_mplus:forall X (M:Type->Type)
            `{Fold_Monad_Plus M}(m n:M X),
            fold (fun x=>fun y=>((unit x)[]y)) 
                 (mzero X) (m[]n) =
            (fold (fun x=>fun y=>((unit x)[]y))
                  (mzero X) m) []
            (fold (fun x=>fun y=>((unit x)[]y))
                  (mzero X) n).
Proof. intros. repeat (rewrite -> fold_prop).
  reflexivity. Qed.

Section lists.
Theorem flat_map_app : forall (X Y:Type)(l:list X)(l':list X)(f:X->list Y),
     flat_map f (l ++ l') = (flat_map f l) ++ (flat_map f l').
Proof. intros. induction l as [| h t].
  Case"Base Case".
    auto.
  Case"Induction Step".
    simpl. rewrite -> IHt. apply app_assoc. Qed. 

Definition my_flat_map (A B : Type):=
 fun (a: list A) => fun (b:A->list B) => flat_map b a.

Theorem list_bind_assoc : 
   forall X Y Z
         (l : list X)
         (f : X -> list Y)
         (g : Y -> list Z),
    my_flat_map (my_flat_map l f) g=my_flat_map l (fun i=> my_flat_map (f i) g).
Proof. intros. induction l as [| h t].
   Case"Base Case".
     auto.
   Case"Induction Step".
     simpl. rewrite <- IHt. apply flat_map_app. Qed.

Theorem list_bind_unit : forall X Y  (x : X) (f : X -> list Y),
   my_flat_map  (x::nil) f = f x. 
Proof. intros. simpl. apply app_nil_r. Qed.

Theorem list_unit_bind : forall X (l : list X),
   my_flat_map l (fun x => x::nil) = l.
Proof. intros. induction l as [| h t].
  Case"Base Case".
    simpl. reflexivity.
  Case"Induction Step".
    simpl. rewrite -> IHt. reflexivity. Qed.

Instance list_Monad : Monad list :=
{
bind := my_flat_map;
unit := fun X => fun x => x::nil;
bind_assoc := list_bind_assoc;
left_unit := list_bind_unit;
right_unit := list_unit_bind
}. 

Theorem list_epsilon_unit : forall X (x y : X),
  In x (y::nil) <-> x = y.
Proof. intros. simpl. split.
  Case"->".
    intro. destruct H. rewrite -> H.
    reflexivity. inversion H.
  Case"<-".
    intros. left. rewrite -> H. reflexivity. Qed.

Theorem list_epsilon_bind : 
 forall X Y (m : list X)(f : X -> list Y) (y : Y),
     In y (my_flat_map m f) <-> exists (x : X),
                                (In x m) /\ (In y (f x)).
Proof. intros. split.
  Case"->".
    intro. induction m as [| h t].
    SCase"Base Case".
      inversion H.
    SCase"Induction Step".
      simpl in H. apply in_app_or in H. destruct H.
      exists h. split. apply in_eq. apply H. apply IHt in H.
      destruct H. destruct H. exists x. split. apply in_cons.
      apply H. apply H0.
  Case"<-".
    intro. induction m as [| h t].
    SCase"Base Case".
      inversion H. destruct H0. inversion H0.
    SCase"Induction Step".
      simpl. destruct H. destruct H. apply in_or_app. simpl in H.
      destruct H. left. rewrite -> H. apply H0. right. apply IHt.
      exists x. split. apply H. apply H0. Qed.

Instance list_emonad : E_Monad list :={
 emonad_monad := list_Monad;
 epsilon := In;
 epsilon_bind := list_epsilon_bind;
 epsilon_unit := list_epsilon_unit
}.

Theorem list_mzero_bind : forall X Y  (f : X -> list Y),
  nil >>= f = nil.
Proof. intros. reflexivity. Qed.

Theorem list_mplus_epsilon : forall X (x : X) (m n:list X),
  epsilon x (m ++ n) <-> 
  epsilon x m \/ epsilon x n.
Proof. intros. split. apply in_app_or. apply in_or_app. Qed.

Theorem list_mzero_epsilon : forall X (x : X), not (epsilon x nil).
Proof. auto. Qed.

Instance list_monad_plus : Monad_Plus list := {
  monad_plus_monad := list_Monad;
  mzero := fun X => nil;
  mplus := app;
  mzero_bind := list_mzero_bind;
  mzero_mplus_left := app_nil_l;
  mzero_mplus_right := app_nil_r;
  mplus_assoc := app_assoc;
  mplus_bind := flat_map_app
}.
End lists.

Section Trees.


Inductive Tree (A: Type): Type :=
  | Leaf : A -> Tree A
  | Branch : Tree A -> Tree A -> Tree A.

Definition tree_unit (A : Type) (a : A) : Tree A := Leaf a.

Fixpoint tree_bind (A B: Type) (m : Tree A) (f:A->Tree B) : Tree B :=
  match m with
  | Leaf x => f x
  | Branch l r => Branch (tree_bind l f) (tree_bind r f)
  end.

Theorem tree_bind_unit : forall X Y (x : X) (f : X -> Tree Y),
   tree_bind (Leaf x) f = f x.
Proof. intros. reflexivity. Qed.

Theorem tree_unit_bind : forall X (t : Tree X),
   tree_bind t (fun x => (Leaf x)) = t.
Proof. intros.  induction t as [| l r].
  Case"Base Case".
    reflexivity.
  Case"Induction Step".
    simpl. rewrite -> IHt1. rewrite -> r. reflexivity. Qed.

Theorem tree_bind_assoc : 
   forall X Y Z
         (t : Tree X)
         (f : X -> Tree Y)
         (g : Y -> Tree Z),
    tree_bind (tree_bind t f) g = tree_bind t (fun i=> tree_bind (f i) g).
Proof. intros. induction t as [| l r].
   Case"Base Case".
     auto.
   Case"Induction Step".
     simpl. rewrite <- IHt1. rewrite -> r. reflexivity. Qed.

Fixpoint in_tree (A : Type)(a : A)(t : Tree A) : Prop :=
  match t with
  | Leaf x => (x=a)
  | Branch l r => or (in_tree a l) (in_tree a r)
  end.

Theorem tree_epsilon_unit : forall X (x y : X),
  in_tree x (Leaf y) <-> x = y.
Proof. intros. split.
  Case"->".
    intro. simpl in H. rewrite -> H. reflexivity.
  Case"<-".
    intro. rewrite H. reflexivity. Qed.

Lemma in_branch_or : forall (A : Type)(a : A)(l r : Tree A),
  ( (in_tree a l) \/ (in_tree a r)) -> in_tree a (Branch l r).
Proof. intros. simpl. apply H. Qed.

Theorem tree_epsilon_bind : 
 forall X Y (t : Tree X)(f : X -> Tree Y) (y : Y),
     in_tree y (tree_bind t f) <-> 
     exists (x : X), (in_tree x t) /\ (in_tree y (f x)).
Proof. intros. split.
  Case "->".
    induction t as [| l r].
    SCase"Base Case".
      intros. simpl in H. exists a. split. reflexivity. apply H.
    SCase"Induction Step".
      intros. simpl in H. destruct H. apply r in H.
      destruct H as [z z_prop]. exists z. split. simpl. 
      left. apply z_prop. apply z_prop.
      apply IHt1 in H. destruct H as [z z_prop]. exists z.
      split. simpl. right. destruct z_prop. apply H.
      destruct z_prop. apply H0.
  Case"<-".
    intros. induction t as [| l r].
    SCase"Base Case".
      simpl. destruct H as [x x_prop]. destruct x_prop. simpl in H.
      rewrite -> H. apply H0.
    SCase"Induction Step". 
      simpl. destruct H as [x x_prop]. destruct x_prop. simpl in H.
      destruct H. left. apply r. exists x. split. apply H. apply H0.
      right. apply IHt1. exists x. split. apply H. apply H0. Qed.

Instance tree_Monad : Monad Tree :=
{
  bind := tree_bind;
  unit := tree_unit;
  bind_assoc := tree_bind_assoc;
  left_unit := tree_bind_unit;
  right_unit := tree_unit_bind
}.

Instance tree_Emonad : E_Monad Tree:={
  emonad_monad := tree_Monad;
  epsilon := in_tree;
  epsilon_bind := tree_epsilon_bind;
  epsilon_unit := tree_epsilon_unit
}.

End Trees.





