diff --git a/_CoqProject b/_CoqProject index 4a35dbab1..e05597f3c 100644 --- a/_CoqProject +++ b/_CoqProject @@ -136,3 +136,11 @@ theories/showcase/pnt.v analysis_stdlib/Rstruct_topology.v analysis_stdlib/showcase/uniform_bigO.v +theories/prob_lang.v +theories/prob_lang_wip.v +theories/lang_syntax_util.v +theories/lang_syntax_toy.v +theories/lang_syntax.v +theories/lang_syntax_examples.v +theories/lang_syntax_table_game.v +theories/lang_syntax_noisy.v diff --git a/coq-mathcomp-analysis.opam b/coq-mathcomp-analysis.opam index a0890eb50..0d54ca9fd 100644 --- a/coq-mathcomp-analysis.opam +++ b/coq-mathcomp-analysis.opam @@ -19,6 +19,7 @@ depends: [ "coq-mathcomp-solvable" "coq-mathcomp-field" "coq-mathcomp-bigenough" { (>= "1.0.0") } + "coq-mathcomp-algebra-tactics" { (>= "1.2.4") } ] tags: [ diff --git a/coq-mathcomp-classical.opam b/coq-mathcomp-classical.opam index 423b07edc..eebea7728 100644 --- a/coq-mathcomp-classical.opam +++ b/coq-mathcomp-classical.opam @@ -20,8 +20,8 @@ depends: [ "coq-mathcomp-ssreflect" { (>= "2.4.0" & < "2.5~") | (= "dev") } "coq-mathcomp-fingroup" "coq-mathcomp-algebra" - "coq-mathcomp-finmap" { (>= "2.1.0") } - "coq-hierarchy-builder" { (>= "1.8.0") } + "coq-mathcomp-finmap" { (>= "2.2.0") } + "coq-hierarchy-builder" { (>= "1.8.1") } ] tags: [ diff --git a/theories/Make b/theories/Make index 98c1b9889..2e259cc60 100644 --- a/theories/Make +++ b/theories/Make @@ -97,3 +97,11 @@ pi_irrational.v gauss_integral.v all_analysis.v showcase/summability.v +prob_lang.v +prob_lang_wip.v +lang_syntax_util.v +lang_syntax_toy.v +lang_syntax.v +lang_syntax_examples.v +lang_syntax_table_game.v +lang_syntax_noisy.v diff --git a/theories/lang_syntax.v b/theories/lang_syntax.v new file mode 100644 index 000000000..9e5736004 --- /dev/null +++ b/theories/lang_syntax.v @@ -0,0 +1,1482 @@ +Require Import String. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval. +From mathcomp Require Import interval_inference. +From mathcomp Require Import unstable mathcomp_extra boolp classical_sets. +From mathcomp Require Import functions cardinality fsbigop. +From mathcomp Require Import reals ereal topology normedtype sequences exp. +From mathcomp Require Import esum measure lebesgue_measure numfun derive realfun. +From mathcomp Require Import lebesgue_integral probability ftc kernel charge. +From mathcomp Require Import prob_lang lang_syntax_util. + +(**md**************************************************************************) +(* # Syntax and Evaluation for a Probabilistic Programming Language *) +(* *) +(* Reference: *) +(* - R. Saito, R. Affeldt. Experimenting with an Intrinsically-Typed *) +(* Probabilistic Programming Language in Coq using s-finite kernels in Coq. *) +(* APLAS 2023 *) +(* *) +(* ``` *) +(* typ == syntax for types of data structures *) +(* measurable_of_typ t == the measurable type corresponding to type t *) +(* It is of type {d & measurableType d} *) +(* mtyp_disp t == the display corresponding to type t *) +(* mtyp t == the measurable type corresponding to type t *) +(* It is of type measurableType (mtyp_disp t) *) +(* measurable_of_seq s == the product space corresponding to the *) +(* list s : seq typ *) +(* It is of type {d & measurableType d} *) +(* acc_typ s n == function that access the nth element of s : seq typ *) +(* It is a function from projT2 (measurable_of_seq s) *) +(* to projT2 (measurable_of_typ (nth Unit s n)) *) +(* ctx == type of context *) +(* := seq (string * type) *) +(* mctx_disp g == the display corresponding to the context g *) +(* mctx g := the measurable type corresponding to the context g *) +(* It is formed of nested pairings of measurable *) +(* spaces. It is of type measurableType (mctx_disp g) *) +(* flag == a flag is either D (deterministic) or *) +(* P (probabilistic) *) +(* exp f g t == syntax of expressions with flag f of type t *) +(* context g *) +(* dval R g t == "deterministic value", i.e., *) +(* function from mctx g to mtyp t *) +(* pval R g t == "probabilistic value", i.e., *) +(* s-finite kernel, from mctx g to mtyp t *) +(* e -D> f ; mf == the evaluation of the deterministic expression e *) +(* leads to the deterministic value f *) +(* (mf is the proof that f is measurable) *) +(* e -P> k == the evaluation of the probabilistic function f *) +(* leads to the probabilistic value k *) +(* execD e == a dependent pair of a function corresponding to the *) +(* evaluation of e and a proof that this function is *) +(* measurable *) +(* execP e == a s-finite kernel corresponding to the evaluation *) +(* of the probabilistic expression e *) +(* ``` *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Import Order.TTheory GRing.Theory Num.Def Num.Theory. +Import numFieldTopology.Exports. + +Reserved Notation "e -D> f ; mf" (at level 40). +Reserved Notation "e -P> k" (at level 40). + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. + +(* In this module, we use the lemma continuous_FTC2 to compute the value of + integration of the indicator function over the interval [0, 1]. + We can use the lemma continuous_FTC2 because it requires continuity within + [0, 1], which the indicator function satisfies. + We also show that the indicator function is not continuous in [0, 1]. + This shows that the lemma continuous_FTC2 is * enough weak to be usable + in practice. *) +Module integral_indicator_function. +Section integral_indicator_function. + +Context {R : realType}. +Notation mu := lebesgue_measure. +Local Open Scope ereal_scope. +Implicit Types (f : R -> R) (a b : R). + +Local Import set_interval. + +Let uni := @indic R R `[0%R, 1%R]%classic. + +Let integrable_uni : mu.-integrable setT (EFin \o uni). +Proof. +apply/integrableP; split. + apply: measurableT_comp => //. + exact: measurable_indic. +apply/abse_integralP => //. + apply: measurableT_comp => //. + exact: measurable_indic. +rewrite -ge0_fin_numE; last exact: abse_ge0. +rewrite abse_fin_num integral_indic// setIT. +by rewrite /= lebesgue_measure_itv ifT. +Qed. + +Let cuni_within : {within `[0%R, 1%R], continuous uni}. +Proof. +apply/continuous_within_itvP => //; split. +- move=> x x01. + apply: (@near_cst_continuous R R 1%R). + near=> z. + rewrite /uni indic_restrict patchE ifT//. + rewrite inE/=. + apply: subset_itv_oo_cc. + near: z. + exact: near_in_itvoo. +- rewrite (_: uni 0 = 1%R); last first. + rewrite /uni indic_restrict patchE ifT//. + by rewrite inE/= boundl_in_itv bnd_simp/=. + apply: cvg_near_cst. + near=> z. + rewrite /uni indic_restrict patchE ifT// inE/= in_itv/=; apply/andP; split => //. + near: z. + exact: nbhs_right_le. +- rewrite (_:uni 1 = 1%R); last first. + rewrite /uni indic_restrict patchE ifT//. + by rewrite inE/= boundr_in_itv bnd_simp/=. + apply: cvg_near_cst. + near=> z. + rewrite /uni indic_restrict patchE ifT// inE/= in_itv/=; apply/andP; split => //. + near: z. + exact: nbhs_left_ge. +Unshelve. all: end_near. Qed. + +Example cuni : ~ {in `[0%R, 1%R], continuous uni}. +Proof. +rewrite -existsNE/=. +exists 0%R. +rewrite not_implyE; split; first by rewrite boundl_in_itv/= bnd_simp. +move/left_right_continuousP. +apply/not_andP; left. +move/(@cvgrPdist_le _ R^o). +apply/existsNP. +exists 2%:R^-1%R. +rewrite not_implyE; split; first by rewrite invr_gt0. +move=> [e /= e0]. +move/(_ (-(e / 2))%R). +apply/not_implyP; split. + rewrite /= sub0r opprK ger0_norm; last by rewrite divr_ge0// ltW. + rewrite -{1}(add0r e). + exact: (midf_lt e0).2. +apply/not_implyP; split. + rewrite oppr_lt0. + exact: divr_gt0. +apply/negP; rewrite -ltNge. +rewrite /uni !indic_restrict !patchE. +rewrite ifT; last by rewrite inE/= boundl_in_itv/= bnd_simp. +rewrite ifF; last first. + apply: negbTE; apply/negP. + rewrite inE/= in_itv/=. + apply/negP; rewrite negb_and; apply/orP; left. + by rewrite -ltNge oppr_lt0 divr_gt0. +rewrite /point/= {2}/1%R/= subr0. +rewrite ger0_norm//. +rewrite invf_lt1//. +rewrite {1}(_:1%R = 1%:R)//; apply: ltr_nat. +Qed. + +Let dintuni : derivable_oo_continuous_bnd (@id R^o) 0 1. +Proof. +split. +- move=> x _. + exact: derivable_id. +- exact: cvg_at_right_filter. +- exact: cvg_at_left_filter. +Qed. + +Let intuni'uni : {in `]0%R, 1%R[, (@id R^o)^`() =1 uni}. +Proof. +move=> x x01. +rewrite derive1E derive_id. +rewrite /uni indic_restrict patchE ifT// inE/=. +exact: subset_itv_oo_cc. +Qed. + +Lemma intuni1 : (\int[mu]_(x in `[0, 1]) uni x)%R = 1%R. +Proof. +rewrite [RHS](_:1%R = fine (1%:E))//; congr (fine _). +rewrite (continuous_FTC2 ltr01 cuni_within dintuni intuni'uni). +by rewrite sube0. +Qed. + +End integral_indicator_function. +End integral_indicator_function. + +Declare Scope lang_scope. +Delimit Scope lang_scope with P. + +Section syntax_of_types. +Import Notations. +Context {R : realType}. + +Inductive typ := +| Unit | Bool | Nat | Real +| Pair : typ -> typ -> typ +| Prob : typ -> typ. + +HB.instance Definition _ := gen_eqMixin typ. + +Fixpoint measurable_of_typ (t : typ) : {d & measurableType d} := + match t with + | Unit => existT _ _ munit + | Bool => existT _ _ mbool + | Nat => existT _ _ (nat : measurableType _) + | Real => existT _ _ + [the measurableType _ of (@measurableTypeR R)] + (* (Real_sort__canonical__measure_Measurable R) *) + | Pair A B => existT _ _ + [the measurableType (projT1 (measurable_of_typ A), + projT1 (measurable_of_typ B)).-prod%mdisp of + (projT2 (measurable_of_typ A) * + projT2 (measurable_of_typ B))%type] + | Prob A => existT _ _ (pprobability (projT2 (measurable_of_typ A)) R) + end. + +Definition mtyp_disp t : measure_display := projT1 (measurable_of_typ t). + +Definition mtyp t : measurableType (mtyp_disp t) := + projT2 (measurable_of_typ t). + +Definition measurable_of_seq (l : seq typ) : {d & measurableType d} := + iter_mprod (List.map measurable_of_typ l). + +End syntax_of_types. +Arguments measurable_of_typ {R}. +Arguments mtyp {R}. +Arguments measurable_of_seq {R}. + +Section accessor_functions. +Context {R : realType}. + +(* NB: almost the same as acc (map (@measurable_of_typ R) s) n l, + modulo commutativity of map and measurable_of_typ *) +Fixpoint acc_typ (s : seq typ) n : + projT2 (@measurable_of_seq R s) -> + projT2 (measurable_of_typ (nth Unit s n)) := + match s return + projT2 (measurable_of_seq s) -> projT2 (measurable_of_typ (nth Unit s n)) + with + | [::] => match n with | 0 => (fun=> tt) | m.+1 => (fun=> tt) end + | a :: l => match n with + | 0 => fst + | m.+1 => fun H => @acc_typ l m H.2 + end + end. + +(*Definition acc_typ : forall (s : seq typ) n, + projT2 (@measurable_of_seq R s) -> + projT2 (@measurable_of_typ R (nth Unit s n)). +fix H 1. +intros s n x. +destruct s as [|s]. + destruct n as [|n]. + exact tt. + exact tt. +destruct n as [|n]. + exact (fst x). +rewrite /=. +apply H. +exact: (snd x). +Show Proof. +Defined.*) + +Lemma measurable_acc_typ (s : seq typ) n : measurable_fun setT (@acc_typ s n). +Proof. +elim: s n => //= h t ih [|m]; first exact: measurable_fst. +by apply: (measurableT_comp (ih _)); exact: measurable_snd. +Qed. + +End accessor_functions. +Arguments acc_typ {R} s n. +Arguments measurable_acc_typ {R} s n. + +Section context. +Variables (R : realType). +Definition ctx := seq (string * typ). + +Definition mctx_disp (g : ctx) := projT1 (@measurable_of_seq R (map snd g)). + +Definition mctx (g : ctx) : measurableType (mctx_disp g) := + projT2 (@measurable_of_seq R (map snd g)). + +End context. +Arguments mctx {R}. + +Section syntax_of_expressions. +Context {R : realType}. + +Inductive flag := D | P. + +(* +Section uniop. + +Inductive uniop := +| uniop_not +| uniop_neg | uniop_inv. + +Definition type_of_uniop (u : uniop) : typ := +match u with +| uniop_not => Bool +| uniop_neg => Real +| uniop_inv => Real +end. + +Definition fun_of_uniop g (u : uniop) : (mctx g -> mtyp (type_of_uniop u)) -> + @mctx R g -> @mtyp R (type_of_uniop u) := +match u with +| uniop_not => (fun f x => f x && f x : mtyp Bool) +| uniop_neg => (fun f => (\- f)%R) +| uniop_inv => (fun f => (f ^-1)%R) +end. + +Definition mfun_of_uniop g b + (f : @mctx R g -> @mtyp R (type_of_uniop b)) (mf : measurable_fun setT f) + measurable_fun [set: @mctx R g] (fun_of_uniop f). +destruct b. +exact: measurable_and mf1 mf2. +exact: measurable_or mf1 mf2. +exact: measurable_funD. +exact: measurable_funB. +exact: measurable_funM. +Defined. + +End uniop. +*) + +Section binop. + +Inductive binop := +| binop_and | binop_or +| binop_add | binop_minus | binop_mult. + +Definition type_of_binop (b : binop) : typ := +match b with +| binop_and => Bool +| binop_or => Bool +| binop_add => Real +| binop_minus => Real +| binop_mult => Real +end. + +Definition fun_of_binop g (b : binop) : (mctx g -> mtyp (type_of_binop b)) -> + (mctx g -> mtyp (type_of_binop b)) -> @mctx R g -> @mtyp R (type_of_binop b) := +match b with +| binop_and => (fun f1 f2 x => f1 x && f2 x : mtyp Bool) +| binop_or => (fun f1 f2 x => f1 x || f2 x : mtyp Bool) +| binop_add => (fun f1 f2 => (f1 \+ f2)%R) +| binop_minus => (fun f1 f2 => (f1 \- f2)%R) +| binop_mult => (fun f1 f2 => (f1 \* f2)%R) +end. + +Definition mfun_of_binop g b + (f1 : @mctx R g -> @mtyp R (type_of_binop b)) (mf1 : measurable_fun setT f1) + (f2 : @mctx R g -> @mtyp R (type_of_binop b)) (mf2 : measurable_fun setT f2) : + measurable_fun [set: @mctx R g] (fun_of_binop f1 f2). +destruct b. +exact: measurable_and mf1 mf2. +exact: measurable_or mf1 mf2. +exact: measurable_funD. +exact: measurable_funB. +exact: measurable_funM. +Defined. + +End binop. + +(* TODO: rename, generalize? *) +Section relop. +Inductive relop := +| relop_le | relop_lt | relop_eq . + +Definition fun_of_relop g (r : relop) : (@mctx R g -> @mtyp R Nat) -> + (mctx g -> mtyp Nat) -> @mctx R g -> @mtyp R Bool := +match r with +| relop_le => (fun f1 f2 x => (f1 x <= f2 x)%N) +| relop_lt => (fun f1 f2 x => (f1 x < f2 x)%N) +| relop_eq => (fun f1 f2 x => (f1 x == f2 x)%N) +end. + +Definition mfun_of_relop g r + (f1 : @mctx R g -> @mtyp R Nat) (mf1 : measurable_fun setT f1) + (f2 : @mctx R g -> @mtyp R Nat) (mf2 : measurable_fun setT f2) : + measurable_fun [set: @mctx R g] (fun_of_relop r f1 f2). +destruct r. +exact: measurable_fun_leq. +exact: measurable_fun_ltn. +exact: measurable_fun_eqn. +Defined. + +End relop. + +Section relop_Real. +Inductive relop_real := +| relop_real_le | relop_real_lt | relop_real_eq . + +Definition fun_of_relop_real g (r : relop_real) : (@mctx R g -> @mtyp R Real) -> + (mctx g -> mtyp Real) -> @mctx R g -> @mtyp R Bool := +match r with +| relop_real_le => (fun f1 f2 x => (f1 x <= f2 x)%R) +| relop_real_lt => (fun f1 f2 x => (f1 x < f2 x)%R) +| relop_real_eq => (fun f1 f2 x => (f1 x == f2 x)%R) +end. + +Definition mfun_of_relop_real g r + (f1 : @mctx R g -> @mtyp R Real) (mf1 : measurable_fun setT f1) + (f2 : @mctx R g -> @mtyp R Real) (mf2 : measurable_fun setT f2) : + measurable_fun [set: @mctx R g] (fun_of_relop_real r f1 f2). +destruct r. +exact: measurable_fun_ler. +exact: measurable_fun_ltr. +exact: measurable_fun_eqr. +Defined. + +End relop_Real. + +Inductive exp : flag -> ctx -> typ -> Type := +| exp_unit g : exp D g Unit +| exp_bool g : bool -> exp D g Bool +| exp_nat g : nat -> exp D g Nat +| exp_real g : R -> exp D g Real +| exp_pow g : exp D g Real -> nat -> exp D g Real +| exp_pow_real g : R (* base *) -> exp D g Real -> exp D g Real +| exp_bin (b : binop) g : exp D g (type_of_binop b) -> + exp D g (type_of_binop b) -> exp D g (type_of_binop b) +| exp_rel (r : relop) g : exp D g Nat -> + exp D g Nat -> exp D g Bool +| exp_rel_real (r : relop_real) g : exp D g Real -> + exp D g Real -> exp D g Bool +| exp_pair g t1 t2 : exp D g t1 -> exp D g t2 -> exp D g (Pair t1 t2) +| exp_proj1 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t1 +| exp_proj2 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t2 +| exp_var g str t : t = lookup Unit g str -> exp D g t +| exp_bernoulli g : exp D g Real -> exp D g (Prob Bool) +| exp_binomial g (n : nat) : exp D g Real -> exp D g (Prob Nat) +| exp_uniform g (a b : R) (ab : (a < b)%R) : exp D g (Prob Real) +| exp_beta g (a b : nat) : exp D g (Prob Real) +| exp_poisson g : nat -> exp D g Real -> exp D g Real +| exp_normal g : exp D g Real -> forall (s : R), (s != 0)%R -> exp D g (Prob Real) +| exp_normalize g t : exp P g t -> exp D g (Prob t) +| exp_letin g t1 t2 str : exp P g t1 -> exp P ((str, t1) :: g) t2 -> + exp P g t2 +| exp_sample g t : exp D g (Prob t) -> exp P g t +| exp_score g : exp D g Real -> exp P g Unit +| exp_return g t : exp D g t -> exp P g t +| exp_if z g t : exp D g Bool -> exp z g t -> exp z g t -> exp z g t +| exp_weak z g h t x : exp z (g ++ h) t -> + x.1 \notin dom (g ++ h) -> exp z (g ++ x :: h) t. +Arguments exp_var {g} _ {t}. + +Definition exp_var' (str : string) (t : typ) (g : find str t) := + @exp_var (untag (ctx_of g)) str t (ctx_prf g). +Arguments exp_var' str {t} g. + +Lemma exp_var'E str t (f : find str t) H : + exp_var' str f = exp_var str H :> (@exp _ _ _). +Proof. by rewrite /exp_var'; congr exp_var. Qed. + +End syntax_of_expressions. +Arguments exp {R}. +Arguments exp_unit {R g}. +Arguments exp_bool {R g}. +Arguments exp_nat {R g}. +Arguments exp_real {R g}. +Arguments exp_pow {R g} &. +Arguments exp_pow_real {R g} &. +Arguments exp_bin {R} b {g} &. +Arguments exp_rel {R} r {g} &. +Arguments exp_rel_real {R} r {g} &. +Arguments exp_pair {R g} & {t1 t2}. +Arguments exp_var {R g} _ {t} & H. +Arguments exp_bernoulli {R g} &. +Arguments exp_binomial {R g} &. +Arguments exp_uniform {R g} &. +Arguments exp_beta {R g} &. +Arguments exp_poisson {R g}. +Arguments exp_normal {R g} &. +Arguments exp_normalize {R g _}. +Arguments exp_letin {R g} & {_ _}. +Arguments exp_sample {R g} & {t}. +Arguments exp_score {R g} &. +Arguments exp_return {R g} & {_}. +Arguments exp_if {R z g t} &. +Arguments exp_weak {R} z g h {t} x. +Arguments exp_var' {R} str {t} g &. + +Declare Custom Entry expr. +Notation "[ e ]" := e (e custom expr at level 5) : lang_scope. +Notation "'TT'" := (exp_unit) (in custom expr at level 1) : lang_scope. +Notation "b ':B'" := (@exp_bool _ _ b%bool) + (in custom expr at level 1) : lang_scope. +Notation "n ':N'" := (@exp_nat _ _ n%N) + (in custom expr at level 1) : lang_scope. +Notation "r ':R'" := (@exp_real _ _ r%R) + (in custom expr at level 1, format "r :R") : lang_scope. +Notation "e ^+ n" := (exp_pow e n) + (in custom expr at level 1) : lang_scope. +Notation "e `^ r" := (exp_pow_real e r) + (in custom expr at level 1) : lang_scope. +Notation "e1 && e2" := (exp_bin binop_and e1 e2) + (in custom expr at level 2) : lang_scope. +Notation "e1 || e2" := (exp_bin binop_or e1 e2) + (in custom expr at level 2) : lang_scope. +Notation "e1 + e2" := (exp_bin binop_add e1 e2) + (in custom expr at level 3) : lang_scope. +Notation "e1 - e2" := (exp_bin binop_minus e1 e2) + (in custom expr at level 3) : lang_scope. +Notation "e1 * e2" := (exp_bin binop_mult e1 e2) + (in custom expr at level 2) : lang_scope. +Notation "e1 <= e2" := (exp_rel relop_le e1 e2) + (in custom expr at level 2) : lang_scope. +Notation "e1 == e2" := (exp_rel relop_eq e1 e2) + (in custom expr at level 4) : lang_scope. +Notation "e1 <=R e2" := (exp_rel_real relop_real_le e1 e2) + (in custom expr at level 2) : lang_scope. +Notation "e1 ==R e2" := (exp_rel_real relop_real_eq e1 e2) + (in custom expr at level 4) : lang_scope. +Notation "'return' e" := (@exp_return _ _ _ e) + (in custom expr at level 7) : lang_scope. +(*Notation "% str" := (@exp_var _ _ str%string _ erefl) + (in custom expr at level 1, format "% str") : lang_scope.*) +(* Notation "% str H" := (@exp_var _ _ str%string _ H) + (in custom expr at level 1, format "% str H") : lang_scope. *) +Notation "# str" := (@exp_var' _ str%string _ _) + (in custom expr at level 1, format "# str"). +Notation "e :+ str" := (exp_weak _ [::] _ (str, _) e erefl) + (in custom expr at level 1) : lang_scope. +Notation "( e1 , e2 )" := (exp_pair e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "\pi_1 e" := (exp_proj1 e) + (in custom expr at level 1) : lang_scope. +Notation "\pi_2 e" := (exp_proj2 e) + (in custom expr at level 1) : lang_scope. +Notation "'let' x ':=' e 'in' f" := (exp_letin x e f) + (in custom expr at level 6, + x constr, + f custom expr at level 6, + left associativity) : lang_scope. +Notation "{ c }" := c (in custom expr, c constr) : lang_scope. +Notation "x" := x + (in custom expr at level 0, x ident) : lang_scope. +Notation "'Sample' e" := (exp_sample e) + (in custom expr at level 6) : lang_scope. +Notation "'Score' e" := (exp_score e) + (in custom expr at level 6) : lang_scope. +Notation "'Normalize' e" := (exp_normalize e) + (in custom expr at level 0) : lang_scope. +Notation "'Bernoulli' p" := (exp_bernoulli p) + (in custom expr at level 6) : lang_scope. +Notation "'Binomial' n k" := (exp_binomial n k) + (in custom expr at level 6) : lang_scope. +Notation "'Uniform' a b ab" := (exp_uniform a b ab) + (in custom expr at level 6) : lang_scope. +Notation "'Beta' a b" := (exp_beta a b) + (in custom expr at level 6) : lang_scope. +Notation "'Normal' m s s0" := (exp_normal m s s0) + (in custom expr at level 6) : lang_scope. +Notation "'if' e1 'then' e2 'else' e3" := (exp_if e1 e2 e3) + (in custom expr at level 7) : lang_scope. +Notation "( e )" := e + (in custom expr at level 1) : lang_scope. + +Section free_vars. +Context {R : realType}. + +Fixpoint free_vars k g t (e : @exp R k g t) : seq string := + match e with + | exp_unit _ => [::] + | exp_bool _ _ => [::] + | exp_nat _ _ => [::] + | exp_real _ _ => [::] + | exp_pow _ e _ => free_vars e + | exp_pow_real _ _ e => free_vars e + | exp_bin _ _ e1 e2 => free_vars e1 ++ free_vars e2 + | exp_rel _ _ e1 e2 => free_vars e1 ++ free_vars e2 + | exp_rel_real _ _ e1 e2 => free_vars e1 ++ free_vars e2 + | exp_pair _ _ _ e1 e2 => free_vars e1 ++ free_vars e2 + | exp_proj1 _ _ _ e => free_vars e + | exp_proj2 _ _ _ e => free_vars e + | exp_var _ x _ _ => [:: x] + | exp_bernoulli _ e => free_vars e + | exp_binomial _ _ e => free_vars e + | exp_uniform _ _ _ _ => [::] + | exp_beta _ _ _ => [::] + | exp_poisson _ _ e => free_vars e + | exp_normal _ e _ _ => free_vars e + | exp_normalize _ _ e => free_vars e + | exp_letin _ _ _ x e1 e2 => free_vars e1 ++ rem x (free_vars e2) + | exp_sample _ _ _ => [::] + | exp_score _ e => free_vars e + | exp_return _ _ e => free_vars e + | exp_if _ _ _ e1 e2 e3 => free_vars e1 ++ free_vars e2 ++ free_vars e3 + | exp_weak _ _ _ _ x e _ => rem x.1 (free_vars e) + end. + +End free_vars. + +Definition dval R g t := @mctx R g -> @mtyp R t. +Definition pval R g t := R.-sfker @mctx R g ~> @mtyp R t. + +Section weak. +Context {R : realType}. +Implicit Types (g h : ctx) (x : string * typ). + +Fixpoint mctx_strong g h x (f : @mctx R (g ++ x :: h)) : @mctx R (g ++ h) := + match g as g0 return mctx (g0 ++ x :: h) -> mctx (g0 ++ h) with + | [::] => fun f0 : mctx ([::] ++ x :: h) => let (a, b) := f0 in (fun=> id) a b + | a :: t => uncurry (fun a b => (a, @mctx_strong t h x b)) + end f. + +Definition weak g h x t (f : dval R (g ++ h) t) : dval R (g ++ x :: h) t := + f \o @mctx_strong g h x. + +Lemma measurable_fun_mctx_strong g h x : + measurable_fun setT (@mctx_strong g h x). +Proof. +elim: g h x => [h x|x g ih h x0]; first exact: measurable_snd. +apply/measurable_fun_pairP; split. +- rewrite [X in measurable_fun _ X](_ : _ = fst)//. + by apply/funext => -[]. +- rewrite [X in measurable_fun _ X](_ : _ = @mctx_strong g h x0 \o snd). + apply: measurableT_comp; last exact: measurable_snd. + exact: ih. + by apply/funext => -[]. +Qed. + +Lemma measurable_weak g h x t (f : dval R (g ++ h) t) : + measurable_fun setT f -> measurable_fun setT (@weak g h x t f). +Proof. +move=> mf; apply: measurableT_comp; first exact: mf. +exact: measurable_fun_mctx_strong. +Qed. + +Definition kweak g h x t (f : pval R (g ++ h) t) + : @mctx R (g ++ x :: h) -> {measure set @mtyp R t -> \bar R} := + f \o @mctx_strong g h x. + +Section kernel_weak. +Context g h x t (f : pval R (g ++ h) t). + +Let mf U : measurable U -> measurable_fun setT (@kweak g h x t f ^~ U). +Proof. +move=> mU. +rewrite (_ : kweak _ ^~ U = f ^~ U \o @mctx_strong g h x)//. +apply: measurableT_comp => //; first exact: measurable_kernel. +exact: measurable_fun_mctx_strong. +Qed. + +HB.instance Definition _ := isKernel.Build _ _ _ _ _ (@kweak g h x t f) mf. +End kernel_weak. + +Section sfkernel_weak. +Context g h (x : string * typ) t (f : pval R (g ++ h) t). + +Let sf : exists2 s : (R.-ker @mctx R (g ++ x :: h) ~> @mtyp R t)^nat, + forall n, measure_fam_uub (s n) & + forall z U, measurable U -> (@kweak g h x t f) z U = kseries s z U . +Proof. +have [s hs] := sfinite_kernel f. +exists (fun n => @kweak g h x t (s n)). + by move=> n; have [M hM] := measure_uub (s n); exists M => x0; exact: hM. +by move=> z U mU; by rewrite /kweak/= hs. +Qed. + +HB.instance Definition _ := + isSFiniteKernel_subdef.Build _ _ _ _ _ (@kweak g h x t f) sf. + +End sfkernel_weak. + +Section fkernel_weak. +Context g h x t (f : R.-fker @mctx R (g ++ h) ~> @mtyp R t). + +Let uub : measure_fam_uub (@kweak g h x t f). +Proof. by have [M hM] := measure_uub f; exists M => x0; exact: hM. Qed. + +HB.instance Definition _ := @Kernel_isFinite.Build _ _ _ _ _ + (@kweak g h x t f) uub. +End fkernel_weak. + +End weak. +Arguments weak {R} g h x {t}. +Arguments measurable_weak {R} g h x {t}. +Arguments kweak {R} g h x {t}. + +Section eval. +Context {R : realType}. +Implicit Type (g : ctx) (str : string). +Local Open Scope lang_scope. + +Inductive evalD : forall g t, exp D g t -> + forall f : dval R g t, measurable_fun setT f -> Prop := +| eval_unit g : ([TT] : exp D g _) -D> cst tt ; ktt + +| eval_bool g b : ([b:B] : exp D g _) -D> cst b ; kb b + +| eval_nat g n : ([n:N] : exp D g _) -D> cst n; kn n + +| eval_real g r : ([r:R] : exp D g _) -D> cst r ; kr r + +| eval_pow g n (e : exp D g _) f mf : e -D> f ; mf -> + [e ^+ {n}] -D> (fun x => (f x ^+ n)%R) ; (measurable_funX n mf) + +| eval_pow_real g (e : exp D g _) r f mf : e -D> f ; mf -> + [{r} `^ e] -D> (fun x => (r `^ (f x))%R) ; measurableT_comp (measurable_powRr r) mf + +| eval_bin g bop (e1 : exp D g _) f1 mf1 e2 f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + exp_bin bop e1 e2 -D> fun_of_binop f1 f2 ; mfun_of_binop mf1 mf2 + +| eval_rel g rop (e1 : exp D g _) f1 mf1 e2 f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + exp_rel rop e1 e2 -D> fun_of_relop rop f1 f2 ; mfun_of_relop rop mf1 mf2 + +| eval_rel_real g rop (e1 : exp D g _) f1 mf1 e2 f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + exp_rel_real rop e1 e2 -D> fun_of_relop_real rop f1 f2 ; mfun_of_relop_real rop mf1 mf2 + +| eval_pair g t1 (e1 : exp D g t1) f1 mf1 t2 (e2 : exp D g t2) f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + [(e1, e2)] -D> fun x => (f1 x, f2 x) ; measurable_fun_pair mf1 mf2 + +| eval_proj1 g t1 t2 (e : exp D g (Pair t1 t2)) f mf : + e -D> f ; mf -> + [\pi_1 e] -D> fst \o f ; measurableT_comp measurable_fst mf + +| eval_proj2 g t1 t2 (e : exp D g (Pair t1 t2)) f mf : + e -D> f ; mf -> + [\pi_2 e] -D> snd \o f ; measurableT_comp measurable_snd mf + +(* | eval_var g str : let i := index str (dom g) in + [% str] -D> acc_typ (map snd g) i ; measurable_acc_typ (map snd g) i *) + +| eval_var g x H : let i := index x (dom g) in + exp_var x H -D> acc_typ (map snd g) i ; measurable_acc_typ (map snd g) i + +| eval_bernoulli g e r mr : + e -D> r ; mr -> ([Bernoulli e] : exp _ g _) -D> bernoulli_prob \o r ; + measurableT_comp measurable_bernoulli_prob mr + +| eval_binomial g n e r mr : + e -D> r ; mr -> ([Binomial n e] : exp _ g _) -D> binomial_prob n \o r ; + measurableT_comp (measurable_binomial_prob n) mr + +| eval_uniform g (a b : R) (ab : (a < b)%R) : + ([Uniform a b ab] : exp D g _) -D> cst (uniform_prob ab) ; + measurable_cst _ + +| eval_beta g (a b : nat) : + ([Beta a b] : exp D g _) -D> cst (beta_prob a b) ; measurable_cst _ + +| eval_poisson g n (e : exp D g _) f mf : + e -D> f ; mf -> + exp_poisson n e -D> poisson_pmf ^~ n \o f ; + measurableT_comp (measurable_poisson_pmf n measurableT) mf + +| eval_normal g s (s0 : (s != 0)%R) (e : exp D g _) r mr : + e -D> r ; mr -> + ([Normal e s s0] : exp D g _) -D> (fun x => @normal_prob _ (r x) s) ; + measurableT_comp (measurable_normal_prob2 s0) mr + +| eval_normalize g t (e : exp P g t) k : + e -P> k -> + [Normalize e] -D> normalize_pt k ; measurable_normalize_pt k + +| evalD_if g t e f mf (e1 : exp D g t) f1 mf1 e2 f2 mf2 : + e -D> f ; mf -> e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + [if e then e1 else e2] -D> fun x => if f x then f1 x else f2 x ; + measurable_fun_ifT mf mf1 mf2 + +| evalD_weak g h t e x (H : x.1 \notin dom (g ++ h)) f mf : + e -D> f ; mf -> + (exp_weak _ g h x e H : exp _ _ t) -D> weak g h x f ; + measurable_weak g h x f mf + +where "e -D> v ; mv" := (@evalD _ _ e v mv) + +with evalP : forall g t, exp P g t -> pval R g t -> Prop := + +| eval_letin g t1 t2 str (e1 : exp _ g t1) (e2 : exp _ _ t2) k1 k2 : + e1 -P> k1 -> e2 -P> k2 -> + [let str := e1 in e2] -P> letin' k1 k2 + +| eval_sample g t (e : exp _ _ (Prob t)) + (p : mctx g -> pprobability (mtyp t) R) mp : + e -D> p ; mp -> [Sample e] -P> sample p mp + +| eval_score g (e : exp _ g _) f mf : + e -D> f ; mf -> [Score e] -P> kscore mf + +| eval_return g t (e : exp D g t) f mf : + e -D> f ; mf -> [return e] -P> ret mf + +| evalP_if g t e f mf (e1 : exp P g t) k1 e2 k2 : + e -D> f ; mf -> e1 -P> k1 -> e2 -P> k2 -> + [if e then e1 else e2] -P> ite mf k1 k2 + +| evalP_weak g h t (e : exp P (g ++ h) t) x + (H : x.1 \notin dom (g ++ h)) f : + e -P> f -> + exp_weak _ g h x e H -P> kweak g h x f + +where "e -P> v" := (@evalP _ _ e v). + +End eval. + +Notation "e -D> v ; mv" := (@evalD _ _ _ e v mv) : lang_scope. +Notation "e -P> v" := (@evalP _ _ _ e v) : lang_scope. + +Scheme evalD_mut_ind := Induction for evalD Sort Prop +with evalP_mut_ind := Induction for evalP Sort Prop. + +(* properties of the evaluation relation *) +Section eval_prop. +Variables (R : realType). +Local Open Scope lang_scope. + +Lemma evalD_uniq g t (e : exp D g t) (u v : dval R g t) mu mv : + e -D> u ; mu -> e -D> v ; mv -> u = v. +Proof. +move=> hu. +apply: (@evalD_mut_ind R + (fun g t (e : exp D g t) f mf (h1 : e -D> f; mf) => + forall v mv, e -D> v; mv -> f = v) + (fun g t (e : exp P g t) u (h1 : e -P> u) => + forall v, e -P> v -> u = v)); last exact: hu. +all: (rewrite {g t e u v mu mv hu}). +- move=> g {}v {}mv. + inversion 1; subst g0. + by inj_ex H3. +- move=> g b {}v {}mv. + inversion 1; subst g0 b0. + by inj_ex H3. +- move=> g n {}v {}mv. + inversion 1; subst g0 n0. + by inj_ex H3. +- move=> g r {}v {}mv. + inversion 1; subst g0 r0. + by inj_ex H3. +- move=> g n e f mf ev IH {}v {}mv. + inversion 1; subst g0 n0. + inj_ex H4; subst v. + inj_ex H0; subst e0. + by move: H3 => /IH <-. +- move=> g e r f mf ev IH {}v {}mv. + inversion 1; subst g0 r0. + inj_ex H4; subst v. + inj_ex H2; subst e0. + by move: H3 => /IH <-. +- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 bop0. + inj_ex H10; subst v. + inj_ex H5; subst e1. + inj_ex H6; subst e5. + by move: H4 H11 => /IH1 <- /IH2 <-. +- move=> g rop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 rop0. + inj_ex H5; subst v. + inj_ex H1; subst e1. + inj_ex H3; subst e3. + by move: H6 H7 => /IH1 <- /IH2 <-. +- move=> g rop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 rop0. + inj_ex H5; subst v. + inj_ex H1; subst e1. + inj_ex H3; subst e3. + by move: H6 H7 => /IH1 <- /IH2 <-. +- move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + simple inversion 1 => //; subst g0. + case: H3 => ? ?; subst t0 t3. + inj_ex H4; case: H4 => He1 He2. + inj_ex He1; subst e0. + inj_ex He2; subst e3. + inj_ex H5; subst v. + by move=> /IH1 <- /IH2 <-. +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g str H n {}v {}mv. + inversion 1; subst g0. + inj_ex H9; rewrite -H9. + by inj_ex H10. +- move=> g e r mr ev IH {}v {}mv. + inversion 1; subst g0. + inj_ex H0; subst e0. + inj_ex H3; subst v. + by rewrite (IH _ _ H4). +- move=> g n e f mf ev IH {}v {}mv. + inversion 1; subst g0 n0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + by rewrite (IH _ _ H5). +- move=> g a b ab {}v {}mv. + inversion 1; subst g0 a0 b0. + inj_ex H4; subst v. + by have -> : ab = ab1. +- (* TODO: beta *) move=> g a b {}v {}mv. + inversion 1; subst g0 a0 b0. + by inj_ex H4; subst v. +- move=> g t e k mk ev IH {}v {}mv. + inversion 1; subst g0 t. + inj_ex H2; subst e0. + inj_ex H4; subst v. + by rewrite (IH _ _ H3). +- move=> g s s0 e r mr ev IH {}v {}mv. + inversion 1; subst g0 s1. + inj_ex H0; subst e0. + inj_ex H3; subst v. + by rewrite (IH _ _ H5). +- move=> g t e k ev IH f mf. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H4; subst f. + inj_ex H5; subst mf. + by rewrite (IH _ H3). +- move=> g t e f mf e1 f1 mf1 e2 f2 mf2 ev ih ev1 ih1 ev2 ih2 v m. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H6; subst e5. + inj_ex H7; subst e6. + inj_ex H9; subst v. + clear H11. + have ? := ih1 _ _ H12; subst f6. + have ? := ih2 _ _ H13; subst f7. + by rewrite (ih _ _ H5). +- move=> g h t e x H f mf ef ih {}v {}mv. + inversion 1; subst t0 g0 h0 x0. + inj_ex H12; subst e1. + inj_ex H14; subst v. + clear H16. + by rewrite (ih _ _ H5). +- move=> g t1 t2 x e1 e2 k1 k2 ev1 IH1 ev2 IH2 k. + inversion 1; subst g0 t0 t3 x. + inj_ex H7; subst k. + inj_ex H6; subst e5. + inj_ex H5; subst e4. + by rewrite (IH1 _ H4) (IH2 _ H8). +- move=> g t e p mp ev IH k. + inversion 1; subst g0. + inj_ex H5; subst t0. + inj_ex H5; subst e1. + inj_ex H7; subst k. + have ? := IH _ _ H3; subst p1. + by have -> : mp = mp1 by []. +- move=> g e f mf ev IH k. + inversion 1; subst g0. + inj_ex H0; subst e0. + inj_ex H4; subst k. + have ? := IH _ _ H2; subst f1. + by have -> : mf = mf0 by []. +- move=> g t e0 f mf ev IH k. + inversion 1; subst g0 t0. + inj_ex H5; subst e1. + inj_ex H7; subst k. + have ? := IH _ _ H3; subst f1. + by have -> : mf = mf1 by []. +- move=> g t e f mf e1 k1 e2 k2 ev ih ev1 ih1 ev2 ih2 k. + inversion 1; subst g0 t0. + inj_ex H0; subst e0. + inj_ex H1; subst e3. + inj_ex H5; subst k. + inj_ex H2; subst e4. + have ? := ih _ _ H6; subst f1. + have -> : mf = mf0 by []. + by rewrite (ih1 _ H7) (ih2 _ H8). +- move=> g h t e x xgh k ek ih. + inversion 1; subst x0 g0 h0 t0. + inj_ex H13; rewrite -H13. + inj_ex H11; subst e1. + by rewrite (ih _ H4). +Qed. + +Lemma evalP_uniq g t (e : exp P g t) (u v : pval R g t) : + e -P> u -> e -P> v -> u = v. +Proof. +move=> eu. +apply: (@evalP_mut_ind R + (fun g t (e : exp D g t) f mf (h : e -D> f; mf) => + forall v mv, e -D> v; mv -> f = v) + (fun g t (e : exp P g t) u (h : e -P> u) => + forall v, e -P> v -> u = v)); last exact: eu. +all: rewrite {g t e u v eu}. +- move=> g {}v {}mv. + inversion 1; subst g0. + by inj_ex H3. +- move=> g b {}v {}mv. + inversion 1; subst g0 b0. + by inj_ex H3. +- move=> g n {}v {}mv. + inversion 1; subst g0 n0. + by inj_ex H3. +- move=> g r {}v {}mv. + inversion 1; subst g0 r0. + by inj_ex H3. +- move=> g n e f mf ev IH {}v {}mv. + inversion 1; subst g0 n0. + inj_ex H4; subst v. + inj_ex H0; subst e0. + by move: H3 => /IH <-. +- move=> g e b f mf ev IH {}v {}mv. + inversion 1; subst g0 r. + inj_ex H4; subst v. + inj_ex H2; subst e0. + by move: H3 => /IH <-. +- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 bop0. + inj_ex H10; subst v. + inj_ex H5; subst e1. + inj_ex H6; subst e5. + by move: H4 H11 => /IH1 <- /IH2 <-. +- move=> g rop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 rop0. + inj_ex H5; subst v. + inj_ex H1; subst e1. + inj_ex H3; subst e3. + by move: H6 H7 => /IH1 <- /IH2 <-. +- move=> g rop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 rop0. + inj_ex H5; subst v. + inj_ex H1; subst e1. + inj_ex H3; subst e3. + by move: H6 H7 => /IH1 <- /IH2 <-. +- move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + simple inversion 1 => //; subst g0. + case: H3 => ? ?; subst t0 t3. + inj_ex H4; case: H4 => He1 He2. + inj_ex He1; subst e0. + inj_ex He2; subst e3. + inj_ex H5; subst v. + move=> e1f0 e2f3. + by rewrite (IH1 _ _ e1f0) (IH2 _ _ e2f3). +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g str H n {}v {}mv. + inversion 1; subst g0. + inj_ex H9; rewrite -H9. + by inj_ex H10. +- move=> g e r mr ev IH {}v {}mv. + inversion 1; subst g0. + inj_ex H0; subst e0. + inj_ex H3; subst v. + by rewrite (IH _ _ H4). +- move=> g n e f mf ev IH {}v {}mv. + inversion 1; subst g0 n0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + by rewrite (IH _ _ H5). +- move=> g a b ab {}v {}mv. + inversion 1; subst g0 a0 b0. + inj_ex H4; subst v. + by have -> : ab = ab1. +- (* TODO: beta case*) move=> g a b {}v {}mv. + inversion 1; subst g0 a0 b0. + by inj_ex H4; subst v. +- move=> g n e f mf ev IH {}v {}mv. + inversion 1; subst g0 n0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + inj_ex H5; subst mv. + by rewrite (IH _ _ H3). +- move=> g s s0 e r mr ev IH {}v {}mv. + inversion 1; subst g0 s1. + inj_ex H0; subst e0. + inj_ex H3; subst v. + by rewrite (IH _ _ H5). +- move=> g t e k ev IH {}v {}mv. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + inj_ex H5; subst mv. + by rewrite (IH _ H3). +- move=> g t e f mf e1 f1 mf1 e2 f2 mf2 ef ih ef1 ih1 ef2 ih2 {}v {}mv. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H6; subst e5. + inj_ex H7; subst e6. + inj_ex H9; subst v. + clear H11. + have ? := ih1 _ _ H12; subst f6. + have ? := ih2 _ _ H13; subst f7. + by rewrite (ih _ _ H5). +- move=> g h t e x H f mf ef ih {}v {}mv. + inversion 1; subst x0 g0 h0 t0. + inj_ex H12; subst e1. + inj_ex H14; subst v. + clear H16. + by rewrite (ih _ _ H5). +- move=> g t1 t2 x e1 e2 k1 k2 ev1 IH1 ev2 IH2 k. + inversion 1; subst g0 x t3 t0. + inj_ex H7; subst k. + inj_ex H5; subst e4. + inj_ex H6; subst e5. + by rewrite (IH1 _ H4) (IH2 _ H8). +- move=> g t e p mp ep IH v. + inversion 1; subst g0 t0. + inj_ex H7; subst v. + inj_ex H5; subst e1. + have ? := IH _ _ H3; subst p1. + by have -> : mp = mp1 by []. +- move=> g e f mf ev IH k. + inversion 1; subst g0. + inj_ex H0; subst e0. + inj_ex H4; subst k. + have ? := IH _ _ H2; subst f1. + by have -> : mf = mf0 by []. +- move=> g t e f mf ev IH k. + inversion 1; subst g0 t0. + inj_ex H7; subst k. + inj_ex H5; subst e1. + have ? := IH _ _ H3; subst f1. + by have -> : mf = mf1 by []. +- move=> g t e f mf e1 k1 e2 k2 ev ih ev1 ih1 ev2 ih2 k. + inversion 1; subst g0 t0. + inj_ex H0; subst e0. + inj_ex H1; subst e3. + inj_ex H5; subst k. + inj_ex H2; subst e4. + have ? := ih _ _ H6; subst f1. + have -> : mf0 = mf by []. + by rewrite (ih1 _ H7) (ih2 _ H8). +- move=> g h t e x xgh k ek ih. + inversion 1; subst x0 g0 h0 t0. + inj_ex H13; rewrite -H13. + inj_ex H11; subst e1. + by rewrite (ih _ H4). +Qed. + +Lemma eval_total z g t (e : @exp R z g t) : + (match z with + | D => fun e => exists f mf, e -D> f ; mf + | P => fun e => exists k, e -P> k + end) e. +Proof. +elim: e. +all: rewrite {z g t}. +- by do 2 eexists; exact: eval_unit. +- by do 2 eexists; exact: eval_bool. +- by do 2 eexists; exact: eval_nat. +- by do 2 eexists; exact: eval_real. +- move=> g e [f [mf H]] n. + by exists (fun x => (f x ^+ n)%R); eexists; exact: eval_pow. +- move=> g r e [f [mf H]]. + by exists (fun x => (r `^ (f x))%R); eexists; exact: eval_pow_real. +- move=> b g e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun_of_binop f1 f2); eexists; exact: eval_bin. +- move=> r g e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun_of_relop r f1 f2); eexists; exact: eval_rel. +- move=> r g e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun_of_relop_real r f1 f2); eexists; exact: eval_rel_real. +- move=> g t1 t2 e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun x => (f1 x, f2 x)); eexists; exact: eval_pair. +- move=> g t1 t2 e [f [mf H]]. + by exists (fst \o f); eexists; exact: eval_proj1. +- move=> g t1 t2 e [f [mf H]]. + by exists (snd \o f); eexists; exact: eval_proj2. +- by move=> g x t tE; subst t; eexists; eexists; exact: eval_var. +- move=> g e [p [mp H]]. + exists ((bernoulli_prob : R -> pprobability bool R) \o p). + by eexists; exact: eval_bernoulli. +- move=> g n e [p [mp H]]. + exists ((binomial_prob n : R -> pprobability nat R) \o p). + by eexists; exact: (eval_binomial n). +- by eexists; eexists; exact: eval_uniform. +- by eexists; eexists; exact: eval_beta. +- move=> g h e [f [mf H]]. + by exists (poisson_pmf ^~ h \o f); eexists; exact: eval_poisson. +- move=> g e [r [mr H]] s s0. + exists (fun x => @normal_prob _ (r x) s : pprobability _ _). + by eexists; exact: eval_normal. +- move=> g t e [k ek]. + by exists (normalize_pt k); eexists; exact: eval_normalize. +- move=> g t1 t2 x e1 [k1 ev1] e2 [k2 ev2]. + by exists (letin' k1 k2); exact: eval_letin. +- move=> g t e [f [/= mf ef]]. + by eexists; exact: (@eval_sample _ _ _ _ _ mf). +- move=> g e [f [mf f_mf]]. + by exists (kscore mf); exact: eval_score. +- by move=> g t e [f [mf f_mf]]; exists (ret mf); exact: eval_return. +- case. + + move=> g t e1 [f [mf H1]] e2 [f2 [mf2 H2]] e3 [f3 [mf3 H3]]. + by exists (fun g => if f g then f2 g else f3 g), + (measurable_fun_ifT mf mf2 mf3); exact: evalD_if. + + move=> g t e1 [f [mf H1]] e2 [k2 H2] e3 [k3 H3]. + by exists (ite mf k2 k3); exact: evalP_if. +- case=> [g h t x e [f [mf ef]] xgh|g h st x e [k ek] xgh]. + + by exists (weak _ _ _ f), (measurable_weak _ _ _ _ mf); exact/evalD_weak. + + by exists (kweak _ _ _ k); exact: evalP_weak. +Qed. + +Lemma evalD_total g t (e : @exp R D g t) : exists f mf, e -D> f ; mf. +Proof. exact: (eval_total e). Qed. + +Lemma evalP_total g t (e : @exp R P g t) : exists k, e -P> k. +Proof. exact: (eval_total e). Qed. + +End eval_prop. + +Section execution_functions. +Local Open Scope lang_scope. +Context {R : realType}. +Implicit Type g : ctx. + +Definition execD g t (e : exp D g t) : + {f : dval R g t & measurable_fun setT f} := + let: exist _ H := cid (evalD_total e) in + existT _ _ (projT1 (cid H)). + +Lemma eq_execD g t (p1 p2 : @exp R D g t) : + projT1 (execD p1) = projT1 (execD p2) -> execD p1 = execD p2. +Proof. +rewrite /execD /=. +case: cid => /= f1 [mf1 ev1]. +case: cid => /= f2 [mf2 ev2] f12. +subst f2. +have ? : mf1 = mf2 by []. +subst mf2. +congr existT. +rewrite /sval. +case: cid => mf1' ev1'. +have ? : mf1 = mf1' by []. +subst mf1'. +case: cid => mf2' ev2'. +have ? : mf1 = mf2' by []. +by subst mf2'. +Qed. + +Definition execP g t (e : exp P g t) : pval R g t := + projT1 (cid (evalP_total e)). + +Lemma execD_evalD g t e x mx: + @execD g t e = existT _ x mx <-> e -D> x ; mx. +Proof. +rewrite /execD; split. + case: cid => x' [mx' H] [?]; subst x'. + have ? : mx = mx' by []. + by subst mx'. +case: cid => f' [mf' f'mf']/=. +move/evalD_uniq => /(_ _ _ f'mf') => ?; subst f'. +by case: cid => //= ? ?; congr existT. +Qed. + +Lemma evalD_execD g t (e : exp D g t) : + e -D> projT1 (execD e); projT2 (execD e). +Proof. +by rewrite /execD; case: cid => // x [mx xmx]/=; case: cid. +Qed. + +Lemma execP_evalP g t (e : exp P g t) x : + execP e = x <-> e -P> x. +Proof. +rewrite /execP; split; first by move=> <-; case: cid. +case: cid => // x0 Hx0. +by move/evalP_uniq => /(_ _ Hx0) ?; subst x. +Qed. + +Lemma evalP_execP g t (e : exp P g t) : e -P> execP e. +Proof. by rewrite /execP; case: cid. Qed. + +Lemma execD_unit g : @execD g _ [TT] = existT _ (cst tt) ktt. +Proof. exact/execD_evalD/eval_unit. Qed. + +Lemma execD_bool g b : @execD g _ [b:B] = existT _ (cst b) (kb b). +Proof. exact/execD_evalD/eval_bool. Qed. + +Lemma execD_nat g n : @execD g _ [n:N] = existT _ (cst n) (kn n). +Proof. exact/execD_evalD/eval_nat. Qed. + +Lemma execD_real g r : @execD g _ [r:R] = existT _ (cst r) (kr r). +Proof. exact/execD_evalD/eval_real. Qed. + +Lemma execD_pow g (e : exp D g _) n : + let f := projT1 (execD e) in let mf := projT2 (execD e) in + execD (exp_pow e n) = + @existT _ _ (fun x => (f x ^+ n)%R) (measurable_funX n mf). +Proof. +by move=> f mf; apply/execD_evalD/eval_pow/evalD_execD. +Qed. + +Lemma execD_pow_real g r (e : exp D g _) : + let f := projT1 (execD e) in let mf := projT2 (execD e) in + execD (exp_pow_real r e) = + @existT _ _ (fun x => (r `^ f x)%R) (measurableT_comp (measurable_powRr r) mf). +Proof. +by move=> f mf; apply/execD_evalD/eval_pow_real/evalD_execD. +Qed. + +Lemma execD_bin g bop (e1 : exp D g _) (e2 : exp D g _) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD (exp_bin bop e1 e2) = + @existT _ _ (fun_of_binop f1 f2) (mfun_of_binop mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_bin; exact/evalD_execD. +Qed. + +Lemma execD_rel g rop (e1 : exp D g _) (e2 : exp D g _) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD (exp_rel rop e1 e2) = + @existT _ _ (fun_of_relop rop f1 f2) (mfun_of_relop rop mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_rel; exact: evalD_execD. +Qed. + +Lemma execD_rel_real g rop (e1 : exp D g _) (e2 : exp D g _) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD (exp_rel_real rop e1 e2) = + @existT _ _ (fun_of_relop_real rop f1 f2) (mfun_of_relop_real rop mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_rel_real; exact: evalD_execD. +Qed. + +Lemma execD_pair g t1 t2 (e1 : exp D g t1) (e2 : exp D g t2) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD [(e1, e2)] = + @existT _ _ (fun z => (f1 z, f2 z)) + (@measurable_fun_pair _ _ _ (mctx g) (mtyp t1) (mtyp t2) + f1 f2 mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_pair; exact: evalD_execD. +Qed. + +Lemma execD_proj1 g t1 t2 (e : exp D g (Pair t1 t2)) : + let f := projT1 (execD e) in + let mf := projT2 (execD e) in + execD [\pi_1 e] = @existT _ _ (fst \o f) + (measurableT_comp measurable_fst mf). +Proof. +by move=> f mf; apply/execD_evalD/eval_proj1; exact: evalD_execD. +Qed. + +Lemma execD_proj2 g t1 t2 (e : exp D g (Pair t1 t2)) : + let f := projT1 (execD e) in let mf := projT2 (execD e) in + execD [\pi_2 e] = @existT _ _ (snd \o f) + (measurableT_comp measurable_snd mf). +Proof. +by move=> f mf; apply/execD_evalD/eval_proj2; exact: evalD_execD. +Qed. + +Lemma execD_var_erefl g str : let i := index str (dom g) in + @execD g _ (exp_var str erefl) = existT _ (acc_typ (map snd g) i) + (measurable_acc_typ (map snd g) i). +Proof. by move=> i; apply/execD_evalD; exact: eval_var. Qed. + +Lemma execD_var g x (H : nth Unit (map snd g) (index x (dom g)) = lookup Unit g x) : + let i := index x (dom g) in + @execD g _ (exp_var x H) = existT _ (acc_typ (map snd g) i) + (measurable_acc_typ (map snd g) i). +Proof. by move=> i; apply/execD_evalD; exact: eval_var. Qed. + +Lemma execD_bernoulli g e : + @execD g _ [Bernoulli e] = + existT _ ((bernoulli_prob : R -> pprobability bool R) \o projT1 (execD e)) + (measurableT_comp measurable_bernoulli_prob (projT2 (execD e))). +Proof. exact/execD_evalD/eval_bernoulli/evalD_execD. Qed. + +Lemma execD_binomial g n e : + @execD g _ [Binomial n e] = + existT _ ((binomial_prob n : R -> pprobability nat R) \o projT1 (execD e)) + (measurableT_comp (measurable_binomial_prob n) (projT2 (execD e))). +Proof. exact/execD_evalD/eval_binomial/evalD_execD. Qed. + +Lemma execD_uniform g a b ab0 : + @execD g _ [Uniform a b ab0] = + existT _ (cst (uniform_prob ab0 : pprobability _ R)) (measurable_cst _). +Proof. exact/execD_evalD/eval_uniform. Qed. + +Lemma execD_beta g a b : + @execD g _ [Beta a b] = + existT _ (cst (beta_prob a b : pprobability _ R)) (measurable_cst _). +Proof. exact/execD_evalD/eval_beta. Qed. + +Lemma execD_normal g s s0 e : + let f := projT1 (execD e) in let mf := projT2 (execD e) in + @execD g _ [Normal e s s0] = + existT _ (fun x => @normal_prob _ (f x) s : pprobability _ R) + (measurableT_comp (measurable_normal_prob2 s0) mf). +Proof. exact/execD_evalD/eval_normal/evalD_execD. Qed. + +Lemma execD_normalize_pt g t (e : exp P g t) : + @execD g _ [Normalize e] = + existT _ (normalize_pt (execP e) : _ -> pprobability _ _) + (measurable_normalize_pt (execP e)). +Proof. exact/execD_evalD/eval_normalize/evalP_execP. Qed. + +Lemma execD_poisson g n (e : exp D g Real) : + execD (exp_poisson n e) = + existT _ (poisson_pmf ^~ n \o projT1 (execD e)) + (measurableT_comp (measurable_poisson_pmf n measurableT) + (projT2 (execD e))). +Proof. exact/execD_evalD/eval_poisson/evalD_execD. Qed. + +Lemma execP_if g st e1 e2 e3 : + @execP g st [if e1 then e2 else e3] = + ite (projT2 (execD e1)) (execP e2) (execP e3). +Proof. +by apply/execP_evalP/evalP_if; [apply: evalD_execD| exact: evalP_execP..]. +Qed. + +Lemma execP_letin g x t1 t2 (e1 : exp P g t1) (e2 : exp P ((x, t1) :: g) t2) : + execP [let x := e1 in e2] = letin' (execP e1) (execP e2) :> (R.-sfker _ ~> _). +Proof. by apply/execP_evalP/eval_letin; exact: evalP_execP. Qed. + +Lemma execP_sample g t (e : @exp R D g (Prob t)) : + let x := execD e in + execP [Sample e] = sample (projT1 x) (projT2 x). +Proof. exact/execP_evalP/eval_sample/evalD_execD. Qed. + +Lemma execP_score g (e : exp D g Real) : + execP [Score e] = score (projT2 (execD e)). +Proof. exact/execP_evalP/eval_score/evalD_execD. Qed. + +Lemma execP_return g t (e : exp D g t) : + execP [return e] = ret (projT2 (execD e)). +Proof. exact/execP_evalP/eval_return/evalD_execD. Qed. + +Lemma execP_weak g h x t (e : exp P (g ++ h) t) + (xl : x.1 \notin dom (g ++ h)) : + execP (exp_weak P g h _ e xl) = kweak _ _ _ (execP e). +Proof. exact/execP_evalP/evalP_weak/evalP_execP. Qed. + +End execution_functions. +Arguments execD_var_erefl {R g} str. +Arguments execP_weak {R} g h x {t} e. +Arguments exp_var'E {R} str. + +Local Open Scope lang_scope. +Lemma congr_letinl {R : realType} g t1 t2 str (e1 e2 : @exp _ _ g t1) + (e : @exp _ _ (_ :: g) t2) x U : + (forall y V, execP e1 y V = execP e2 y V) -> + measurable U -> + @execP R g t2 [let str := e1 in e] x U = + @execP R g t2 [let str := e2 in e] x U. +Proof. by move=> + mU; move/eq_sfkernel => He; rewrite !execP_letin He. Qed. + +Lemma congr_letinr {R : realType} g t1 t2 str (e : @exp _ _ _ t1) + (e1 e2 : @exp _ _ (_ :: g) t2) x U : + (forall y V, execP e1 (y, x) V = execP e2 (y, x) V) -> + @execP R g t2 [let str := e in e1] x U = @execP R g t2 [let str := e in e2] x U. +Proof. +by move=> He; rewrite !execP_letin !letin'E; apply: eq_integral => ? _; exact: He. +Qed. + +Lemma congr_normalize {R : realType} g t (e1 e2 : @exp R _ g t) : + (forall x U, execP e1 x U = execP e2 x U) -> + execD [Normalize e1] = execD [Normalize e2]. +Proof. +move=> He; apply: eq_execD. +rewrite !execD_normalize_pt /=. +f_equal. +apply: eq_kernel => y V. +exact: He. +Qed. +Local Close Scope lang_scope. diff --git a/theories/lang_syntax_examples.v b/theories/lang_syntax_examples.v new file mode 100644 index 000000000..2e74afcd5 --- /dev/null +++ b/theories/lang_syntax_examples.v @@ -0,0 +1,1041 @@ +(* mathcomp analysis (c) 2025 Inria and AIST. License: CeCILL-C. *) +Require Import String. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval. +From mathcomp Require Import interval_inference. +From mathcomp Require Import unstable mathcomp_extra boolp classical_sets. +From mathcomp Require Import functions cardinality fsbigop. +From mathcomp Require Import reals ereal topology normedtype sequences esum. +From mathcomp Require Import measure lebesgue_measure numfun lebesgue_integral. +From mathcomp Require Import kernel prob_lang lang_syntax_util lang_syntax. +From mathcomp Require Import probability. +From mathcomp Require Import ring lra. + +(**md**************************************************************************) +(* # Examples using the probabilistic Programming language of lang_syntax.v *) +(* *) +(* sample_pair1213 := normalize ( *) +(* let x := sample (bernoulli 1/2) in *) +(* let y := sample (bernoulli 1/3) in *) +(* return (x, y)) *) +(* *) +(* sample_and1213 := normalize ( *) +(* let x := sample (bernoulli 1/2) in *) +(* let y := sample (bernoulli 1/3) in *) +(* return (x && y)) *) +(* *) +(* bernoulli13_score := normalize ( *) +(* let x := sample (bernoulli 1/3) in *) +(* let _ := if x then score (1/3) else score (2/3) in *) +(* return x) *) +(* *) +(* sample_binomial3 := *) +(* let x := sample (binomial 3 1/2) in *) +(* return x *) +(* *) +(* hard_constraint := let x := Score {0}:R in return TT *) +(* *) +(* guard := *) +(* let p := sample (bernoulli (1 / 3)) in *) +(* let _ := if p then return TT else score 0 in *) +(* return p *) +(* *) +(* more examples about uniform, beta, and bernoulli distributions *) +(* *) +(* associativity of let-in expressions *) +(* *) +(* staton_bus_syntax == example from [Staton, ESOP 2017] *) +(* *) +(* staton_busA_syntax == same as staton_bus_syntax module associativity of *) +(* let-in expression *) +(* *) +(* commutativity of let-in expressions *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Import Order.TTheory GRing.Theory Num.Def Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. +Local Open Scope string_scope. + +Local Open Scope string_scope. + +(* simple tests to check bidirectional hints *) +Module bidi_tests. +Section bidi_tests. +Local Open Scope lang_scope. +Import Notations. +Context (R : realType). + +Definition bidi_test1 x : @exp R P [::] _ := [ + let x := return {1}:R in + return #x]. + +Definition bidi_test2 (a b : string) + (a := "a") (b := "b") + (* (ba : infer (b != a)) *) + : @exp R P [::] _ := [ + let a := return {1}:R in + let b := return {true}:B in + (* let c := return {3}:R in + let d := return {4}:R in *) + return (#a, #b)]. + +Definition bidi_test3 (a b c d : string) + (ba : infer (b != a)) (ca : infer (c != a)) + (cb : infer (c != b)) (ab : infer (a != b)) + (ac : infer (a != c)) (bc : infer (b != c)) : @exp R P [::] _ := [ + let a := return {1}:R in + let b := return {2}:R in + let c := return {3}:R in + (* let d := return {4}:R in *) + return (#b, #a)]. + +Definition bidi_test4 (a b c d : string) + (ba : infer (b != a)) (ca : infer (c != a)) + (cb : infer (c != b)) (ab : infer (a != b)) + (ac : infer (a != c)) (bc : infer (b != c)) : @exp R P [::] _ := [ + let a := return {1}:R in + let b := return {2}:R in + let c := return {3}:R in + (* let d := return {4}:R in *) + return {exp_poisson O [#c(*{exp_var c erefl}*)]}]. + +End bidi_tests. +End bidi_tests. + +Section trivial_example. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Lemma exec_normalize_return g x r : + projT1 (@execD _ g _ [Normalize return r:R]) x = + @dirac _ (measurableTypeR R) r _ :> probability _ R. + (* NB: \d_r notation? *) +Proof. +by rewrite execD_normalize_pt execP_return execD_real//=; exact: normalize_kdirac. +Qed. + +End trivial_example. + +Section sample_pair. +Local Open Scope lang_scope. +Local Open Scope ring_scope. +Import Notations. +Context {R : realType}. + +Definition sample_pair1213' : @exp R _ [::] _ := + [let "x" := Sample Bernoulli {1 / 2}:R in + let "y" := Sample Bernoulli {1 / 3}:R in + return (#{"x"}, #{"y"})]. + +Definition sample_pair1213 : exp _ [::] _ := [Normalize {sample_pair1213'}]. + +Lemma exec_sample_pair1213' (A : set (bool * bool)) : + @execP R [::] _ sample_pair1213' tt A = + ((1 / 2)%:E * + ((1 / 3)%:E * \d_(true, true) A + + (1 - 1 / 3)%:E * \d_(true, false) A) + + (1 - 1 / 2)%:E * + ((1 / 3)%:E * \d_(false, true) A + + (1 - 1 / 3)%:E * \d_(false, false) A))%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli !execP_return /=. +rewrite execD_pair !exp_var'E (execD_var_erefl "x") (execD_var_erefl "y") /=. +rewrite !execD_real//=. +do 2 (rewrite letin'E/= integral_bernoulli_prob//=; last lra). +by rewrite letin'E/= integral_bernoulli_prob//=; lra. +Qed. + +Lemma exec_sample_pair1213'_TandT : + @execP R [::] _ sample_pair1213' tt [set (true, true)] = (1 / 6)%:E. +Proof. +rewrite exec_sample_pair1213' !diracE mem_set//; do 3 rewrite memNset//=. +by rewrite /= !mule0 mule1 !add0e mule0 adde0; congr (_%:E); lra. +Qed. + +Lemma exec_sample_pair1213'_TandT' : + @execP R [::] _ sample_pair1213' tt [set p | p.1 && p.2] = (1 / 6)%:E. +Proof. +rewrite exec_sample_pair1213' !diracE mem_set//; do 3 rewrite memNset//=. +by rewrite /= !mule0 mule1 !add0e mule0 adde0; congr (_%:E); lra. +Qed. + +Lemma exec_sample_pair1213'_TandF : + @execP R [::] _ sample_pair1213' tt [set (true, false)] = (1 / 3)%:E. +Proof. +rewrite exec_sample_pair1213' !diracE memNset// mem_set//; do 2 rewrite memNset//. +by rewrite /= !mule0 mule1 !add0e mule0 adde0; congr (_%:E); lra. +Qed. + +Lemma exec_sample_pair1213_TorT : + (projT1 (execD sample_pair1213)) tt [set p | p.1 || p.2] = (2 / 3)%:E. +Proof. +rewrite execD_normalize_pt normalizeE/= exec_sample_pair1213'. +rewrite !diracE; do 4 rewrite mem_set//=. +rewrite eqe ifF; last by apply/negbTE/negP => /orP[/eqP|//]; lra. +rewrite exec_sample_pair1213' !diracE; do 3 rewrite mem_set//; rewrite memNset//=. +by rewrite !mule1; congr (_%:E); field. +Qed. + +End sample_pair. + +Section sample_and. +Local Open Scope lang_scope. +Local Open Scope ring_scope. +Import Notations. +Context {R : realType}. + +Definition sample_and1213' : @exp R _ [::] _ := + [let "x" := Sample Bernoulli {1 / 2}:R in + let "y" := Sample Bernoulli {1 / 3}:R in + return #{"x"} && #{"y"}]. + +Lemma exec_sample_and1213' (A : set bool) : + @execP R [::] _ sample_and1213' tt A = ((1 / 6)%:E * \d_true A + + (1 - 1 / 6)%:E * \d_false A)%E. +Proof. +rewrite !execP_letin !execP_sample/= !execD_bernoulli execP_return /=. +rewrite !(@execD_bin _ _ binop_and) !exp_var'E (execD_var_erefl "x"). +rewrite (execD_var_erefl "y") /= !letin'E/= !execD_real/=. +rewrite integral_bernoulli_prob//=; last lra. +rewrite !letin'E/= integral_bernoulli_prob//=; last lra. +rewrite integral_bernoulli_prob//=; last lra. +rewrite /onem muleDr// -addeA; congr (_ + _)%E. + by rewrite !muleA; congr (_%:E); congr (_ * _); field. +rewrite -muleDl// !muleA -muleDl//. +by congr (_%:E); congr (_ * _); field. +Qed. + +Definition sample_and121212 : @exp R _ [::] _ := + [let "x" := Sample Bernoulli {1 / 2}:R in + let "y" := Sample Bernoulli {1 / 2}:R in + let "z" := Sample Bernoulli {1 / 2}:R in + return #{"x"} && #{"y"} && #{"z"}]. + +Lemma exec_sample_and121212 t U : + execP sample_and121212 t U = ((1 / 8)%:E * \d_true U + + (1 - 1 / 8)%:E * \d_false U)%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli !execP_return /=. +rewrite !(@execD_bin _ _ binop_and) !exp_var'E (execD_var_erefl "x"). +rewrite (execD_var_erefl "y") (execD_var_erefl "z") /= !execD_real/=. +do 3 (rewrite !letin'E/= integral_bernoulli_prob//=; last lra). +do 2 (rewrite integral_bernoulli_prob//=; last lra). +rewrite !letin'E/= integral_bernoulli_prob//=; last lra. +rewrite !muleDr// -!addeA; congr (_ + _)%E. + by rewrite !muleA; congr *%E; congr EFin; field. +rewrite !muleA -!muleDl//; congr *%E; congr EFin. +by rewrite /onem; field. +Qed. + +End sample_and. + +Section sample_score. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Definition bernoulli13_score : @exp R _ [::] _ := [Normalize + let "x" := Sample Bernoulli {1 / 3}:R in + let "_" := if #{"x"} then Score {1 / 3}:R else Score {2 / 3}:R in + return #{"x"}]. + +Lemma exec_bernoulli13_score : + execD bernoulli13_score = execD [Bernoulli {1 / 5}:R]. +Proof. +apply: eq_execD. +rewrite execD_bernoulli/= /bernoulli13_score execD_normalize_pt 2!execP_letin. +rewrite execP_sample/= execD_bernoulli/= execP_if /= exp_var'E. +rewrite (execD_var_erefl "x")/= !execP_return/= 2!execP_score !execD_real/=. +apply: funext=> g; apply: eq_probability => U. +rewrite normalizeE !letin'E/=. +under eq_integral. + move=> x _. + rewrite !letin'E. + under eq_integral do rewrite retE /=. + over. +rewrite /=. +rewrite integral_bernoulli_prob//=; [|lra|by move=> b; rewrite integral_ge0]. +rewrite iteE/= !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_indic//= !iteE/= /mscale/=. +rewrite setTI !diracT !mule1. +rewrite ger0_norm//. +rewrite -EFinD/= eqe ifF; last first. + by apply/negbTE/negP => /orP[/eqP|//]; rewrite /onem; lra. +rewrite integral_bernoulli_prob//=; last lra. +rewrite !letin'E/= !iteE/=. +rewrite !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_dirac//= !diracT !mul1e ger0_norm//. +rewrite exp_var'E (execD_var_erefl "x")/=. +rewrite !indicT/= !mulr1. +rewrite bernoulli_probE//=; last lra. +by rewrite muleDl//; congr (_ + _)%E; + rewrite -!EFinM; congr (_%:E); + rewrite !indicE /onem /=; case: (_ \in _); field. +Qed. + +Definition bernoulli12_score : @exp R _ [::] _ := [Normalize + let "x" := Sample Bernoulli {1 / 2}:R in + let "r" := if #{"x"} then Score {1 / 3}:R else Score {2 / 3}:R in + return #{"x"}]. + +Lemma exec_bernoulli12_score : + execD bernoulli12_score = execD [Bernoulli {1 / 3}:R]. +Proof. +apply: eq_execD. +rewrite execD_bernoulli/= /bernoulli12_score execD_normalize_pt 2!execP_letin. +rewrite execP_sample/= execD_bernoulli/= execP_if /= exp_var'E. +rewrite (execD_var_erefl "x")/= !execP_return/= 2!execP_score !execD_real/=. +apply: funext=> g; apply: eq_probability => U. +rewrite normalizeE !letin'E/=. +under eq_integral. + move=> x _. + rewrite !letin'E. + under eq_integral do rewrite retE /=. + over. +rewrite /= integral_bernoulli_prob//=; [|lra|by move=> b; rewrite integral_ge0]. +rewrite iteE/= !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_indic//= !iteE/= /mscale/=. +rewrite setTI !diracT !mule1. +rewrite ger0_norm//. +rewrite -EFinD/= eqe ifF; last first. + apply/negbTE/negP => /orP[/eqP|//]. + by rewrite /onem; lra. +rewrite integral_bernoulli_prob//=; last lra. +rewrite !letin'E/= !iteE/=. +rewrite !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_dirac//= !diracT !mul1e ger0_norm//. +rewrite exp_var'E (execD_var_erefl "x")/=. +rewrite bernoulli_probE//=; last lra. +rewrite !mul1r. +rewrite muleDl//; congr (_ + _)%E; + rewrite -!EFinM; + congr (_%:E); + by rewrite !indicT !indicE /onem /=; case: (_ \in _); field. +Qed. + +(* https://dl.acm.org/doi/pdf/10.1145/2933575.2935313 (Sect. 4) *) +Definition bernoulli14_score : @exp R _ [::] _ := [Normalize + let "x" := Sample Bernoulli {1 / 4}:R in + let "r" := if #{"x"} then Score {5}:R else Score {2}:R in + return #{"x"}]. + +Lemma exec_bernoulli14_score : + execD bernoulli14_score = execD [Bernoulli {5%:R / 11%:R}:R]. +Proof. +apply: eq_execD. +rewrite execD_bernoulli/= execD_normalize_pt 2!execP_letin. +rewrite execP_sample/= execD_bernoulli/= execP_if /= !exp_var'E. +rewrite !execP_return/= 2!execP_score !execD_real/=. +rewrite !(execD_var_erefl "x")/=. +apply: funext=> g; apply: eq_probability => U. +rewrite normalizeE !letin'E/=. +under eq_integral. + move=> x _. + rewrite !letin'E. + under eq_integral do rewrite retE /=. + over. +rewrite /= integral_bernoulli_prob//=; [|lra|by move=> b; rewrite integral_ge0]. +rewrite iteE/= !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_cst//= !diracT !(mule1,mul1e). +rewrite !indicT/= !mule1. +rewrite !iteE/= /mscale/=. +rewrite ger0_norm//. +rewrite !diracT/= !mul1r. +rewrite -EFinD/= eqe ifF; last first. + apply/negbTE/negP => /orP[/eqP|//]. + by rewrite /onem; lra. +rewrite integral_bernoulli_prob//=; last lra. +rewrite !letin'E/= !iteE/=. +rewrite !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_dirac//= !diracT !mul1e ger0_norm//. +rewrite bernoulli_probE//=; last lra. +rewrite !indicT. +rewrite muleDl//; congr (_ + _)%E; + rewrite -!EFinM; + congr (_%:E); + by rewrite !indicE /onem /=; case: (_ \in _); field. +Qed. + +End sample_score. + +Section sample_binomial. +Context {R : realType}. +Open Scope lang_scope. +Open Scope ring_scope. + +Definition sample_binomial3 : @exp R _ [::] _ := + [let "x" := Sample Binomial {3} {1 / 2}:R in + return #{"x"}]. + +Lemma exec_sample_binomial3 t U : measurable U -> + execP sample_binomial3 t U = ((1 / 8)%:E * \d_0%N U + + (3 / 8)%:E * \d_1%N U + + (3 / 8)%:E * \d_2%N U + + (1 / 8)%:E * \d_3%N U)%E. +Proof. +move=> mU; rewrite /sample_binomial3 execP_letin execP_sample execP_return. +rewrite exp_var'E (execD_var_erefl "x") !execD_binomial/= execD_real//=. +rewrite letin'E/= /= integral_binomial//=; [lra|move=> _]. +rewrite !big_ord_recl big_ord0/=. +rewrite /bump. +rewrite !binS/= !bin0 bin1 bin2 bin_small// addn0. +rewrite expr0 mulr1 mul1r subn0. +rewrite -2!addeA !mul1r. +congr _%:E. +rewrite !indicE /onem !addrA addr0 expr1/=. +by congr (_ + _ + _ + _); congr (_ * _); field. +Qed. + +End sample_binomial. + +Section nondeterminism_and_weights. +Context {R : realType}. +Open Scope lang_scope. +Open Scope ring_scope. + +Definition binomial2p (p : R) : @exp R _ [::] _ := + [let "x" := Sample Binomial {2} {p}:R in + return #{"x"}]. + +Definition return2 (p : R) : @exp R _ [::] _ := + [let "_" := Score {p ^+ 2}:R in return {2}:N]. + +Definition return1 (p : R) : @exp R _ [::] _ := + [let "_" := Score {p * `1-p *+ 2}:R in return {1}:N]. + +Definition return0 (p : R) : @exp R _ [::] _ := + [let "_" := Score {`1-p^+2}:R in return {0}:N]. + +Lemma exec_binomial2p (p : R) t U : 0 <= p <= 1 -> measurable U -> + execP (binomial2p p) t U = + execP (return2 p) t U + + execP (return1 p) t U + + execP (return0 p) t U. +Proof. +move=> /= /andP[p0 p1] mU. +(* simplify the lhs *) +rewrite [in LHS]execP_letin execP_sample/= execD_binomial/=. +rewrite execP_return/= !execD_real/= exp_var'E (execD_var_erefl "x")/=. +rewrite letin'E/= integral_binomial//=. +rewrite !big_ord_recr big_ord0//=. +rewrite !(bin0,bin1,bin2). +rewrite !(add0r,expr0,mul1r,mulr1,subn0,mulr1n,expr1). +(* simplify the rhs *) +rewrite /return2 /return1 /return0. +rewrite ![in RHS]execP_letin !execP_score/= !execD_real/=. +rewrite !execP_return/= !execD_nat/=. +rewrite !letin'E/=. +rewrite !ge0_integral_mscale//=. +rewrite !integral_dirac//=. +rewrite !diracT. +rewrite !(mul1e) ger0_norm ?sqr_ge0//. +rewrite ger0_norm ?mulrn_wge0 ?mulr_ge0 ?onem_ge0//. +rewrite ger0_norm ?sqr_ge0//. +by rewrite -addeA addeC -addeA addeCA addeA. +Qed. + +End nondeterminism_and_weights. + +Section hard_constraint. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType} {str : string}. + +Definition hard_constraint g : @exp R _ g _ := + [let str := Score {0}:R in return TT]. + +Lemma exec_hard_constraint g mg U : + execP (hard_constraint g) mg U = fail' (false, tt) U. +Proof. +rewrite execP_letin execP_score execD_real execP_return execD_unit/=. +rewrite letin'E integral_indic//= /mscale/= normr0 mul0e. +by rewrite /fail' letin'E/= ge0_integral_mscale//= normr0 mul0e. +Qed. + +Lemma exec_score_fail (r : R) (r01 : (0 <= r <= 1)%R) : + execP (g := [::]) [Score {r}:R] = + execP [let str := Sample Bernoulli {r}:R in + if #str then return TT else {hard_constraint _}]. +Proof. +move: r01 => /andP[r0 r1]//. +rewrite execP_score execD_real /= score_fail' ?r0 ?r1//. +rewrite execP_letin execP_sample/= execD_bernoulli execP_if execP_return. +rewrite execD_unit/= exp_var'E /=. + exact/ctx_prf_head (* TODO *). +move=> h. +apply: eq_sfkernel=> /= -[] U. +rewrite [LHS]letin'E/= [RHS]letin'E/=. +rewrite execD_real/=. +apply: eq_integral => b _. +rewrite 2!iteE//=. +case: b => //=. +- suff : projT1 (@execD R _ _ (exp_var str h)) (true, tt) = true by move=> ->. + set g := [:: (str, Bool)]. + have /= := @execD_var R [:: (str, Bool)] str. + by rewrite eqxx => /(_ h) ->. +- have -> : projT1 (@execD R _ _ (exp_var str h)) (false, tt) = false. + set g := [:: (str, Bool)]. + have /= := @execD_var R [:: (str, Bool)] str. + by rewrite eqxx /= => /(_ h) ->. + by rewrite (@exec_hard_constraint [:: (str, Bool)]). +Qed. + +End hard_constraint. + +Section test_uniform. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Context (R : realType). + +Definition uniform_syntax : @exp R _ [::] _ := + [let "p" := Sample Uniform {0} {1} {ltr01} in + return #{"p"}]. + +Lemma exec_uniform_syntax t U : measurable U -> + execP uniform_syntax t U = uniform_prob (@ltr01 R) U. +Proof. +move=> mU. +rewrite /uniform_syntax execP_letin execP_sample execP_return !execD_uniform. +rewrite exp_var'E (execD_var_erefl "p")/=. +rewrite letin'E /=. +rewrite integral_uniform//=; last exact: measurable_fun_dirac. +rewrite subr0 invr1 mul1e. +rewrite {1}/uniform_prob. +rewrite integral_mkcond//=. +rewrite [in RHS]integral_mkcond//=. +apply: eq_integral => x _. +rewrite !patchE. +case: ifPn => //; case: ifPn => //. +- move=> xU. + rewrite inE/= in_itv/= => x01. + by rewrite /uniform_pdf x01 diracE xU subr0 invr1. +- by rewrite diracE => /negbTE ->. +- move=> xU. + rewrite notin_setE/= in_itv/= => /negP/negbTE x01. + by rewrite /uniform_pdf x01. +Qed. + +End test_uniform. + +Section guard. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Context (R : realType). + +Definition guard : @exp R _ [::] _ := [ + let "p" := Sample Bernoulli {1 / 3}:R in + let "_" := if #{"p"} then return TT else Score {0}:R in + return #{"p"} +]. + +Lemma exec_guard t U : execP guard t U = ((1 / 3)%:E * \d_true U)%E. +Proof. +rewrite /guard 2!execP_letin execP_sample execD_bernoulli execD_real. +rewrite execP_if/= !execP_return !exp_var'E !(execD_var_erefl "p") execD_unit. +rewrite execP_score execD_real/=. +rewrite letin'E/= integral_bernoulli_prob//=; last lra. +rewrite !letin'E !iteE/= integral_dirac// ge0_integral_mscale//=. +by rewrite normr0 mul0e !mule0 !adde0 !diracT !mul1e. +Qed. + +End guard. + +Section test_binomial. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Context (R : realType). + +Definition binomial_le : @exp R _ [::] Bool := + [let "a2" := Sample Binomial {3} {1 / 2}:R in + return {1}:N <= #{"a2"}]. + +Lemma exec_binomial_le t U : + execP binomial_le t U = ((7 / 8)%:E * \d_true U + + (1 / 8)%:E * \d_false U)%E. +Proof. +rewrite /binomial_le execP_letin execP_sample execP_return execD_rel execD_nat. +rewrite exp_var'E (execD_var_erefl "a2") execD_binomial/= !execD_real/=. +rewrite letin'E//= integral_binomial//=; [lra|move=> _]. +rewrite !big_ord_recl big_ord0//=. +rewrite /bump. +rewrite !binS/= !bin0 bin1 bin2 bin_small// addn0. +rewrite addeC adde0. +congr (_ + _)%:E. + rewrite !indicE !(mul0n,add0n,lt0n,mul1r)/=. + rewrite -!mulrDl; congr (_ * _). + rewrite /onem. + lra. +rewrite !expr0 ltnn indicE/= !(mul1r,mul1e) /onem. +lra. +Qed. + +Definition binomial_guard : @exp R _ [::] Nat := + [let "a1" := Sample Binomial {3} {1 / 2}:R in + let "_" := if #{"a1"} == {1}:N then return TT else Score {0}:R in + return #{"a1"}]. + +Lemma exec_binomial_guard t U : + execP binomial_guard t U = ((3 / 8)%:E * \d_1%N U)%E. +Proof. +rewrite /binomial_guard !execP_letin execP_sample execP_return execP_if. +rewrite !exp_var'E execD_rel !(execD_var_erefl "a1") execP_return. +rewrite execD_unit execD_binomial execD_nat execP_score !execD_real. +rewrite !letin'E//=. +rewrite integral_binomial//=; [lra|move=> _]. +rewrite !big_ord_recl big_ord0. +rewrite /bump/=. +rewrite !binS/= !bin0 bin1 bin2 bin_small//. +rewrite !letin'E//= !iteE/=. +rewrite !ge0_integral_mscale//=. +rewrite !integral_dirac//= !diracE/=. +rewrite /bump/=. +rewrite !(normr0,mul0e,mule0,add0e,add0n,mul1e,adde0). +rewrite mem_set//=. +rewrite /onem mul1e. +congr (_%:E * _)%E. +lra. +Qed. + +End test_binomial. + +Section beta_bernoulli_bernoulli. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Context (R : realType). +Local Notation mu := (@lebesgue_measure R). + +(* TODO: move? *) +Lemma integrable_bernoulli_XMonemX01 a b U + (mu : {measure set (g_sigma_algebraType R.-ocitv.-measurable) -> \bar R}) : + measurable U -> (mu `[0%R, 1%R]%classic < +oo)%E -> + mu.-integrable `[0, 1] (fun x => bernoulli_prob (@XMonemX R a b \_`[0,1] x) U). +Proof. +move=> mU mu01oo. +apply/integrableP; split. + apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. + apply/measurable_restrict => //=; rewrite setIidr//. + exact: measurable_XMonemX. +apply: (@le_lt_trans _ _ (\int[mu]_(x in `[0%R, 1%R]) cst 1 x)%E). + apply: ge0_le_integral => //=. + apply/measurable_funTS/measurableT_comp => //=. + apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. + apply/measurable_restrict => //=; rewrite setIidr//. + exact: measurable_XMonemX. + by move=> x _; rewrite gee0_abs// probability_le1. +by rewrite integral_cst//= mul1e. +Qed. + +Let measurable_bernoulli_XMonemX01 U : + measurable_fun setT + (fun x : R => bernoulli_prob (@XMonemX R 1 0 \_`[0,1] x) U). +Proof. +apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. +apply/measurable_restrict => //=; rewrite setIidr//. +exact: measurable_XMonemX. +Qed. + +Lemma beta_bernoulli_bernoulli U : measurable U -> + @execP R [::] _ [let "p" := Sample Beta {6} {4} in + Sample Bernoulli #{"p"}] tt U = + @execP R [::] _ [Sample Bernoulli {3 / 5}:R] tt U. +Proof. +move=> mU. +rewrite execP_letin !execP_sample execD_beta !execD_bernoulli/=. +rewrite !execD_real/= exp_var'E (execD_var_erefl "p")/=. +transitivity (beta_prob_bernoulli_prob 6 4 1 0 U : \bar R). + rewrite /beta_prob_bernoulli_prob !letin'E/=. + rewrite integral_beta_prob//=; last 2 first. + exact: measurable_bernoulli_prob2. + exact: integral_beta_prob_bernoulli_prob_lty. + rewrite integral_beta_prob//=; last 2 first. + by apply: measurable_funTS => /=; exact: measurable_bernoulli_XMonemX01. + rewrite integral_beta_prob//=. + + suff: mu.-integrable `[0%R, 1%R] + (fun x => bernoulli_prob (@XMonemX R 1 0 \_`[0,1] x)%R U + * (beta_pdf 6 4 x)%:E)%E. + move=> /integrableP[_]. + under eq_integral. + move=> x _. + rewrite gee0_abs//; last first. + by rewrite mule_ge0// lee_fin beta_pdf_ge0. + over. + move=> ?. + by under eq_integral do rewrite gee0_abs//. + + apply: integrableMl => //=. + * apply: integrable_bernoulli_XMonemX01 => //=. + by rewrite lebesgue_measure_itv//= lte01 EFinN sube0 ltry. + * by apply: measurable_funTS; exact: measurable_beta_pdf. + * exact: bounded_beta_pdf_01. + + apply/measurableT_comp => //; apply: measurable_funTS => /=. + exact: measurable_bernoulli_XMonemX01. + + under eq_integral do rewrite gee0_abs//=. + have : (beta_prob 6 4 `[0%R, 1%R] < +oo :> \bar R)%E. + by rewrite -ge0_fin_numE// beta_prob_fin_num. + by move=> /(@integrable_bernoulli_XMonemX01 1 0 _ (beta_prob 6 4) mU) /integrableP[]. + rewrite [RHS]integral_mkcond. + apply: eq_integral => x _ /=. + rewrite patchE. + case: ifPn => x01. + by rewrite patchE x01 XMonemX0' expr1. + by rewrite /beta_pdf patchE (negbTE x01) mul0r mule0. +rewrite beta_prob_bernoulli_probE// !bernoulli_probE//=; last 2 first. + lra. + by rewrite div_beta_fun_ge0 div_beta_fun_le1. +(*by congr (_ * _ + _ * _)%:E; + rewrite /div_beta_fun/= /onem !beta_funE/=; repeat rewrite !factE/=; field.*) + (* temporary measure to avoid stack overflow *) +suff : div_beta_fun 6 4 1 0 = 3 / 5 :> R by move->. +rewrite /div_beta_fun/= /onem !beta_funE. +rewrite addn0 invfM mulrCA invrK. +rewrite addn1 8!addnS 2!addn0. +by rewrite (factS 9) !factS fact0; field. +Qed. +(* +congr (_ * _ + _ * _)%:E. +rewrite !factE/= !factE; field. +Qed. +*) + +End beta_bernoulli_bernoulli. + +Section letinA. +Local Open Scope lang_scope. +Variable R : realType. + +Lemma letinA g x y t1 t2 t3 (xyg : x \notin dom ((y, t2) :: g)) + (e1 : @exp R P g t1) + (e2 : exp P [:: (x, t1) & g] t2) + (e3 : exp P [:: (y, t2) & g] t3) : + forall U, measurable U -> + execP [let x := e1 in + let y := e2 in + {@exp_weak _ _ [:: (y, t2)] _ _ (x, t1) e3 xyg}] ^~ U = + execP [let y := + let x := e1 in e2 in + e3] ^~ U. +Proof. +move=> U mU; apply/funext=> z1. +rewrite !execP_letin. +rewrite (execP_weak [:: (y, t2)]). +apply: letin'A => //= z2 z3. +rewrite /kweak /mctx_strong /=. +by destruct z3. +Qed. + +Example letinA12 : forall U, measurable U -> + @execP R [::] _ [let "y" := return {1}:R in + let "x" := return {2}:R in + return #{"x"}] ^~ U = + @execP R [::] _ [let "x" := + let "y" := return {1}:R in return {2}:R in + return #{"x"}] ^~ U. +Proof. +move=> U mU. +rewrite !execP_letin !execP_return !execD_real. +apply: funext=> x. +rewrite !exp_var'E /= !(execD_var_erefl "x")/=. +exact: letin'A. +Qed. + +End letinA. + +Section staton_bus. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Definition staton_bus_syntax0 : @exp R _ [::] _ := + [let "x" := Sample Bernoulli {2 / 7}:R in + let "r" := if #{"x"} then return {3}:R else return {10}:R in + let "_" := Score {exp_poisson 4 [#{"r"}]} in + return #{"x"}]. + +Definition staton_bus_syntax := [Normalize {staton_bus_syntax0}]. + +Let sample_bern : R.-sfker munit ~> mbool := + sample _ (measurableT_comp measurable_bernoulli_prob (measurable_cst (2 / 7 : R)%R)). + +Let ite_3_10 : R.-sfker mbool * munit ~> measurableTypeR R := + ite macc0of2 (@ret _ _ _ (measurableTypeR R) R _ (kr 3)) (@ret _ _ _ (measurableTypeR R) R _ (kr 10)). + +Let score_poisson4 : R.-sfker measurableTypeR R * (mbool * munit) ~> munit := + score (measurableT_comp (measurable_poisson_pmf 4 measurableT) + (@macc0of2 _ _ (measurableTypeR R) _)). + +Let kstaton_bus' := + letin' sample_bern + (letin' ite_3_10 + (letin' score_poisson4 (ret macc2of4'))). + +Lemma eval_staton_bus0 : staton_bus_syntax0 -P> kstaton_bus'. +Proof. +apply: eval_letin. + by apply: eval_sample; apply: eval_bernoulli; exact: eval_real. +apply: eval_letin. + apply/evalP_if; [|exact/eval_return/eval_real..]. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "x")/=; congr existT. +apply: eval_letin. + apply/eval_score/eval_poisson. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "r")/=; congr existT. +apply/eval_return/execD_evalD. +by rewrite exp_var'E (execD_var_erefl "x")/=; congr existT. +Qed. + +Lemma exec_staton_bus0' : execP staton_bus_syntax0 = kstaton_bus'. +Proof. +rewrite 3!execP_letin execP_sample/= execD_bernoulli/= !execD_real. +rewrite /kstaton_bus'; congr letin'. +rewrite !execP_if !execP_return !execD_real/=. +rewrite exp_var'E (execD_var_erefl "x")/=. +have -> : measurable_acc_typ [:: Bool] 0 = macc0of2 by []. +congr letin'. +rewrite execP_score execD_poisson/=. +rewrite exp_var'E (execD_var_erefl "r")/=. +have -> : measurable_acc_typ [:: Real; Bool] 0 = macc0of2 by []. +congr letin'. +by rewrite exp_var'E (execD_var_erefl "x") /=; congr ret. +Qed. + +Lemma exec_staton_bus : execD staton_bus_syntax = + existT _ (normalize_pt kstaton_bus') (measurable_normalize_pt _). +Proof. by rewrite execD_normalize_pt exec_staton_bus0'. Qed. + +Let poisson4 := @poisson_pmf R ^~ 4%N. + +Let staton_bus_probability U := + ((2 / 7)%:E * (poisson4 3)%:E * \d_true U + + (5 / 7)%:E * (poisson4 10)%:E * \d_false U)%E. + +Lemma exec_staton_bus0 (U : set bool) : + execP staton_bus_syntax0 tt U = staton_bus_probability U. +Proof. +rewrite exec_staton_bus0' /staton_bus_probability /kstaton_bus'. +rewrite /sample_bern. +rewrite letin'E/=. +rewrite integral_bernoulli_prob//=; last lra. +rewrite -!muleA; congr (_ * _ + _ * _)%E. +- rewrite letin'_iteT//. + rewrite letin'_retk//. + rewrite letin'_kret//. + rewrite /score_poisson4. + by rewrite /score/= /mscale/= ger0_norm//= poisson_pmf_ge0. +- by rewrite onem27. +- rewrite letin'_iteF//. + rewrite letin'_retk//. + rewrite letin'_kret//. + rewrite /score_poisson4. + by rewrite /score/= /mscale/= ger0_norm//= poisson_pmf_ge0. +Qed. + +End staton_bus. + +(* same as staton_bus module associativity of letin *) +Section staton_busA. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Definition staton_busA_syntax0 : @exp R _ [::] _ := + [let "x" := Sample Bernoulli {2 / 7}:R in + let "_" := + let "r" := if #{"x"} then return {3}:R else return {10}:R in + Score {exp_poisson 4 [#{"r"}]} in + return #{"x"}]. + +Definition staton_busA_syntax : exp _ [::] _ := + [Normalize {staton_busA_syntax0}]. + +Let sample_bern : R.-sfker munit ~> mbool := + sample _ (measurableT_comp measurable_bernoulli_prob (measurable_cst (2 / 7 : R)%R)). + +Let ite_3_10 : R.-sfker mbool * munit ~> measurableTypeR R := + ite macc0of2 (@ret _ _ _ (measurableTypeR R) R _ (kr 3)) + (@ret _ _ _ (measurableTypeR R) R _ (kr 10)). + +Let score_poisson4 : R.-sfker measurableTypeR R * (mbool * munit) ~> munit := + score (measurableT_comp (measurable_poisson_pmf 4 measurableT) + (@macc0of3' _ _ _ (measurableTypeR R) _ _)). + +(* same as kstaton_bus _ (measurable_poisson 4) but expressed with letin' + instead of letin *) +Let kstaton_busA' := + letin' sample_bern + (letin' + (letin' ite_3_10 + score_poisson4) + (ret macc1of3')). + +Lemma eval_staton_busA0 : staton_busA_syntax0 -P> kstaton_busA'. +Proof. +apply: eval_letin. + by apply: eval_sample; apply: eval_bernoulli; exact: eval_real. +apply: eval_letin. + apply: eval_letin. + apply/evalP_if; [|exact/eval_return/eval_real..]. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "x")/=; congr existT. + apply/eval_score/eval_poisson. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "r")/=; congr existT. +apply/eval_return. +by apply/execD_evalD; rewrite exp_var'E (execD_var_erefl "x")/=; congr existT. +Qed. + +Lemma exec_staton_busA0' : execP staton_busA_syntax0 = kstaton_busA'. +Proof. +rewrite 3!execP_letin execP_sample/= execD_bernoulli execD_real. +rewrite /kstaton_busA'; congr letin'. +rewrite !execP_if !execP_return !execD_real/=. +rewrite exp_var'E (execD_var_erefl "x")/=. +have -> : measurable_acc_typ [:: Bool] 0 = macc0of2 by []. +congr letin'. + rewrite execP_score execD_poisson/=. + rewrite exp_var'E (execD_var_erefl "r")/=. + by have -> : measurable_acc_typ [:: Real; Bool] 0 = macc0of3' by []. +by rewrite exp_var'E (execD_var_erefl "x") /=; congr ret. +Qed. + +Lemma exec_statonA_bus : execD staton_busA_syntax = + existT _ (normalize_pt kstaton_busA') (measurable_normalize_pt _). +Proof. by rewrite execD_normalize_pt exec_staton_busA0'. Qed. + +(* equivalence between staton_bus and staton_busA *) +Lemma staton_bus_staton_busA : + execP staton_bus_syntax0 = @execP R _ _ staton_busA_syntax0. +Proof. +rewrite /staton_bus_syntax0 /staton_busA_syntax0. +rewrite execP_letin. +rewrite [in RHS]execP_letin. +congr (letin' _). +set e1 := exp_if _ _ _. +set e2 := exp_score _. +set e3 := (exp_return _ in RHS). +pose f := @found _ Unit "x" Bool [::]. +have r_f : "r" \notin [seq i.1 | i <- ("_", Unit) :: untag (ctx_of f)] by []. +have H := @letinA _ _ _ _ _ _ + (lookup Unit (("_", Unit) :: untag (ctx_of f)) "x") + r_f e1 e2 e3. +apply/eq_sfkernel => /= x U. +have mU : + (@mtyp_disp R (lookup Unit (("_", Unit) :: untag (ctx_of f)) "x")).-measurable U. + by []. +move: H => /(_ U mU) /(congr1 (fun f => f x)) <-. +set e3' := exp_return _. +set e3_weak := exp_weak _ _ _ _. +rewrite !execP_letin. +suff: execP e3' = execP (e3_weak e3 r_f) by move=> <-. +rewrite execP_return/= exp_var'E (execD_var_erefl "x") /= /e3_weak. +rewrite (@execP_weak R [:: ("_", Unit)] _ ("r", Real) _ e3 r_f). +rewrite execP_return exp_var'E/= (execD_var_erefl "x") //=. +by apply/eq_sfkernel => /= -[[] [a [b []]]] U0. +Qed. + +Let poisson4 := @poisson_pmf R ^~ 4%N. + +Lemma exec_staton_busA0 U : execP staton_busA_syntax0 tt U = + ((2 / 7%:R)%:E * (poisson4 3%:R)%:E * \d_true U + + (5%:R / 7%:R)%:E * (poisson4 10%:R)%:E * \d_false U)%E. +Proof. by rewrite -staton_bus_staton_busA exec_staton_bus0. Qed. + +End staton_busA. + +Section letinC. +Local Open Scope lang_scope. +Variable (R : realType). + +Let weak_head g {t1 t2} x (e : @exp R P g t2) (xg : x \notin dom g) := + exp_weak P [::] _ (x, t1) e xg. + +Lemma letinC g t1 t2 (e1 : @exp R P g t1) (e2 : exp P g t2) + (x y : string) + (xy : infer (x != y)) (yx : infer (y != x)) + (xg : x \notin dom g) (yg : y \notin dom g) : + forall U, measurable U -> + execP [ + let x := e1 in + let y := {weak_head e2 xg} in + return (#x, #y)] ^~ U = + execP [ + let y := e2 in + let x := {weak_head e1 yg} in + return (#x, #y)] ^~ U. +Proof. +move=> U mU; apply/funext => z. +rewrite 4!execP_letin. +rewrite 2!(execP_weak [::] g). +rewrite 2!execP_return/=. +rewrite 2!execD_pair/=. +rewrite !exp_var'E. +- exact/(ctx_prf_tail _ yx)/ctx_prf_head. +- exact/ctx_prf_head. +- exact/ctx_prf_head. +- exact/(ctx_prf_tail _ xy)/ctx_prf_head. +- move=> h1 h2 h3 h4. + set g1 := [:: (y, t2), (x, t1) & g]. + set g2 := [:: (x, t1), (y, t2) & g]. + have /= := @execD_var R g1 x. + rewrite (negbTE yx) eqxx => /(_ h4) ->. + have /= := @execD_var R g2 x. + rewrite (negbTE yx) eqxx => /(_ h2) ->. + have /= := @execD_var R g1 y. + rewrite eqxx => /(_ h3) ->. + have /= := @execD_var R g2 y. + rewrite (negbTE xy) eqxx => /(_ h1) -> /=. + have -> : measurable_acc_typ [:: t2, t1 & map snd g] 0 = macc0of3' by []. + have -> : measurable_acc_typ [:: t2, t1 & map snd g] 1 = macc1of3' by []. + rewrite (letin'C _ _ (execP e2) + [the R.-sfker _ ~> _ of @kweak _ [::] _ (y, t2) _ (execP e1)]); + [ |by [] | by [] |by []]. + have -> : measurable_acc_typ [:: t1, t2 & map snd g] 0 = macc0of3' by []. + by have -> : measurable_acc_typ [:: t1, t2 & map snd g] 1 = macc1of3' by []. +Qed. + +Example letinC_ground_variables g t1 t2 (e1 : @exp R P g t1) (e2 : exp P g t2) + (x := "x") (y := "y") + (xg : x \notin dom g) (yg : y \notin dom g) : + forall U, measurable U -> + execP [ + let x := e1 in + let y := {exp_weak _ [::] _ (x, t1) e2 xg} in + return (#x, #y)] ^~ U = + execP [ + let y := e2 in + let x := {exp_weak _ [::] _ (y, t2) e1 yg} in + return (#x, #y)] ^~ U. +Proof. by move=> U mU; rewrite letinC. Qed. + +Example letinC_ground (g := [:: ("a", Unit); ("b", Bool)]) t1 t2 + (e1 : @exp R P g t1) + (e2 : exp P g t2) : + forall U, measurable U -> + execP [let "x" := e1 in + let "y" := e2 :+ {"x"} in + return (#{"x"}, #{"y"})] ^~ U = + execP [let "y" := e2 in + let "x" := e1 :+ {"y"} in + return (#{"x"}, #{"y"})] ^~ U. +Proof. by move=> U mU; exact: letinC. Qed. + +End letinC. diff --git a/theories/lang_syntax_noisy.v b/theories/lang_syntax_noisy.v new file mode 100644 index 000000000..cad600659 --- /dev/null +++ b/theories/lang_syntax_noisy.v @@ -0,0 +1,805 @@ +Require Import String. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval. +From mathcomp.classical Require Import mathcomp_extra boolp. +From mathcomp Require Import ring lra. +From mathcomp Require Import classical_sets functions cardinality fsbigop. +From mathcomp Require Import interval_inference reals ereal topology normedtype. +From mathcomp Require Import sequences esum measure lebesgue_measure numfun exp. +From mathcomp Require Import trigo realfun charge lebesgue_integral kernel. +From mathcomp Require Import probability prob_lang. +From mathcomp Require Import lang_syntax_util lang_syntax lang_syntax_examples. + +(**md**************************************************************************) +(* # Observing a noisy draw from a normal distribution *) +(* *) +(* Formalization of Shan's HelloRight example (Sec. 2.3 of [Shan, POPL 2018]).*) +(* *) +(* ref: *) +(* - Chung-chieh Shan, Equational reasoning for probabilistic programming, *) +(* POPL TutorialFest 2018 *) +(* https://homes.luddy.indiana.edu/ccshan/rational/equational-handout.pdf *) +(* - Praveen Narayanan. Verifiable and reusable conditioning. PhD thesis, *) +(* Indiana University, 2019. *) +(* *) +(* ``` *) +(* noisyA == distribution of the next noisy measurement of a normally *) +(* distributed quantity *) +(* ``` *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Import Order.TTheory GRing.Theory Num.Def Num.ExtraDef Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. +Local Open Scope string_scope. + +(* TODO: PR? *) +Section ge0_bounded_measurable_probability_integrable. +Context d {T : measurableType d} {R : realType} {p : probability T R} + {f : T -> \bar R}. + +Lemma ge0_bounded_measurable_probability_integrable : + (forall x, 0 <= f x) -> (exists M : R, forall x, f x <= M%:E) -> + measurable_fun setT f -> p.-integrable setT f. +Proof. +move=> f0 [M fleM] mf. +apply/integrableP; split => //. +rewrite (@le_lt_trans _ _ (\int[p]_x M%:E))//. + apply: ge0_le_integral => //=. + - exact: measurableT_comp. + - by move=> t _; exact: (@le_trans _ _ (f t)). + - by move=> t _; rewrite gee0_abs. +by rewrite integral_cst// probability_setT mule1 ltry. +Qed. + +End ge0_bounded_measurable_probability_integrable. + +(* TODO: PR *) +Section pkernel_probability_integrable. +Context d d' {T : measurableType d} {T' : measurableType d'} {R : realType} + {p : probability T R} {f : R.-pker T ~> T'}. + +Lemma pkernel_probability_integrable V : measurable V -> + p.-integrable setT (fun x => f x V). +Proof. +move=> mV. +apply: ge0_bounded_measurable_probability_integrable => //. + exists 1%R => x. + rewrite (@le_trans _ _ (f x setT))//. + by rewrite le_measure ?inE. + by rewrite prob_kernel. +exact: measurable_kernel. +Qed. + +End pkernel_probability_integrable. + +(* TODO: move to probability.v? *) +Section normal_prob_lemmas. +Context {R: realType}. +Local Notation mu := lebesgue_measure. + +Local Open Scope charge_scope. + +Lemma normal_pdf_uniq_ae (m s : R) (s0 : (s != 0)%R) : + ae_eq mu setT + ('d ((charge_of_finite_measure (@normal_prob R m s))) '/d mu) + (EFin \o (@normal_pdf R m s)). +Proof. +apply: integral_ae_eq => //. +- by apply: Radon_Nikodym_integrable => /=; exact: normal_prob_dominates. +- by apply/measurable_EFinP; exact: measurable_normal_pdf. +- move=> /= E _ mE. + by rewrite -Radon_Nikodym_integral//=; exact: normal_prob_dominates. +Qed. + +Local Close Scope charge_scope. + +Lemma measurable_normal_prob (s : R) U : s != 0%R -> measurable U -> + measurable_fun setT (fun x => normal_prob x s U). +Proof. +move=> s0 mU. +under [X in _ _ X]eq_fun. + move=> /= x. + rewrite -(@fineK _ (_ x _ _)); last first. + rewrite ge0_fin_numE//. + rewrite (@le_lt_trans _ _ (normal_prob x s setT))//. + by rewrite le_measure ?inE. + by rewrite probability_setT ltry. + over. +apply/measurable_EFinP. +apply: (continuous_measurable_fun). +exact: normal_prob_continuous. +Qed. + +Lemma integral_normal_prob (m s : R) (s0 : (s != 0)%R) f U : + measurable U -> + (normal_prob m s).-integrable U f -> + \int[@normal_prob _ m s]_(x in U) f x = + \int[mu]_(x in U) (f x * (normal_pdf m s x)%:E). +Proof. +move=> mU intf. +move/integrableP : (intf) => [mf intf_lty]. +rewrite -(Radon_Nikodym_change_of_variables (normal_prob_dominates m s))//=. +apply: ae_eq_integral => //=. +- apply: emeasurable_funM => //. + apply: measurable_funTS. + have : charge_of_finite_measure (normal_prob m s) `<< mu. + exact: normal_prob_dominates m s. + by move=> /Radon_Nikodym_integrable /integrableP[]. +- apply: emeasurable_funM => //. + apply/measurable_EFinP. + apply: measurable_funTS. + exact: measurable_normal_pdf. +- apply: ae_eqe_mul2l. + apply: (ae_eq_subset (@subsetT _ U)). + exact: (normal_pdf_uniq_ae m s0). +Qed. + +Lemma normal_prob_integrable_dirac (m s : R) (V : set R): measurable V -> + (normal_prob m s).-integrable setT (fun x => \d_x V). +Proof. +move=> mV. +apply/integrableP; split; first exact: measurable_fun_dirac. +rewrite -(setUv V) ge0_integral_setU//; last 3 first. + exact: measurableC. + rewrite setUv. + apply: measurableT_comp => //. + exact: measurable_fun_dirac. + exact/disj_setPCl. +under eq_integral. + move=> x Vx. + rewrite diracE Vx/= normr1. + over. +under [X in _ + X < _]eq_integral. + move=> /= x. + rewrite inE/= => nVx. + have {}nVx := memNset nVx. + rewrite indicE nVx/= normr0. + over. +rewrite !integral_cst//=; last exact: measurableC. +rewrite mul1e mul0e adde0. +apply: (le_lt_trans (probability_le1 (normal_prob m s) mV)). +exact: ltey. +Qed. + +Lemma integral_normal_prob_dirac (s : R) (m : R) V : + (s != 0)%R -> measurable V -> + \int[normal_prob m s]_x (\d_x V) = normal_prob m s V. +Proof. +move=> s0 mV. +rewrite integral_normal_prob//; last exact: normal_prob_integrable_dirac. +under eq_integral do rewrite diracE. +rewrite /= /normal_prob [in RHS]integral_mkcond. +under [in RHS]eq_integral do rewrite patchE. +rewrite /=. +apply: eq_integral => x _. +by case: ifP => xV/=; rewrite ?mul1e ?mul0e. +Qed. + +End normal_prob_lemmas. + +(* TODO: move to probability.v *) +Section normal_probD. +Context {R : realType}. +Local Notation mu := lebesgue_measure. + +Let normal_pdf0 m s x : R := normal_peak s * normal_fun m s x. + +Let measurable_normal_pdf0 m s : measurable_fun setT (normal_pdf0 m s). +Proof. by apply: measurable_funM => //=; exact: measurable_normal_fun. Qed. + +Lemma normal_probD1 (m1 m2 s1 s2 : R) V : measurable V -> + s1 != 0%R -> s2 != 0%R -> + \int[normal_prob m1 s1]_x normal_prob (m2 + x) s2 V = + \int[mu]_(y in V) \int[mu]_x (normal_pdf (m2 + x) s2 y * normal_pdf m1 s1 x)%:E. +Proof. +move=> mV s10 s20; rewrite integral_normal_prob//; last first. + apply: ge0_bounded_measurable_probability_integrable => //=. + by exists 1%R => ?; exact: probability_le1. + apply: (@measurableT_comp _ _ _ _ _ _ + (fun x => normal_prob x s2 V) _ (fun x => m2 + x)). + exact: measurable_normal_prob. + exact: measurable_funD. +transitivity (\int[mu]_x \int[mu]_y + ((normal_pdf (m2 + x) s2 y * normal_pdf m1 s1 x)%:E * (\1_V y)%:E)). + apply: eq_integral => y _. + rewrite /normal_prob -integralZr//; last first. + by apply: (integrableS measurableT) => //; exact: integrable_normal_pdf. + transitivity (\int[mu]_(x in V) + (normal_pdf (m2 + y) s2 x * normal_pdf m1 s1 y)%:E). + by apply: eq_integral => z _; rewrite -EFinM. + by rewrite integral_mkcond epatch_indic. +rewrite (@fubini_tonelli _ _ _ _ _ mu mu (EFin \o + ((fun xz : R * R => (normal_pdf (m2 + xz.1) s2 xz.2 * + normal_pdf m1 s1 xz.1)%R * \1_V xz.2)%R)))/=; last 2 first. + apply/measurable_EFinP; apply: measurable_funM => /=; last first. + apply: measurable_indic; rewrite -[X in measurable X]setTX. + exact: measurableX. + apply: measurable_funM => /=. + rewrite [X in measurable_fun _ X](_ : _ = (fun x => + normal_pdf0 0 s2 (x.2 - (m2 + x.1)%E)))/=; last first. + by apply/funext=> x0; rewrite /normal_pdf0 normal_pdfE// normal_fun_center. + apply: measurableT_comp => /=; first exact: measurable_normal_pdf0. + under eq_fun do rewrite opprD. + by apply: measurable_funD => //=; exact: measurable_funB. + by apply: measurableT_comp => //; exact: measurable_normal_pdf. + by move=> x/=; rewrite lee_fin !mulr_ge0 ?normal_pdf_ge0. +transitivity (\int[mu]_x \int[mu]_y + ((fun y => (normal_pdf (m2 + y) s2 x * normal_pdf m1 s1 y)%:E) \_ (fun=> V x)) y). + apply: eq_integral => x0 _ /=. + under eq_integral do rewrite EFinM. + by rewrite -epatch_indic. +rewrite [RHS]integral_mkcond/=. +apply: eq_integral => /= x0 _. +rewrite patchE; case: ifPn => xV. + by apply: eq_integral => z _/=; rewrite patchE ifT. +apply: integral0_eq => /= z _. +rewrite patchE ifF//; apply/negbTE; rewrite notin_setE/=. +by move/negP : xV; rewrite inE. +Qed. + +Lemma normal_probD2 (y m1 m2 s1 s2 : R) : s1 != 0%R -> s2 != 0%R -> + \int[mu]_x (normal_pdf (m1 + x)%E s1 y * normal_pdf m2 s2 x)%:E = + (normal_peak s1 * normal_peak s2)%:E * + \int[mu]_z (normal_fun (m1 + z) s1 y * normal_fun m2 s2 z)%:E. +Proof. +move=> s10 s20. +rewrite -ge0_integralZl//=; last 3 first. + apply/measurable_EFinP => //=; apply: measurable_funM => //=. + - rewrite /normal_fun. + under eq_fun do rewrite -(sqrrN (y - _)) opprB (addrC m1) -addrA -opprB. + exact: measurable_normal_fun. + - exact: measurable_normal_fun. + by move=> z _; rewrite lee_fin mulr_ge0// expR_ge0. + by rewrite lee_fin mulr_ge0// ?normal_peak_ge0. +apply: eq_integral => /= z _. +by rewrite 2?normal_pdfE// /normal_pdf0 mulrACA /normal_fun. +Qed. + +Lemma normal_peak1 : normal_peak 1 = (Num.sqrt (pi *+ 2))^-1%R :> R. +Proof. by rewrite /normal_peak expr1n mul1r. Qed. + +(* Variable elimination and integration [Shan, Section 3.5, (9)], + * also known as the reproductive property of normal distribution. + *) +Lemma normal_probD (m1 s1 m2 s2 : R) V : s1 != 0%R -> s2 != 0%R -> + measurable V -> + \int[normal_prob m1 s1]_x normal_prob (m2 + x) s2 V = + normal_prob (m1 + m2) (Num.sqrt (s1 ^+ 2 + s2 ^+ 2)) V. +Proof. +move=> s10 s20 mV. +rewrite normal_probD1//; apply: eq_integral => y _. +clear V mV. +rewrite normal_probD2//. +have s1s20 : (s1 ^+ 2 + s2 ^+ 2 != 0)%R. + by rewrite lt0r_neq0// addr_gt0// exprn_even_gt0. +have sqs1s20 : Num.sqrt (s1 ^+ 2 + s2 ^+ 2) != 0%R. + by rewrite lt0r_neq0// sqrtr_gt0 addr_gt0// exprn_even_gt0. +rewrite normal_pdfE /normal_pdf0//. +set S1 := (s1 ^+ 2)%R. +set S2 := (s2 ^+ 2)%R. +transitivity (((Num.sqrt S1 * Num.sqrt S2 * pi *+ 2)^-1)%:E * + \int[mu]_x (expR + (- (x - (y * s1 ^+ 2 + m1 * s2 ^+ 2 - m2 * s1 ^+ 2) + / (s1 ^+ 2 + s2 ^+ 2)%R ) ^+ 2 + / ((Num.sqrt ((s1 ^+ 2 * s2 ^+ 2) / (s1 ^+ 2 + s2 ^+ 2)%R) ^+ 2) *+ 2) + - (y - (m1 + m2)) ^+ 2 / ((s1 ^+ 2 + s2 ^+ 2) *+ 2)))%:E). + congr *%E. + rewrite /normal_peak. + congr EFin. + rewrite -2!(mulr_natr (_ * pi)). + rewrite !(sqrtrM 2) ?(@mulr_ge0 _ _ pi) ?sqr_ge0 ?pi_ge0//. + rewrite !(sqrtrM pi) ?sqr_ge0//. + rewrite ![in LHS]invfM. + rewrite mulrACA -(@sqrtrV _ 2)// -(expr2 (_ _^-1)%R). + rewrite (@sqr_sqrtr _ 2^-1) ?invr_ge0//. + rewrite mulrACA -(@sqrtrV _ pi) ?pi_ge0//. + rewrite -(expr2 (_ _^-1)%R) (@sqr_sqrtr _ pi^-1) ?invr_ge0// ?pi_ge0//. + rewrite -!invfM; congr GRing.inv. + by rewrite -[in RHS]mulr_natr (mulrC _ (Num.sqrt _)). + apply: eq_integral. + move=> x _. + rewrite -expRD. + congr ((expR _)%:E). + rewrite sqr_sqrtr; last first. + rewrite mulr_ge0 ?invr_ge0// ?addr_ge0 ?(@mulr_ge0 _ (_ ^+ 2))// ?sqr_ge0//. + by field; do ?[apply/and3P; split]. +set DS12 := S1 + S2. +set MS12 := (S1 * S2)%R. +set C := ((((y * s1 ^+ 2)%R + (m1 * s2 ^+ 2)%R)%E - m2 * s1 ^+ 2) / DS12)%R. +under eq_integral do rewrite expRD EFinM. +rewrite ge0_integralZr//=; last first. + apply/measurable_EFinP. + apply: measurableT_comp => //. + apply: measurable_funM => //. + apply: measurableT_comp => //. + apply: (@measurableT_comp _ _ _ _ _ _ (fun t : R => t ^+ 2)%R) => //. + exact: measurable_funD. +rewrite /normal_peak /normal_fun. +rewrite [in RHS]EFinM. +rewrite [in RHS]sqr_sqrtr//; last first. + by rewrite addr_ge0// sqr_ge0. +rewrite muleA; congr *%E; last by rewrite -mulNr. +(* gauss integral *) +have MS12DS12_gt0 : (0 < MS12 / DS12)%R. + rewrite divr_gt0//. + by rewrite mulr_gt0// exprn_even_gt0. + by rewrite addr_gt0// exprn_even_gt0. +transitivity (((Num.sqrt S1 * Num.sqrt S2 * pi *+ 2)^-1)%:E + * \int[mu]_x ((normal_peak (Num.sqrt (MS12 / DS12)))^-1%:E + * (normal_pdf C (Num.sqrt (MS12 / DS12)) x)%:E)). + congr *%E. + apply: eq_integral => x _. + rewrite -EFinM; congr EFin. + rewrite normal_pdfE; last first. + apply: lt0r_neq0. + by rewrite sqrtr_gt0. + rewrite mulrA mulVf// ?mul1r//. + rewrite lt0r_neq0// invr_gt0 sqrtr_gt0 pmulrn_lgt0// mulr_gt0// ?pi_gt0//. + rewrite exprn_even_gt0//=. + by rewrite lt0r_neq0// sqrtr_gt0. +rewrite ge0_integralZl//; last 3 first. + apply/measurable_EFinP. + exact: measurable_normal_pdf. + move=> x _. + rewrite lee_fin. + exact: normal_pdf_ge0. + rewrite lee_fin invr_ge0. + exact: normal_peak_ge0. +rewrite integral_normal_pdf. +rewrite mule1 -EFinM; congr EFin. +rewrite -invfM; congr GRing.inv. +rewrite -sqrtrM ?sqr_ge0//. +rewrite /normal_peak sqr_sqrtr; last by rewrite ltW. +rewrite -3!mulrnAr. +rewrite (sqrtrM (pi *+ 2)); last by rewrite ltW. +rewrite invfM mulrCA. +rewrite -{1}(@sqr_sqrtr _ (pi *+ 2)); last by rewrite pmulrn_lge0 ?pi_ge0. +rewrite -2!(mulrA (Num.sqrt _)) divff// ?mulr1; last first. + by rewrite lt0r_neq0// sqrtr_gt0 pmulrn_lgt0 ?pi_gt0. +rewrite (sqrtrM (DS12^-1)); last by rewrite mulr_ge0 ?sqr_ge0. +rewrite sqrtrV; last by rewrite addr_ge0 ?sqr_ge0. +rewrite invfM invrK. +rewrite mulrAC mulrA mulVf ?mul1r; last first. + by rewrite lt0r_neq0// sqrtr_gt0 mulr_gt0 ?exprn_even_gt0. +rewrite sqrtrM; last by rewrite addr_ge0 ?sqr_ge0. +by rewrite mulrC. +Qed. + +End normal_probD. + +Section noisy_programs. +Local Open Scope lang_scope. +Context {R : realType}. +Local Notation mu := lebesgue_measure. + +Definition exp_normal1 {g} (e : exp D g Real) := + [Normal e {1%R} {oner_neq0 R}]. + +(* NB: exp_powR level setting is mistaken? *) +(* ((_ `^ _) * _) cannot write as (_ `^ _ * _) *) +Definition noisyA' : @exp R P [:: ("y0", Real)] Real := + [let "x" := Sample {exp_normal1 [{0}:R]} in + let "_" := Score ({expR 1} `^ + ({0}:R - (#{"y0"} - #{"x"}) ^+ {2%R} * {2^-1}:R)) + * {(Num.sqrt (2 * pi))^-1}:R in + let "z" := Sample {exp_normal1 [#{"x"}]} in + return #{"z"}]. + +Definition noisyA : @exp R _ [:: ("y0", Real)] _ := [Normalize {noisyA'}]. + +(* other programs from Sect. 2.3 of [Shan, POPL 2018], + nothing proved about them yet, just for the sake of completeness *) +Definition guard_real {g} str (r : R) : + @exp R P [:: (str, _) ; g] _ := + [if #{str} ==R {r}:R then return TT else Score {0}:R]. + +Definition helloWrong (y0 : R) : @exp R _ [::] _ := + [Normalize + let "x" := Sample {exp_normal1 (exp_real 0)} in + let "y" := Sample {exp_normal1 [#{"x"}]} in + let "_" := {guard_real "y" y0} in + let "z" := Sample {exp_normal1 [#{"x"}]} in + return #{"z"}]. + +Definition helloJoint : @exp R _ [::] _ := + [Normalize + let "x" := Sample {exp_normal1 (exp_real 0)} in + let "y" := Sample {exp_normal1 [#{"x"}]} in + let "z" := Sample {exp_normal1 [#{"x"}]} in + return (#{"y"}, #{"z"})]. + +End noisy_programs. + +(* The following section contains the mathematical facts that are used + to verify the noisy program. They are proved beforehand as an attempt + to optimize the time spent by the Qed command of Rocq. *) +Section noisy_subproofs. +Local Open Scope lang_scope. +Context {R : realType}. +Local Notation mu := lebesgue_measure. + +Local Definition noisyA_semantics_normal + (y : (@mctx R [:: ("y0", Real)])) (V : set (@mtyp R Real)) := + \int[normal_prob 0 1]_x (fun z => + (expR (- ((y.1 - z) ^+ 2%R / 2)) / Num.sqrt (2 * pi))%:E * + normal_prob z 1 V) x. + +Lemma noisyA_semantics_normalE y V : measurable V -> + noisyA_semantics_normal y V = + \int[mu]_x + ((expR (- ((y.1 - x) ^+ 2%R / 2)) / Num.sqrt (2 * pi))%:E * + normal_prob x 1 V * + (normal_pdf 0 1 x)%:E). +Proof. +move=> mV; rewrite /noisyA_semantics_normal. +rewrite integral_normal_prob//. +apply: integrableMr => //. + apply: measurable_funM => //. + under eq_fun do rewrite -sqrrN opprB -mulNr. + rewrite [X in fun _ => expR (_ / X)%R](_:2 = 1 ^+ 2 *+ 2)%R; last first. + by rewrite expr1n. + exact: measurable_normal_fun. + exists (Num.sqrt (2 * pi))^-1%R; split; first exact: num_real. + move=> x x_gt p _. + rewrite /= ger0_norm; last by rewrite mulr_ge0// expR_ge0. + apply/ltW; apply: le_lt_trans x_gt. + rewrite -[leRHS]mul1r ler_pM ?expR_ge0//. + by rewrite -expR0 ler_expR oppr_le0 mulr_ge0// ?sqr_ge0// expR0 invr_ge0. +apply/integrableP; split; first exact: measurable_normal_prob. +apply/abse_integralP => //; first exact: measurable_normal_prob. +rewrite gee0_abs//; last exact: integral_ge0. +apply: (@le_lt_trans _ _ (\int[normal_prob 0 1]_x (cst 1%R x)%:E)). + apply: ge0_le_integral => //; first exact: measurable_normal_prob. + by apply/measurable_EFinP; exact: measurable_cst. + by move=> x _; exact: probability_le1. +by rewrite /= integral_cst// mul1e probability_setT ltry. +Qed. + +Local Definition noisyB_semantics_normal + (y : (@mctx R [:: ("y0", Real)])) (V : set (@mtyp R Real)) := + \int[normal_prob (y.1 / 2) (Num.sqrt 2)^-1]_x + ((expR (- (y.1 ^+ 2%R / 4)) / + Num.sqrt (4 * pi))%:E * + normal_prob x 1 V)%E. + +Lemma noisyB_semantics_normalE y V : measurable V -> + noisyB_semantics_normal y V = + \int[mu]_x + ((expR (- (y.1 ^+ 2%R / 4)) / Num.sqrt (4 * pi))%:E * + (normal_prob x 1 V * (normal_pdf (y.1 / 2) (Num.sqrt 2)^-1 x)%:E)). +Proof. +move=> mV; rewrite /noisyB_semantics_normal integral_normal_prob//. + by under eq_integral do rewrite -muleA. +apply/integrableP; split. + by apply: measurable_funeM; exact: measurable_normal_prob. +apply: (@le_lt_trans _ _ (\int[normal_prob (y.1 / 2) (Num.sqrt 2)^-1]_x + ((NngNum (normr_ge0 (expR (- (y.1 ^+ 2%R / 4)) / Num.sqrt (4 * pi))))%:num%:E + * (cst 1%R x)%:E))). +apply: ge0_le_integral; [by []|by []| | | |]. +- apply: measurableT_comp => //. + by apply: measurable_funeM; exact: measurable_normal_prob. +- by move=> x _; rewrite /= mule1. +- by rewrite /= mule1; exact/measurable_EFinP. +- move=> x _/=. + rewrite abseM/= lee_pmul//. + rewrite gee0_abs; last exact: measure_ge0. + by rewrite probability_le1. +- rewrite integralZr//; last exact: finite_measure_integrable_cst. + by rewrite integral_cst// mule1 probability_setT ltry. +Qed. + +Local Definition noisyA'_part + (y : (@mctx R [:: ("y0", Real)])) (x : R) := + ((expR (- ((y.1 - x) ^+ 2%R / 2)) / Num.sqrt (2 * pi)) * (normal_pdf 0 1 x))%R. + +Local Definition noisyB'_part + (y : (@mctx R [:: ("y0", Real)])) (x : R) := + ((expR (- (y.1 ^+ 2%R / 4)) / Num.sqrt (4 * pi)) * (normal_pdf (y.1 / 2) (Num.sqrt 2)^-1 x))%R. + +Lemma noisyAB'_rearrange (y : (@mctx R [:: ("y0", Real)])) x : + noisyA'_part y x = noisyB'_part y x. +Proof. +rewrite /noisyA'_part/noisyB'_part !normal_pdfE//. +rewrite mulrA mulrAC -(@sqrtrV _ 2)//. +rewrite /normal_peak sqr_sqrtr; last by rewrite invr_ge0. +rewrite /normal_fun subr0 sqr_sqrtr; last by rewrite invr_ge0. +rewrite -mulrA mulrAC mulrA. +rewrite [X in (X / _ / _)%R = _](_ : _ = + expR (- x ^+ 2 / (1 ^+ 2 *+ 2) - (y.1 - x) ^+ 2%R / 2)); last first. + by rewrite mulrC -expRD -mulrA. +rewrite [RHS]mulrAC [X in _ = (_ * X / _)%R]mulrC mulrA -mulrA -[RHS]mulrA. +rewrite -expRD -2!invfM. +congr ((expR _) * _^-1)%R. + lra. +rewrite -2?sqrtrM; last 2 first. + by rewrite -mulr_natr mulrAC mulVf// mul1r pi_ge0. + by rewrite expr1n mul1r mulrn_wge0// pi_ge0. +congr Num.sqrt. +lra. +Qed. + +Lemma noisyC_semanticsE (y : R) V : measurable V -> + \int[normal_prob y (Num.sqrt 2)^-1]_x normal_prob x 1 V = + normal_prob y (Num.sqrt (3 / 2)) V. +Proof. +move=> mV. +have := @normal_probD R (y ) (Num.sqrt 2)^-1 0 1 _ _ _ mV. +under eq_integral do rewrite add0r. +rewrite addr0. +rewrite (_ : ((Num.sqrt 2)^-1 ^+ 2 + 1 ^+ 2 = 3 / 2)%R)//; last first. + rewrite exprVn sqr_sqrtr// expr1n -[in LHS]div1r -{3}(@divff _ 1%R)//. + rewrite addf_div// 2!mulr1 mul1r (_:1%R = 1%:R)// -natrD. +exact. +Qed. + +End noisy_subproofs. + +(* this section contains the verification of the noisy program per se *) +Section noisy_verification. +Local Open Scope lang_scope. +Context {R : realType}. +Local Notation mu := lebesgue_measure. + +(* definition of intermediate programs *) +Definition neq0Vsqrt2 : ((Num.sqrt 2)^-1 != 0 :> R)%R. +Proof. exact: lt0r_neq0. Qed. + +Definition exp_normal_Vsqrt2 {g} (e : exp D g Real) := + [Normal e {(Num.sqrt 2)^-1%R} {neq0Vsqrt2}]. + +Definition tailB : @exp R _ [:: ("_", Unit); ("y0", Real)] Real := + [let "x" := Sample {exp_normal_Vsqrt2 [#{"y0"} * {2^-1}:R]} in + let "z" := Sample {exp_normal1 [#{"x"}]} in + return #{"z"}]. +Definition noisyB' : @exp R _ [:: ("y0", Real)] Real := + [let "_" := Score ({expR 1} `^ ({0}:R - #{"y0"} ^+ {2%R} * {4^-1}:R)) * + {(Num.sqrt (4 * pi))^-1}:R in + {tailB}]. +Definition noisyB : @exp R _ [:: ("y0", Real)] _ := [Normalize {noisyB'}]. + +Definition neq0sqrt32 : (Num.sqrt (3 / 2) != 0 :> R)%R. +Proof. exact: lt0r_neq0. Qed. +Definition exp_normal_sqrt32 {g} (e : exp D g Real) := + [Normal e {Num.sqrt (3 / 2)} {neq0sqrt32}]. + +Definition tailC : @exp R _ [:: ("_", Unit); ("y0", Real)] Real := + [Sample {exp_normal_sqrt32 [#{"y0"} * {2^-1}:R]}]. +Definition noisyC' : @exp R _ [:: ("y0", Real)] _ := + [let "_" := Score ({expR 1} `^ ({0}:R - #{"y0"} ^+ {2%R} * {4^-1}:R)) * + {(Num.sqrt (4 * pi))^-1}:R in + {tailC}]. +Definition noisyC : @exp R _ [:: ("y0", Real)] _ := [Normalize {noisyC'}]. + +Lemma noisyAB' y V : measurable V -> execP noisyA' y V = execP noisyB' y V. +Proof. +move=> mV. +(* reduce the lhs *) +rewrite 3![in LHS]execP_letin. +rewrite execP_sample. +rewrite execD_normal/=. +rewrite execD_real/=. +rewrite execP_score. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite execD_pow_real/=. +rewrite (@execD_bin _ _ binop_minus)/=. +rewrite execD_real/=. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite execD_pow/=. +rewrite (@execD_bin _ _ binop_minus)/=. +rewrite exp_var'E/= (execD_var_erefl "y0")/=. +rewrite exp_var'E/= (execD_var_erefl "x")/=. +rewrite 2!execD_real/=. +rewrite execP_sample/=. +rewrite execD_normal/=. +rewrite exp_var'E/= (execD_var_erefl "x")/=. +rewrite execP_return/=. +rewrite exp_var'E/= (execD_var_erefl "z")/=. +(* reduce the rhs *) +rewrite [in RHS]execP_letin. +rewrite execP_score. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite execD_pow_real/=. +rewrite (@execD_bin _ _ binop_minus)/=. +rewrite execD_real/=. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite execD_pow/=. +rewrite exp_var'E/= (execD_var_erefl "y0")/=. +rewrite execD_real/=. +rewrite execD_real/=. +rewrite execP_letin. +rewrite execP_sample. +rewrite execD_normal/=. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite exp_var'E/= (execD_var_erefl "y0")/=. +rewrite execD_real/=. +rewrite execP_letin. +rewrite execP_sample. +rewrite execD_normal/=. +rewrite exp_var'E/= (execD_var_erefl "x")/=. +rewrite execP_return. +rewrite exp_var'E/= (execD_var_erefl "z")/=. +(* semantics *) +transitivity (noisyA_semantics_normal y V). + rewrite [in LHS]letin'E/=. + apply: eq_integral => x _. + rewrite letin'E/=. + under eq_integral. + move=> u _. + rewrite letin'E/= integral_normal_prob_dirac//. + over. + rewrite /= ge0_integral_mscale//. + rewrite integral_dirac// diracT mul1e sub0r -expRM mul1r/=. + by rewrite ger0_norm; last by rewrite mulr_ge0// expR_ge0. +transitivity (noisyB_semantics_normal y V); last first. + rewrite letin'E/=. + under eq_integral. + move=> u _. + rewrite letin'E/=. + under eq_integral. + move=> x _. + rewrite letin'E/=. + rewrite integral_normal_prob_dirac//. + over. + over. + rewrite /= ge0_integral_mscale//; first last. + by move=> ? _; exact: integral_ge0. + rewrite integral_dirac// diracT mul1e sub0r -expRM mul1r/=. + rewrite ger0_norm; last by rewrite mulr_ge0// expR_ge0. + rewrite -(@ge0_integralZl _ _ R (normal_prob _ _) _ measurableT _ _ + (expR (- (y.1 ^+ 2%R / 4)) / Num.sqrt (4 * pi))%:E)//. + exact: measurable_normal_prob. +rewrite noisyA_semantics_normalE//. +rewrite noisyB_semantics_normalE//. +apply: eq_integral => x _. +transitivity ((noisyA'_part y x)%:E * normal_prob x 1 V). + by rewrite muleAC. +transitivity ((noisyB'_part y x)%:E * normal_prob x 1 V); last first. + by rewrite (muleC (normal_prob x 1 V)) muleA. +congr (fun t => t%:E * normal_prob x 1 V)%E. +exact: noisyAB'_rearrange. +Qed. + +(* from (7) to (9) in [Shan, POPL 2018] *) +Lemma tailBC y V : measurable V -> + @execP R [:: ("_", Unit); ("y0", Real)] _ tailB y V = + @execP R [:: ("_", Unit); ("y0", Real)] _ tailC y V. +Proof. +move=> mV. +(* execute lhs *) +rewrite 2![in LHS]execP_letin. +rewrite 2![in LHS]execP_sample. +rewrite 2!execD_normal/=. +rewrite (@execD_bin _ _ binop_mult) execD_real/=. +rewrite execP_return. +rewrite exp_var'E (execD_var_erefl "y0")/=. +rewrite exp_var'E (execD_var_erefl "x")/=. +rewrite exp_var'E (execD_var_erefl "z")/=. +rewrite ![in LHS]letin'E/=. +under eq_integral do rewrite letin'E/=. +(* execute rhs *) +rewrite [in RHS]execP_sample/=. +rewrite execD_normal/=. +rewrite (@execD_bin _ _ binop_mult) execD_real/=. +rewrite exp_var'E (execD_var_erefl "y0")/=. +(* prove semantics *) +under eq_integral do rewrite integral_normal_prob_dirac//=. +by rewrite noisyC_semanticsE. +Qed. + +Lemma noisyBC' y V : measurable V -> execP noisyB' y V = execP noisyC' y V. +Proof. +move=> mV. +rewrite /noisyB' /noisyC'. +rewrite execP_letin/= [in RHS]execP_letin/=. +rewrite letin'E [RHS]letin'E. +by under eq_integral do rewrite tailBC//. +Qed. + +Lemma noisyAC' y V : measurable V -> execP noisyA' y V = execP noisyC' y V. +Proof. by move=> mV; rewrite noisyAB'// noisyBC'. Qed. + +End noisy_verification. + +(* Trying to show why rewriting noisyB to noisyC is reproductive property *) +(* +Section rewrite noisyB'_to_variable_addition. + +Definition noisyB'_alter : @exp R _ [:: ("y0", Real)] Real := + [let "_" := Score ({expR 1} `^ ({0}:R - #{"y0"} ^+ {2} * {4^-1}:R)) * + {(Num.sqrt (4 * pi))^-1}:R in + let "x" := Sample {exp_normal_Vsqrt2 [#{"y0"} * {2^-1}:R]} in + let "x1" := Sample {exp_normal1 [{0}:R]} in + return #{"x"} + #{"x1"}]. + +Definition noisyB_alter : @exp R _ [:: ("y0", Real)] _ := + [Normalize {noisyB'_alter}]. + +Lemma execP_noisyB'_alterE y V : +@execP R [:: ("y0", Real)] Real noisyB'_alter y V = executed_noisyB'_alter y V. +Proof. +rewrite 3!execP_letin. +rewrite execP_return/=. +rewrite (@execD_bin _ _ binop_add)/=. +rewrite (exp_var'E "x") (exp_var'E "x1"). +rewrite (execD_var_erefl "x") (execD_var_erefl "x1")/=. +rewrite 2!execP_sample. +rewrite 2!execD_normal/=. +rewrite execD_real/=. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite execD_real/=. +rewrite execP_score. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite execD_pow_real/=. +rewrite (@execD_bin _ _ binop_minus)/=. +rewrite execD_real/=. +rewrite (@execD_bin _ _ binop_mult)/=. +rewrite execD_pow/=. +rewrite 2!execD_real/=. +rewrite 2!(exp_var'E "y0") 2!(execD_var_erefl "y0")/=. +rewrite /executed_noisyB'_alter. +rewrite /=. +Abort. + +Lemma noisyB_alterE y V : measurable V -> + @execP R [:: ("y0", Real)] Real noisyB' y V = + @execP R [:: ("y0", Real)] Real noisyB'_alter y V. +Proof. +move=> mV. +rewrite 3![in RHS]execP_letin. +rewrite [in RHS]execP_return/=. +rewrite [in RHS](@execD_bin _ _ binop_add)/=. +rewrite [in RHS](exp_var'E "x") (exp_var'E "x1"). +rewrite [in RHS](execD_var_erefl "x") (execD_var_erefl "x1")/=. +rewrite 2![in RHS]execP_sample. +rewrite 2![in RHS]execD_normal/=. +rewrite [in RHS]execD_real/=. +rewrite [in RHS](@execD_bin _ _ binop_mult)/=. +rewrite [in RHS]execD_real/=. +rewrite [in RHS]execP_score. +rewrite [in RHS](@execD_bin _ _ binop_mult)/=. +rewrite [in RHS]execD_pow_real/=. +rewrite [in RHS](@execD_bin _ _ binop_minus)/=. +rewrite [in RHS]execD_real/=. +rewrite [in RHS](@execD_bin _ _ binop_mult)/=. +rewrite [in RHS]execD_pow/=. +rewrite 2![in RHS]execD_real/=. +rewrite 2![in RHS](exp_var'E "y0") 2!(execD_var_erefl "y0")/=. +rewrite [in RHS]letin'E/=. +under [RHS]eq_integral. + move=> x _. + rewrite letin'E/=. + under eq_integral. + move=> z _. + rewrite letin'E/=. + rewrite integral_normal_prob_dirac//; last first. + admit. (* standard property of measurable sets *) + rewrite (_ : (fun x0 => V (z + x0)) = ((fun x0 => z + x0) @^-1` V)); last by []. + rewrite (_ : normal_prob 0 1 ((fun x0 => z + x0) @^-1` V) = normal_prob z 1 V); last first. + admit. (* general version of integration by substitution *) + over. + over. +rewrite ge0_integral_mscale//; last first. + move=> x _. + apply: integral_ge0. + by move=> z _. +rewrite integral_dirac// diracT mul1e. +rewrite -ge0_integralZl//; last first. + exact: emeasurable_normal_prob. +rewrite sub0r. +rewrite -expRM mul1r. + +rewrite execP_noisyB'E. +by rewrite executed_noisyB'_semantics. +Abort. + +End rewrite noisyB'_to_variable_addition. +*) diff --git a/theories/lang_syntax_table_game.v b/theories/lang_syntax_table_game.v new file mode 100644 index 000000000..1533e5204 --- /dev/null +++ b/theories/lang_syntax_table_game.v @@ -0,0 +1,734 @@ +(* mathcomp analysis (c) 2025 Inria and AIST. License: CeCILL-C. *) +Require Import String. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval. +From mathcomp Require Import ring lra. +From mathcomp Require Import unstable mathcomp_extra boolp classical_sets. +From mathcomp Require Import functions cardinality fsbigop interval_inference. +From mathcomp Require Import reals ereal topology normedtype sequences. +From mathcomp Require Import esum measure charge lebesgue_measure numfun. +From mathcomp Require Import lebesgue_integral probability kernel prob_lang. +From mathcomp Require Import lang_syntax_util lang_syntax lang_syntax_examples. + +(**md**************************************************************************) +(* # Eddy's table game example *) +(* *) +(* Formalization of the Eddy's table game by equational reasoning. See *) +(* Sections 2.1, 3.2, 3.4, 3.5 of [Shan, POPL 2018]. The final statement of *) +(* equivalence is Lemma from_prog0_to_prog5. Intermediate steps are lemmas *) +(* named progij that turn the program progi into the program progj. *) +(* *) +(* ref: *) +(* - Chung-chieh Shan, Equational reasoning for probabilistic programming, *) +(* POPL TutorialFest 2018 *) +(* https://homes.luddy.indiana.edu/ccshan/rational/equational-handout.pdf *) +(* - Sean R Eddy, What is Bayesian statistics?, Nature Biotechnology 22(9), *) +(* 1177--1178 (2004) *) +(* *) +(* ``` *) +(* table0 == Eddy's table game represented as a probabilistic program *) +(* ``` *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Import Order.TTheory GRing.Theory Num.Def Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope string_scope. + +Local Open Scope ereal_scope. +Local Open Scope string_scope. + +Section execP_letin_uniform. +Local Open Scope ereal_scope. + +Let letin'_sample_uniform {R : realType} d d' (T : measurableType d) + (T' : measurableType d') (a b : R) (ab : (a < b)%R) + (u : R.-sfker [the measurableType _ of (_ * T)%type] ~> T') x y : + measurable y -> + letin' (sample_cst (uniform_prob ab)) u x y = + (b - a)^-1%:E * \int[lebesgue_measure]_(x0 in `[a, b]) u (x0, x) y. +Proof. +move=> my; rewrite letin'E/=. +rewrite integral_uniform//=. +move => _ /= Y mY /=. +have /= := measurable_kernel u _ my measurableT _ mY. +move/measurable_ysection => /(_ x) /=. +set A := (X in measurable X). +set B := (X in _ -> measurable X). +suff : A = B by move=> ->. +by rewrite {}/A {}/B !setTI ysectionE. +Qed. + +Local Open Scope lang_scope. +(* NB: consider moving to lang_syntax.v *) +Lemma execP_letin_uniform {R : realType} + g t str (s0 s1 : exp P ((str, Real) :: g) t) : + (forall (p : R) x U, (0 <= p <= 1)%R -> + execP s0 (p, x) U = execP s1 (p, x) U) -> + forall x U, measurable U -> + execP [let str := Sample Uniform {0%R} {1%R} {@ltr01 R} in {s0}] x U = + execP [let str := Sample Uniform {0%R} {1%R} {@ltr01 R} in {s1}] x U. +Proof. +move=> s01 x U mU. +rewrite !execP_letin execP_sample execD_uniform/=. +rewrite !letin'_sample_uniform//. +congr *%E. +apply: eq_integral => p p01. +apply: s01. +by rewrite inE in p01. +Qed. + +End execP_letin_uniform. + +(* NB: generic lemmas about bounded, to be moved *) +Section bounded. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Local Open Scope ereal_scope. +Context {R : realType}. + +Lemma bounded_id_01 : [bounded x0 : R^o | x0 in `[0%R, 1%R]%classic : set R]. +Proof. +exists 1%R; split => // y y1. +near=> M => /=. +rewrite (le_trans _ (ltW y1))//. +near: M. +move=> M /=. +rewrite in_itv/= => /andP[M0 M1]. +by rewrite ler_norml M1 andbT (le_trans _ M0). +Unshelve. all: by end_near. Qed. + +Lemma bounded_onem_01 : [bounded (`1- x : R^o) | x in `[0%R, 1%R]%classic : set R]. +Proof. +exists 1%R; split => // y y1. +near=> M => /=. +rewrite (le_trans _ (ltW y1))//. +near: M. +move=> M /=. +rewrite in_itv/= => /andP[M0 M1]. +rewrite ler_norml (@le_trans _ _ 0%R)//=. + by rewrite lerBlDr addrC -lerBlDr subrr. +by rewrite onem_ge0. +Unshelve. all: by end_near. Qed. + +Lemma bounded_cst_01 (x : R^o) : [bounded x | _ in `[0%R, 1%R]%classic : set R]. +Proof. +exists `|x|%R; split. + by rewrite num_real. +move=> y y1/= z. +rewrite in_itv/= => /andP[z0 z1]. +by rewrite (le_trans _ (ltW y1)). +Qed. + +Lemma bounded_norm (f : R -> R) : + [bounded f x : R^o | x in (`[0%R, 1%R]%classic : set R)] <-> + [bounded `|f x|%R : R^o | x in (`[0%R, 1%R]%classic : set R)]. +Proof. +split. + move=> [M [Mreal HM]]. + exists `|M|%R; split; first by rewrite normr_real. + move=> r Mr x/= x01. + by rewrite ger0_norm// HM// (le_lt_trans _ Mr)// ler_norm. +move=> [M [Mreal HM]]. +exists `|M|%R; split; first by rewrite normr_real. +move=> r Mr x/= x01. +rewrite -[leLHS]ger0_norm// HM//. +by rewrite (le_lt_trans _ Mr)// ler_norm. +Qed. + +Lemma boundedMl k (f : R -> R) : + [bounded f x : R^o | x in (`[0%R, 1%R]%classic : set R)] -> + [bounded (k * f x)%R : R^o | x in (`[0%R, 1%R]%classic : set R)]. +Proof. +move=> [M [Mreal HM]]. +exists `|k * M|%R; split; first by rewrite normr_real. +move=> r kMr x/= x01. +rewrite normrM. +have [->|k0] := eqVneq k 0%R. + by rewrite normr0 mul0r (le_trans _ (ltW kMr)). +rewrite -ler_pdivlMl ?normr_gt0//. +apply: HM => //. +rewrite ltr_pdivlMl ?normr_gt0//. +rewrite (le_lt_trans _ kMr)//. +by rewrite normrM ler_pM2l ?normr_gt0// ler_norm. +Qed. + +End bounded. + +(* TODO: move? *) +Lemma measurable_bernoulli_expn {R : realType} U n : + measurable_fun [set: g_sigma_algebraType R.-ocitv.-measurable] + (fun x : R => bernoulli_prob (`1-x ^+ n) U). +Proof. +apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. +by apply: measurable_funX => //=; exact: measurable_funB. +Qed. + +Lemma integrable_bernoulli_beta_pdf {R : realType} U : measurable U -> + (@lebesgue_measure R).-integrable [set: g_sigma_algebraType R.-ocitv.-measurable] + (fun x => bernoulli_prob (1 - `1-x ^+ 3) U * (beta_pdf 6 4 x)%:E)%E. +Proof. +move=> mU. +have ? : measurable_fun [set: g_sigma_algebraType R.-ocitv.-measurable] + (fun x => bernoulli_prob (1 - `1-x ^+ 3) U). + apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. + apply: measurable_funB => //; apply: measurable_funX => //. + exact: measurable_funB. +apply/integrableP; split => /=. + apply: emeasurable_funM => //=. + by apply/measurable_EFinP; exact: measurable_beta_pdf. +apply: (@le_lt_trans _ _ (\int[lebesgue_measure]_(x in `[0%R, 1%R]) (beta_pdf 6 4 x)%:E))%E. + rewrite [leRHS]integral_mkcond /=. + apply: ge0_le_integral => //=. + - apply: measurableT_comp => //; apply: emeasurable_funM => //. + by apply/measurable_EFinP; exact: measurable_beta_pdf. + - move=> x _ /=; rewrite patchE; case: ifPn => // _. + by rewrite lee_fin beta_pdf_ge0. + - apply/(measurable_restrict _ _ _) => //. + apply/measurable_funTS/measurableT_comp => //. + exact: measurable_beta_pdf. + - move=> x _. + rewrite patchE; case: ifPn => x01. + rewrite gee0_abs//. + rewrite gee_pMl// ?probability_le1//. + by rewrite ge0_fin_numE// (le_lt_trans (probability_le1 _ _))// ltry. + by rewrite lee_fin beta_pdf_ge0. + by rewrite mule_ge0// lee_fin beta_pdf_ge0. + by rewrite /beta_pdf patchE (negbTE x01) mul0r mule0 abse0. +apply: (@le_lt_trans _ _ + (\int[lebesgue_measure]_(x in `[0%R, 1%R]) (beta_fun 6 4)^-1%:E)%E); last first. + by rewrite integral_cst//= lebesgue_measure_itv/= lte01 EFinN sube0 mule1 ltry. +apply: ge0_le_integral => //=. +- by move=> ? _; rewrite lee_fin beta_pdf_ge0. +- by apply/measurable_funTS/measurableT_comp => //; exact: measurable_beta_pdf. +- by move=> ? _; rewrite lee_fin invr_ge0// beta_fun_ge0. +- by move=> x _; rewrite lee_fin beta_pdf_le_beta_funV. +Qed. + +Section table_game_programs. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Context (R : realType). +Local Notation mu := lebesgue_measure. + +Definition guard {g} str n : @exp R P [:: (str, _) ; g] _ := + [if #{str} == {n}:N then return TT else Score {0}:R]. + +Definition table0 : @exp R _ [::] _ := [Normalize + let "p" := Sample Uniform {0} {1} {ltr01} in + let "x" := Sample Binomial {8} #{"p"} in + let "_" := {guard "x" 5} in + let "y" := Sample Binomial {3} #{"p"} in + return {1}:N <= #{"y"}]. + +Definition tail1 : @exp R _ [:: ("_", Unit); ("x", Nat) ; ("p", Real)] _ := + [Sample Bernoulli {1}:R - {[{1}:R - #{"p"}]} ^+ {3}]. + +Definition tail2 : @exp R _ [:: ("_", Unit); ("p", Real)] _ := + [Sample Bernoulli {1}:R - {[{1}:R - #{"p"}]} ^+ {3}]. + +Definition tail3 : @exp R _ [:: ("p", Real); ("_", Unit)] _ := + [Sample Bernoulli {1}:R - {[{1}:R - #{"p"}]} ^+ {3}]. + +Definition table1 : @exp R _ [::] _ := [Normalize + let "p" := Sample Uniform {0} {1} {ltr01} in + let "x" := Sample Binomial {8} #{"p"} in + let "_" := {guard "x" 5} in + {tail1}]. + +Definition table2 : @exp R _ [::] _ := [Normalize + let "p" := Sample Uniform {0} {1} {ltr01} in + let "_" := + Score {[{56}:R * #{"p"} ^+ {5%R} * {[{1}:R - #{"p"}]} ^+ {3%R}]} in + {tail2}]. + +Definition table2' : @exp R _ [::] _ := [Normalize + let "p" := Sample Beta {1} {1} in + let "_" := Score + {[{56}:R * #{"p"} ^+ {5%R} * {[{1}:R - #{"p"}]} ^+ {3%R}]} in + {tail2}]. + +Definition table3 : @exp R _ [::] _ := [Normalize + let "_" := Score {1 / 9}:R in + let "p" := Sample Beta {6} {4} in + {tail3}]. + +Definition table4 : @exp R _ [::] _ := [Normalize + let "_" := Score {1 / 9}:R in + Sample Bernoulli {10 / 11}:R]. + +Definition table5 : @exp R _ [::] _ := [Normalize Sample Bernoulli {10 / 11}:R]. + +End table_game_programs. +Arguments tail1 {R}. +Arguments tail2 {R}. +Arguments guard {R g}. + +Section from_table0_to_table1. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. +Local Open Scope lang_scope. +Context (R : realType). +Local Notation mu := lebesgue_measure. + +Let table01_subproof + (x : mctx (untag (ctx_of (recurse Unit (recurse Nat (found "p" Real [::])))))) + U : (0 <= x.2.2.1 <= 1)%R -> + execP [let "y" := Sample Binomial {3%R} #{"p"} in + return {1}:N <= #{"y"}] x U = + execP (@tail1 R) x U. +Proof. +move=> x01. +rewrite /tail1. +(* reduce lhs *) +rewrite execP_letin execP_sample execD_binomial/= execP_return/= execD_rel/=. +rewrite exp_var'E (execD_var_erefl "p")/=. +rewrite exp_var'E (execD_var_erefl "y")/=. +rewrite execD_nat/=. +rewrite [LHS]letin'E/=. +(* reduce rhs *) +rewrite execP_sample/= execD_bernoulli/= (@execD_bin _ _ binop_minus)/=. +rewrite execD_real/= execD_pow/= (@execD_bin _ _ binop_minus)/= execD_real/=. +rewrite (execD_var_erefl "p")/=. +exact/integral_binomial_prob. +Qed. + +Lemma table01 : execD (@table0 R) = execD (@table1 R). +Proof. +rewrite /table0 /table1. +apply: congr_normalize => y A. +apply: execP_letin_uniform => // p [] B p01. +apply: congr_letinr => a1 V0. +apply: congr_letinr => -[] V1. +exact: table01_subproof. +Qed. + +End from_table0_to_table1. + +Section from_table1_to_table2. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. +Local Open Scope lang_scope. +Context (R : realType). +Local Notation mu := lebesgue_measure. + +Let table12_subproof (y : @mctx R [::]) (V : set (@mtyp R Bool)) + (p : R) + (x : projT2 (existT measurableType default_measure_display unit)) + (U : set (mtyp Bool)) + (p0 : (0 <= p)%R) + (p1 : (p <= 1)%R) : + \int[binomial_prob 8 p]_y0 + execP [let "_" := {guard "x" 5} in {tail1}] + (y0, (p, x)) U = + \int[mscale (NngNum (normr_ge0 (56 * XMonemX 5 3 p)%R)) \d_tt]_y0 + execP tail2 (y0, (p, x)) U. +Proof. +rewrite integral_binomial//=. +rewrite (bigD1 (inord 5))//=. +rewrite big1 ?adde0; last first. + move=> i i5. + rewrite execP_letin/= execP_if/= execD_rel/=. + rewrite exp_var'E/= (execD_var_erefl "x")/=. + rewrite execD_nat/= execP_score/= execD_real/= execP_return/=. + rewrite letin'E iteE/=. + move: i => [[|[|[|[|[|[|[|[|[|//]]]]]]]]]]//= Hi in i5 *; + rewrite ?ge0_integral_mscale//= ?execD_real/= ?normr0 ?(mul0e,mule0)//. + by rewrite -val_eqE/= inordK in i5. +(* reduce lhs *) +rewrite -[(p ^+ _ * _ ^+ _)%R]/(XMonemX _ _ p). +rewrite execP_letin/= execP_if/= execD_rel/=. +rewrite exp_var'E/= (execD_var_erefl "x")/=. +rewrite execD_nat/= execP_score/= execD_real/= execP_return/=. +rewrite letin'E iteE/=. +rewrite inordK// eqxx. +rewrite integral_dirac//= execD_unit/= diracE mem_set// mul1e. +(* reduce rhs *) +rewrite ge0_integral_mscale//=. +rewrite integral_dirac//= diracE mem_set// mul1e. +rewrite ger0_norm ?mulr_ge0 ?subr_ge0//. +rewrite mulr_natl. +(* same score *) +congr *%E. +(* the tails are the same module the shape of the environment *) +rewrite /tail1 /tail2 !execP_sample/=. +rewrite !execD_bernoulli/=. +rewrite !(@execD_bin _ _ binop_minus)/=. +rewrite !execD_pow/=. +rewrite !execD_real/=. +rewrite !(@execD_bin _ _ binop_minus)/=. +by rewrite !execD_real/= !exp_var'E/= !(execD_var_erefl "p")/=. +Qed. + +Lemma table12 : execD (@table1 R) = execD (@table2 R). +Proof. +apply: congr_normalize => y V. +apply: execP_letin_uniform => // p x U /andP[p0 p1]. +(* reduce the lhs *) +rewrite execP_letin execP_sample execD_binomial/=. +rewrite letin'E/=. +rewrite [in LHS]exp_var'E/= (execD_var_erefl "p")/=. +(* reduce the rhs *) +rewrite [in RHS]execP_letin execP_score/=. +rewrite letin'E/=. +do 2 rewrite (@execD_bin _ _ binop_mult)/=/=. +rewrite [in RHS]exp_var'E/=. +rewrite execD_pow/=. +rewrite (execD_var_erefl "p")/=. +rewrite execD_pow/=. +rewrite (@execD_bin _ _ binop_minus)/=/=. +rewrite 2!execD_real/=. +rewrite (execD_var_erefl "p")/=. +rewrite -(mulrA 56%R). +exact: table12_subproof. +Qed. + +End from_table1_to_table2. + +Local Open Scope ereal_scope. + +Lemma measurable_bernoulli_onemXn {R : realType} U : + measurable_fun [set: g_sigma_algebraType R.-ocitv.-measurable] + (fun x => bernoulli_prob (1 - `1-x ^+ 3) U). +Proof. +apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. +apply: measurable_funB => //. +by apply: measurable_funX; exact: measurable_funB. +Qed. + +Lemma bounded_norm_XnonemXn {R : realType} : + [bounded normr (56 * XMonemX 5 3 x)%R : R^o | x in `[0%R, 1%R] : set R]. +Proof. exact/(bounded_norm _).1/boundedMl/bounded_XMonemX. Qed. + +Lemma integrable_bernoulli_XMonemX {R : realType} U : + (beta_prob 1 1).-integrable [set: R] + (fun x => bernoulli_prob (1 - `1-x ^+ 3) U * (normr (56 * XMonemX 5 3 x))%:E). +Proof. +apply/integrableP; split. + apply: emeasurable_funM; first exact: measurable_bernoulli_onemXn. + apply/measurable_EFinP => //; apply: measurableT_comp => //. + by apply: measurable_funM => //; exact: measurable_XMonemX. +rewrite beta_prob_uniform integral_uniform//=. + rewrite subr0 invr1 mul1e. + suff : lebesgue_measure.-integrable `[0%R, 1%R] + (fun y : R => bernoulli_prob (1 - `1-y ^+ 3) U * (normr (56 * XMonemX 5 3 y))%:E). + by move=> /integrableP[]. + apply: integrableMl => //=. + - apply/integrableP; split. + by apply: measurable_funTS; exact: measurable_bernoulli_onemXn. + have := @integral_beta_prob_bernoulli_prob_onem_lty R 3 1%N 1%N U. + rewrite beta_prob_uniform integral_uniform//=; last first. + by apply: measurableT_comp => //=; exact: measurable_bernoulli_onemXn. + by rewrite subr0 invr1 mul1e. + - apply: @measurableT_comp => //=; apply: measurable_funM => //. + exact: measurable_XMonemX. + exact: bounded_norm_XnonemXn. +apply: @measurableT_comp => //; apply: emeasurable_funM => //=. + exact: measurable_bernoulli_onemXn. +do 2 apply: @measurableT_comp => //=. +by apply: measurable_funM => //; exact: measurable_XMonemX. +Qed. + +Section from_table2_to_table3. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. +Local Open Scope lang_scope. +Context (R : realType). +Local Notation mu := lebesgue_measure. + +Lemma table22' : execD (@table2 R) = execD (@table2' R). +Proof. +apply: congr_normalize => // x U. +apply: congr_letinl => // y V. +rewrite !execP_sample execD_uniform/= execD_beta/=. +by rewrite beta_prob_uniform. +Qed. + +Lemma table23 : execD (@table2' R) = execD (@table3 R). +Proof. +apply: congr_normalize => x U. +(* reduce the LHS *) +rewrite 2![in LHS]execP_letin. +rewrite ![in LHS]execP_sample. +rewrite [in LHS]execP_score. +rewrite [in LHS]execD_beta/=. +rewrite [in LHS]execD_bernoulli. +rewrite 2![in LHS](@execD_bin _ _ binop_mult)/=. +rewrite 2![in LHS]execD_pow/=. +rewrite 2![in LHS](@execD_bin _ _ binop_minus)/=. +rewrite 3![in LHS]execD_real. +rewrite [in LHS]exp_var'E [in LHS](execD_var_erefl "p")/=. +rewrite [in LHS]execD_pow/=. +rewrite [in LHS](@execD_bin _ _ binop_minus)/=. +rewrite [in LHS]execD_real. +rewrite [in LHS]exp_var'E [in LHS](execD_var_erefl "p")/=. +(* reduce the RHS *) +rewrite [in RHS]execP_letin. +rewrite [in RHS]execP_score. +rewrite [in RHS]execP_letin/=. +rewrite [in RHS]execP_sample/=. +rewrite [in RHS]execD_beta/=. +rewrite [in RHS]execP_sample/=. +rewrite [in RHS]execD_bernoulli/=. +rewrite [in RHS]execD_real/=. +rewrite [in RHS](@execD_bin _ _ binop_minus)/=. +rewrite [in RHS]execD_real/=. +rewrite [in RHS]execD_pow/=. +rewrite [in RHS](@execD_bin _ _ binop_minus)/=. +rewrite [in RHS]exp_var'E [in RHS](execD_var_erefl "p")/=. +rewrite [in RHS]execD_real/=. +rewrite [LHS]letin'E/=. +under eq_integral => y _. + rewrite letin'E/=. + rewrite integral_cst//= /mscale/= diracT mule1 -mulrA -/(XMonemX _ _ _). + over. +rewrite [RHS]letin'E/=. +under [in RHS]eq_integral => y _. + rewrite letin'E/=. + over. +rewrite /=. +rewrite [RHS]ge0_integral_mscale//=; last first. + by move=> _ _; rewrite integral_ge0. +rewrite integral_beta_prob//=; last 2 first. + - apply: emeasurable_funM => //=. + exact: measurable_bernoulli_onemXn. + apply/measurable_EFinP; apply: measurableT_comp => //. + by apply: measurable_funM => //; exact: measurable_XMonemX. + - by have /integrableP[] := @integrable_bernoulli_XMonemX R U. +rewrite ger0_norm// integral_dirac// diracT mul1e. +rewrite integral_beta_prob/=; [|by []|exact: measurable_bernoulli_onemXn + |exact: integral_beta_prob_bernoulli_prob_onem_lty]. +rewrite -integralZl//=; last exact: integrable_bernoulli_beta_pdf. +apply: eq_integral => y _. +rewrite [in RHS]muleCA -[in LHS]muleA; congr *%E. +rewrite /beta_pdf 2!patchE; case: ifPn => [y01|_]; last first. + by rewrite !mul0r 2!mule0. +rewrite ger0_norm; last first. + by rewrite mulr_ge0// XMonemX_ge0//; rewrite inE in y01. +rewrite [X in _ = _ * X]EFinM [in RHS]muleCA. +rewrite /= XMonemX00 mul1r [in LHS](mulrC 56%R) [in LHS]EFinM -[in LHS]muleA; congr *%E. +by rewrite !beta_funE/=; repeat rewrite !factE/=; rewrite -EFinM; congr EFin; lra. +Qed. + +End from_table2_to_table3. + +Local Open Scope ereal_scope. +(* TODO: move? *) +Lemma int_beta_prob01 {R : realType} (f : R -> R) a b U : + measurable_fun [set: R] f -> + (forall x, x \in `[0%R, 1%R] -> 0 <= f x <= 1)%R -> + \int[beta_prob a b]_y bernoulli_prob (f y) U = + \int[beta_prob a b]_(y in `[0%R, 1%R] : set R) bernoulli_prob (f y) U. +Proof. +move=> mf f01. +rewrite [LHS]integral_beta_prob//=; last 2 first. + apply: measurable_funTS. + by apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. + exact: integral_beta_prob_bernoulli_prob_lty. +rewrite [RHS]integral_beta_prob//; last 2 first. + apply/measurable_funTS => //=. + by apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. + apply: (le_lt_trans _ (integral_beta_prob_bernoulli_prob_lty a b U mf f01)). + apply: ge0_subset_integral => //=; apply: measurableT_comp => //=. + by apply: (measurableT_comp (measurable_bernoulli_prob2 _)) => //=. +rewrite [RHS]integral_mkcond/=; apply: eq_integral => x _ /=. +rewrite !patchE; case: ifPn => // x01. +by rewrite /beta_pdf patchE (negbTE x01) mul0r mule0. +Qed. + +Lemma expr_onem_01 {R : realType} (x : R) : x \in `[0%R, 1%R] -> + (0 <= `1-x ^+ 3 <= 1)%R. +Proof. +rewrite in_itv/= => /andP[x0 x1]. +rewrite exprn_ge0 ?subr_ge0//= exprn_ile1// ?subr_ge0//. +by rewrite lerBlDl -lerBlDr subrr. +Qed. + +Lemma int_beta_prob_bernoulli {R : realType} (U : set (@mtyp R Bool)) : + \int[beta_prob 6 4]_y bernoulli_prob (`1-y ^+ 3) U = bernoulli_prob (1 / 11) U :> \bar R. +Proof. +rewrite int_beta_prob01//; last 2 first. + by apply: measurable_funX => //; exact: measurable_funB. + exact: expr_onem_01. +have := @beta_prob_bernoulli_probE R 6 4 0 3 U isT isT. +rewrite /beta_prob_bernoulli_prob. +under eq_integral. + move=> x x0. + rewrite patchE x0 XMonemX0. + over. +rewrite /= => ->; congr bernoulli_prob. +(*by rewrite /div_beta_fun addn0 !beta_funE/=; repeat rewrite !factE/=; field.*) +rewrite /div_beta_fun addn0 !beta_funE/=. +(* temporary measure to avoid stack overflow *) +rewrite mulrAC -mulrA mulrAC 2!invfM 3!mulrA mulfV ?gt_eqF// 2!div1r. +rewrite !addnS !addn0. +rewrite (factS 11) (factS 10) (factS 9). +by rewrite !factE; field. +Qed. + +Lemma dirac_bool {R : realType} (U : set bool) : + \d_false U + \d_true U = (\sum_(x \in U) (1%E : \bar R))%R. +Proof. +have [| | |] := set_bool U => /eqP ->; rewrite !diracE. +- by rewrite memNset// mem_set//= fsbig_set1 add0e. +- by rewrite mem_set// memNset//= fsbig_set1 adde0. +- by rewrite !in_set0 fsbig_set0 adde0. +- rewrite !in_setT setT_bool fsbigU0//=; last by move=> x [->]. + by rewrite !fsbig_set1. +Qed. + +Lemma int_beta_prob_bernoulli_onem {R : realType} (U : set (@mtyp R Bool)) : + \int[beta_prob 6 4]_y bernoulli_prob (`1-(`1-y ^+ 3)) U = bernoulli_prob (10 / 11) U :> \bar R. +Proof. +transitivity + (\d_false U + \d_true U - bernoulli_prob (1 / 11) U : \bar R)%E; last first. + rewrite /bernoulli_prob ifT; last lra. + rewrite ifT; last lra. + apply/eqP; rewrite sube_eq//; last first. + rewrite ge0_adde_def// inE. + by apply/sume_ge0 => //= b _; rewrite lee_fin bernoulli_pmf_ge0//; lra. + by apply/sume_ge0 => //= b _; rewrite lee_fin bernoulli_pmf_ge0//; lra. + rewrite -fsbig_split//=. + under eq_fsbigr. + move=> /= x _. + rewrite -EFinD /bernoulli_pmf [X in X%:E](_ : _ = 1%R); last first. + case: x => //; lra. + over. + by rewrite /= dirac_bool. +rewrite -int_beta_prob_bernoulli. +apply/esym/eqP; rewrite sube_eq//; last first. + by rewrite ge0_adde_def// inE; exact: integral_ge0. +rewrite int_beta_prob01; last 2 first. + apply: measurable_funB => //; apply: measurable_funX => //. + exact: measurable_funB. + move=> x x01. + by rewrite subr_ge0 andbC lerBlDr -lerBlDl subrr expr_onem_01. +rewrite [X in _ == _ + X]int_beta_prob01; last 2 first. + by apply: measurable_funX => //; exact: measurable_funB. + exact: expr_onem_01. +rewrite -ge0_integralD//=; last 2 first. + apply: (@measurableT_comp _ _ _ _ _ _ (bernoulli_prob ^~ U)) => /=. + exact: measurable_bernoulli_prob2. + apply: measurable_funB => //=; apply: measurable_funX => //=. + exact: measurable_funB. + apply: (@measurableT_comp _ _ _ _ _ _ (bernoulli_prob ^~ U)) => /=. + exact: measurable_bernoulli_prob2. + by apply: measurable_funX => //=; exact: measurable_funB. +apply/eqP; transitivity + (\int[beta_prob 6 4]_(x in `[0%R, 1%R]) (\d_false U + \d_true U) : \bar R). + by rewrite integral_cst//= beta_prob01 mule1 EFinD. +apply: eq_integral => /= x x01. +rewrite /bernoulli_prob subr_ge0 lerBlDr -lerBlDl subrr andbC. +rewrite (_ : (_ <= _ <= _)%R); last first. + by apply: expr_onem_01; rewrite inE in x01. +rewrite -fsbig_split//=. +under eq_fsbigr. + move=> /= y yU. + rewrite -EFinD /bernoulli_pmf. + rewrite [X in X%:E](_ : _ = 1%R); last first. + by case: ifPn => _; rewrite subrK. + over. +by rewrite /= dirac_bool. +Qed. + +Local Close Scope ereal_scope. + +Section from_table3_to_table4. +Local Open Scope ereal_scope. +Local Open Scope lang_scope. +Context (R : realType). + +(* NB: not used *) +Lemma table34' U : + @execP R [::] _ [let "p" := Sample Beta {6%R} {4%R} in + Sample Bernoulli {[{1}:R - #{"p"}]} ^+ {3%N}] tt U = + @execP R [::] _ [Sample Bernoulli {1 / 11}:R] tt U. +Proof. +(* reduce the lhs *) +rewrite execP_letin. +rewrite execP_sample execD_beta/=. +rewrite execP_sample/= execD_bernoulli/=. +rewrite execD_pow/= (@execD_bin _ _ binop_minus) execD_real/=. +rewrite exp_var'E (execD_var_erefl "p")/=. +(* reduce the rhs *) +rewrite execP_sample execD_bernoulli/= execD_real. +(* semantics of lhs *) +rewrite letin'E/=. +exact: int_beta_prob_bernoulli. +Qed. + +Lemma table34 l u U : + @execP R l _ [let "p" := Sample Beta {6%R} {4%R} in + Sample Bernoulli {1}:R - {[{1}:R - #{"p"}]} ^+ {3%N}] u U = + @execP R l _ [Sample Bernoulli {10 / 11}:R] u U. +Proof. +(* reduce the lhs *) +rewrite execP_letin. +rewrite execP_sample execD_beta/=. +rewrite execP_sample/= execD_bernoulli/=. +rewrite (@execD_bin _ _ binop_minus)/=. +rewrite execD_pow/= (@execD_bin _ _ binop_minus) execD_real/=. +rewrite exp_var'E (execD_var_erefl "p")/=. +(* reduce the rhs *) +rewrite execP_sample execD_bernoulli/= execD_real. +(* semantics of lhs *) +rewrite letin'E/=. +exact: int_beta_prob_bernoulli_onem. +Qed. + +End from_table3_to_table4. + +Section from_table4_to_table5. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. +Local Open Scope lang_scope. +Context (R : realType). +Local Notation mu := lebesgue_measure. + +Lemma normalize_score_bernoulli g p q (p0 : (0 < p)%R) (q01 : (0 <= q <= 1)%R) : + @execD R g _ [Normalize let "_" := Score {p}:R in + Sample Bernoulli {q}:R] = + execD [Normalize Sample Bernoulli {q}:R]. +Proof. +apply: eq_execD. +rewrite !execD_normalize_pt/= !execP_letin !execP_score. +rewrite !execP_sample !execD_bernoulli !execD_real/=. +apply: funext=> x. +apply: eq_probability=> /= U. +rewrite !normalizeE/=. +rewrite !bernoulli_probE//=; [|lra..]. +rewrite !diracT !mule1 -EFinD add_onemK onee_eq0/=. +rewrite !letin'E. +under eq_integral. + move=> A _ /=. + rewrite !bernoulli_probE//=; [|lra..]. + rewrite !diracT !mule1 -EFinD add_onemK. + over. +rewrite !ge0_integral_mscale//= (ger0_norm (ltW p0))//. +rewrite integral_dirac// !diracT !indicT /= !mule1 !mulr1. +rewrite add_onemK invr1 mule1. +rewrite gt_eqF ?lte_fin//=. +rewrite integral_dirac//= diracT mul1e. +by rewrite muleAC -EFinM divff ?gt_eqF// mul1e bernoulli_probE. +Qed. + +Lemma table45 : execD (@table4 R) = execD (@table5 R). +Proof. by rewrite normalize_score_bernoulli//; lra. Qed. + +End from_table4_to_table5. + +Lemma from_table0_to_table5 {R : realType} : execD (@table0 R) = execD (@table5 R). +Proof. +rewrite table01 table12 table22' table23. +rewrite -table45. +apply: congr_normalize => y V. +apply: congr_letinr => x U. +by rewrite -table34. +Qed. diff --git a/theories/lang_syntax_toy.v b/theories/lang_syntax_toy.v new file mode 100644 index 000000000..dd30e9e85 --- /dev/null +++ b/theories/lang_syntax_toy.v @@ -0,0 +1,553 @@ +(* mathcomp analysis (c) 2025 Inria and AIST. License: CeCILL-C. *) +Require Import String. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg interval_inference. +From mathcomp Require Import mathcomp_extra boolp. +From mathcomp Require Import reals topology normedtype. +From mathcomp Require Import lang_syntax_util. + +(**md**************************************************************************) +(* # Intrinsically-typed concrete syntax for a toy language *) +(* *) +(* The main module provided by this file is "lang_intrinsic_tysc" which *) +(* provides an example of intrinsically-typed concrete syntax for a toy *) +(* language (a simplification of the syntax/evaluation formalized in *) +(* lang_syntax.v). Other modules provide even more simplified language for *) +(* pedagogical purposes. *) +(* *) +(* ``` *) +(* lang_extrinsic == non-intrinsic definition of expression *) +(* lang_intrinsic_ty == intrinsically-typed syntax *) +(* lang_intrinsic_sc == intrinsically-scoped syntax *) +(* lang_intrinsic_tysc == intrinsically-typed/scoped syntax *) +(* ``` *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Set Printing Implicit Defensive. + +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope string_scope. + +Section type. +Variables (R : realType). + +Inductive typ := Real | Unit. + +HB.instance Definition _ := gen_eqMixin typ. + +Definition iter_pair (l : list Type) : Type := + foldr (fun x y => (x * y)%type) unit l. + +Definition Type_of_typ (t : typ) : Type := + match t with + | Real => R + | Unit => unit + end. + +Definition ctx := seq (string * typ). + +Definition Type_of_ctx (g : ctx) := iter_pair (map (Type_of_typ \o snd) g). + +Goal Type_of_ctx [:: ("x", Real); ("y", Real)] = (R * (R * unit))%type. +Proof. by []. Qed. + +End type. + +Module lang_extrinsic. +Section lang_extrinsic. +Variable R : realType. +Implicit Types str : string. + +Inductive exp : Type := +| exp_unit : exp +| exp_real : R -> exp +| exp_var (g : ctx) t str : t = lookup Unit g str -> exp +| exp_add : exp -> exp -> exp +| exp_letin str : exp -> exp -> exp. +Arguments exp_var {g t}. + +Fail Example letin_once : exp := + exp_letin "x" (exp_real 1) (exp_var "x" erefl). +Example letin_once : exp := + exp_letin "x" (exp_real 1) (@exp_var [:: ("x", Real)] Real "x" erefl). + +End lang_extrinsic. +End lang_extrinsic. + +Module lang_intrinsic_ty. +Section lang_intrinsic_ty. +Variable R : realType. +Implicit Types str : string. + +Inductive exp : typ -> Type := +| exp_unit : exp Unit +| exp_real : R -> exp Real +| exp_var g t str : t = lookup Unit g str -> exp t +| exp_add : exp Real -> exp Real -> exp Real +| exp_letin t u : string -> exp t -> exp u -> exp u. +Arguments exp_var {g t}. + +Fail Example letin_once : exp Real := + exp_letin "x" (exp_real 1) (exp_var "x" erefl). +Example letin_once : exp Real := + exp_letin "x" (exp_real 1) (@exp_var [:: ("x", Real)] _ "x" erefl). + +End lang_intrinsic_ty. +End lang_intrinsic_ty. + +Module lang_intrinsic_sc. +Section lang_intrinsic_sc. +Variable R : realType. +Implicit Types str : string. + +Inductive exp : ctx -> Type := +| exp_unit g : exp g +| exp_real g : R -> exp g +| exp_var g t str : t = lookup Unit g str -> exp g +| exp_add g : exp g -> exp g -> exp g +| exp_letin g t str : exp g -> exp ((str, t) :: g) -> exp g. +Arguments exp_real {g}. +Arguments exp_var {g t}. +Arguments exp_letin {g t}. + +Declare Custom Entry expr. + +Notation "[ e ]" := e (e custom expr at level 5). +Notation "{ x }" := x (in custom expr, x constr). +Notation "x ':R'" := (exp_real x) (in custom expr at level 1). +Notation "x" := x (in custom expr at level 0, x ident). +Notation "$ x" := (exp_var x erefl) (in custom expr at level 1). +Notation "x + y" := (exp_add x y) + (in custom expr at level 2, left associativity). +Notation "'let' x ':=' e1 'in' e2" := (exp_letin x e1 e2) + (in custom expr at level 3, x constr, + e1 custom expr at level 2, e2 custom expr at level 3, + left associativity). + +Fail Example letin_once : exp [::] := + [let "x" := {1%R}:R in ${"x"}]. +Example letin_once : exp [::] := + [let "x" := {1%R}:R in {@exp_var [:: ("x", Real)] _ "x" erefl}]. + +Fixpoint acc (g : ctx) (i : nat) : + Type_of_ctx R g -> @Type_of_typ R (nth Unit (map snd g) i) := + match g return Type_of_ctx R g -> Type_of_typ R (nth Unit (map snd g) i) with + | [::] => match i with | O => id | j.+1 => id end + | _ :: _ => match i with + | O => fst + | j.+1 => fun H => acc j H.2 + end + end. +Arguments acc : clear implicits. + +Inductive eval : forall g (t : typ), exp g -> (Type_of_ctx R g -> Type_of_typ R t) -> Prop := +| eval_real g c : @eval g Real [c:R] (fun=> c) +| eval_plus g (e1 e2 : exp g) (v1 v2 : R) : + @eval g Real e1 (fun=> v1) -> + @eval g Real e2 (fun=> v2) -> + @eval g Real [e1 + e2] (fun=> v1 + v2) +| eval_var (g : ctx) str i : + i = index str (map fst g) -> eval [$ str] (acc g i). + +Goal @eval [::] Real [{1}:R] (fun=> 1). +Proof. exact: eval_real. Qed. +Goal @eval [::] Real [{1}:R + {2}:R] (fun=> 3). +Proof. exact/eval_plus/eval_real/eval_real. Qed. +Goal @eval [:: ("x", Real)] _ [$ {"x"}] (acc [:: ("x", Real)] 0). +Proof. exact: eval_var. Qed. + +End lang_intrinsic_sc. +End lang_intrinsic_sc. + +Module lang_intrinsic_tysc. +Section lang_intrinsic_tysc. +Variable R : realType. +Implicit Types str : string. + +Inductive typ := Real | Unit | Pair : typ -> typ -> typ. + +HB.instance Definition _ := gen_eqMixin typ. + +Fixpoint mtyp (t : typ) : Type := + match t with + | Real => R + | Unit => unit + | Pair t1 t2 => (mtyp t1 * mtyp t2) + end. + +Definition ctx := seq (string * typ). + +Definition Type_of_ctx (g : ctx) := iter_pair (map (mtyp \o snd) g). + +Goal Type_of_ctx [:: ("x", Real); ("y", Real)] = (R * (R * unit))%type. +Proof. by []. Qed. + +Inductive exp : ctx -> typ -> Type := +| exp_unit g : exp g Unit +| exp_real g : R -> exp g Real +| exp_var g t str : t = lookup Unit g str -> exp g t +| exp_add g : exp g Real -> exp g Real -> exp g Real +| exp_pair g t1 t2 : exp g t1 -> exp g t2 -> exp g (Pair t1 t2) +| exp_letin g t1 t2 x : exp g t1 -> exp ((x, t1) :: g) t2 -> exp g t2. + +Definition exp_var' str (t : typ) (g : find str t) := + @exp_var (untag (ctx_of g)) t str (ctx_prf g). + +Section no_bidirectional_hints. + +Arguments exp_unit {g}. +Arguments exp_real {g}. +Arguments exp_var {g t}. +Arguments exp_add {g}. +Arguments exp_pair {g t1 t2}. +Arguments exp_letin {g t1 t2}. +Arguments exp_var' str {t} g. + +Fail Example letin_add : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_add (exp_var "x" erefl) + (exp_var "y" erefl))). +Example letin_add : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_add (@exp_var [:: ("y", Real); ("x", Real)] _ "x" erefl) + (exp_var "y" erefl))). +Reset letin_add. + +Declare Custom Entry expr. + +Notation "[ e ]" := e (e custom expr at level 5). +Notation "{ x }" := x (in custom expr, x constr). +Notation "x ':R'" := (exp_real x) (in custom expr at level 1). +Notation "x" := x (in custom expr at level 0, x ident). +Notation "$ x" := (exp_var x erefl) (in custom expr at level 1). +Notation "# x" := (exp_var' x%string _) (in custom expr at level 1). +Notation "e1 + e2" := (exp_add e1 e2) + (in custom expr at level 2, + (* e1 custom expr at level 1, e2 custom expr at level 2, *) + left associativity). +Notation "( e1 , e2 )" := (exp_pair e1 e2) + (in custom expr at level 1). +Notation "'let' x ':=' e1 'in' e2" := (exp_letin x e1 e2) + (in custom expr at level 3, x constr, + e1 custom expr at level 2, e2 custom expr at level 3, + left associativity). + +Fail Definition let3_add_erefl (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + $a + $b]. +(* The term "[$ a]" has type "exp ?g2 (lookup Unit ?g2 a)" while it is expected to have type "exp ?g2 Real". *) + +Definition let3_pair_erefl (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + ($a, $b)]. + +Fail Definition let3_add (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + #a + #b]. +(* The term "[# a + # b]" has type + "exp (untag (ctx_of (recurse (str':=b) Real ?f))) Real" +while it is expected to have type "exp ((c, Real) :: ?g1) ?u1" +(cannot unify "(b, Real)" and "(c, Real)"). *) + +Fail Definition let3_pair (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + (#a, #b)]. +(* The term "[# a + # b]" has type "exp (untag (ctx_of (recurse (str':=b) Real ?f))) Real" while it is expected to have type + "exp ((c, Real) :: ?g1) ?u1" (cannot unify "(b, Real)" and "(c, Real)"). *) + +End no_bidirectional_hints. + +Section with_bidirectional_hints. + +Arguments exp_unit {g}. +Arguments exp_real {g}. +Arguments exp_var {g t}. +Arguments exp_add {g} &. +Arguments exp_pair {g} & {t1 t2}. +Arguments exp_letin {g} & {t1 t2}. +Arguments exp_var' str {t} g. + +Declare Custom Entry expr. + +Notation "[ e ]" := e (e custom expr at level 5). +Notation "{ x }" := x (in custom expr, x constr). +Notation "x ':R'" := (exp_real x) (in custom expr at level 1). +Notation "x" := x (in custom expr at level 0, x ident). +Notation "$ x" := (exp_var x%string erefl) (in custom expr at level 1). +Notation "# x" := (exp_var' x%string _) (in custom expr at level 1). +Notation "e1 + e2" := (exp_add e1 e2) + (in custom expr at level 2, + left associativity). +Notation "( e1 , e2 )" := (exp_pair e1 e2) + (in custom expr at level 1). +Notation "'let' x ':=' e1 'in' e2" := (exp_letin x e1 e2) + (in custom expr at level 3, x constr, + e1 custom expr at level 2, e2 custom expr at level 3, + left associativity). + +Fail Definition let2_add_erefl_bidi (a b : string) + (ba : infer (b != a)) (ab : infer (a != b)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + $a + $b]. + +Definition let2_add_erefl_bidi (a b : string) + (ba : infer (b != a)) (ab : infer (a != b)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + #a + #b]. + +Fail Definition let3_add_erefl_bidi (a b c d : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + $a + $b]. +(* The term "[$ a]" has type "exp [:: (c, Real); (b, Real); (a, Real)] (lookup Unit [:: (c, Real); (b, Real); (a, Real)] a)" +while it is expected to have type "exp [:: (c, Real); (b, Real); (a, Real)] Real" +(cannot unify "lookup Unit [:: (c, Real); (b, Real); (a, Real)] a" and "Real"). *) + +Definition let3_pair_erefl_bidi (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + ($a, $b)]. + +Check let3_pair_erefl_bidi. +(* exp [::] (Pair (lookup Unit [:: (c, Real); (b, Real); (a, Real)] a) (lookup Unit [:: (c, Real); (b, Real); (a, Real)] b)) *) + +Definition let3_add_bidi (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + #a + #b]. + +Definition let3_pair_bidi (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + (#a , #b)]. + +Check let3_pair_bidi. +(* exp [::] (Pair Real Real) *) + +Example e0 : exp [::] _ := exp_real 1. +Example letin1 : exp [::] _ := + exp_letin "x" (exp_real 1) (exp_var "x" erefl). +Example letin2 : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_var "x" erefl)). + +Example letin_add : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_add (exp_var "x" erefl) + (exp_var "y" erefl))). +Reset letin_add. +Fail Example letin_add (x y : string) + (xy : infer (x != y)) (yx : infer (y != x)) : exp [::] _ := + exp_letin x (exp_real 1) + (exp_letin y (exp_real 2) + (exp_add (exp_var x erefl) (exp_var y erefl))). +Example letin_add (x y : string) + (xy : infer (x != y)) (yx : infer (y != x)) : exp [::] _ := + exp_letin x (exp_real 1) + (exp_letin y (exp_real 2) + (exp_add (exp_var' x _) (exp_var' y _))). +Reset letin_add. + +Example letin_add_custom : exp [::] _ := + [let "x" := {1}:R in + let "y" := {2}:R in + #{"x"} + #{"y"}]. + +Section eval. + +Fixpoint acc (g : ctx) (i : nat) : + Type_of_ctx g -> mtyp (nth Unit (map snd g) i) := + match g return Type_of_ctx g -> mtyp (nth Unit (map snd g) i) with + | [::] => match i with | O => id | j.+1 => id end + | _ :: _ => match i with + | O => fst + | j.+1 => fun H => acc j H.2 + end + end. +Arguments acc : clear implicits. + +Reserved Notation "e '-e->' v" (at level 40). + +Inductive eval : forall g t, exp g t -> (Type_of_ctx g -> mtyp t) -> Prop := +| eval_tt g : (exp_unit : exp g _) -e-> (fun=> tt) +| eval_real g c : (exp_real c : exp g _) -e-> (fun=> c) +| eval_plus g (e1 e2 : exp g Real) v1 v2 : + e1 -e-> v1 -> + e2 -e-> v2 -> + [e1 + e2] -e-> fun x => v1 x + v2 x +| eval_var g str : + let i := index str (map fst g) in + exp_var str erefl -e-> acc g i +| eval_pair g t1 t2 e1 e2 v1 v2 : + e1 -e-> v1 -> + e2 -e-> v2 -> + @exp_pair g t1 t2 e1 e2 -e-> fun x => (v1 x, v2 x) +| eval_letin g t t' str (e1 : exp g t) (e2 : exp ((str, t) :: g) t') v1 v2 : + e1 -e-> v1 -> + e2 -e-> v2 -> + exp_letin str e1 e2 -e-> (fun a => v2 (v1 a, a)) +where "e '-e->' v" := (@eval _ _ e v). + +Lemma eval_uniq g t (e : exp g t) u v : + e -e-> u -> e -e-> v -> u = v. +Proof. +move=> hu. +apply: (@eval_ind + (fun g t (e : exp g t) (u : Type_of_ctx g -> mtyp t) => + forall v, e -e-> v -> u = v)); last exact: hu. +all: (rewrite {g t e u v hu}). +- move=> g v. + inversion 1. + by inj_ex H3. +- move=> g c v. + inversion 1. + by inj_ex H3. +- move=> g e1 e2 v1 v2 ev1 IH1 ev2 IH2 v. + inversion 1. + inj_ex H0; inj_ex H1; subst. + inj_ex H5; subst. + by rewrite (IH1 _ H3) (IH2 _ H4). +- move=> g x i v. + inversion 1. + by inj_ex H6; subst. +- move=> g t1 t2 e1 e2 v1 v2 ev1 IH1 ev2 IH2 v. + inversion 1. + inj_ex H3; inj_ex H4; subst. + inj_ex H5; subst. + by rewrite (IH1 _ H6) (IH2 _ H7). +- move=> g t t' x0 e0 e1 v1 v2 ev1 IH1 ev2 IH2 v. + inversion 1. + inj_ex H5; subst. + inj_ex H6; subst. + inj_ex H7; subst. + by rewrite (IH1 _ H4) (IH2 _ H8). +Qed. + +Lemma eval_total g t (e : exp g t) : exists v, e -e-> v. +Proof. +elim: e. +- by eexists; exact: eval_tt. +- by eexists; exact: eval_real. +- move=> {}g {}t x e; subst t. + by eexists; exact: eval_var. +- move=> {}g e1 [v1] IH1 e2 [v2] IH2. + by eexists; exact: (eval_plus IH1 IH2). +- move=> {}g t1 t2 e1 [v1] IH1 e2 [v2] IH2. + by eexists; exact: (eval_pair IH1 IH2). +- move=> {}g {}t u x e1 [v1] IH1 e2 [v2] IH2. + by eexists; exact: (eval_letin IH1 IH2). +Qed. + +Definition exec g t (e : exp g t) : Type_of_ctx g -> mtyp t := + proj1_sig (cid (@eval_total g t e)). + +Lemma exec_eval g t (e : exp g t) v : exec e = v <-> e -e-> v. +Proof. +split. + by move=> <-; rewrite /exec; case: cid. +move=> ev. +by rewrite /exec; case: cid => f H/=; apply: eval_uniq; eauto. +Qed. + +Lemma eval_exec g t (e : exp g t) : e -e-> exec e. +Proof. by rewrite /exec; case: cid. Qed. + +Lemma exec_real g r : @exec g Real (exp_real r) = (fun=> r). +Proof. exact/exec_eval/eval_real. Qed. + +Lemma exec_var g str t H : + exec (@exp_var _ t str H) = + eq_rect _ (fun a => Type_of_ctx g -> mtyp a) + (acc g (index str (map fst g))) + _ (esym H). +Proof. +subst t. +rewrite {1}/exec. +case: cid => f H. +inversion H; subst g0 str0. +by inj_ex H6; subst f. +Qed. + +Lemma exp_var'E str t (f : find str t) H : exp_var' str f = exp_var str H. +Proof. by rewrite /exp_var'; congr exp_var. Qed. + +Lemma exec_letin g x t1 t2 (e1 : exp g t1) (e2 : exp ((x, t1) :: g) t2) : + exec [let x := e1 in e2] = (fun a => (exec e2) ((exec e1) a, a)). +Proof. by apply/exec_eval/eval_letin; exact: eval_exec. Qed. + +Goal ([{1}:R] : exp [::] _) -e-> (fun=> 1). +Proof. exact: eval_real. Qed. +Goal @eval [::] _ [{1}:R + {2}:R] (fun=> 3). +Proof. exact/eval_plus/eval_real/eval_real. Qed. +Goal @eval [:: ("x", Real)] _ (exp_var "x" erefl) (@acc [:: ("x", Real)] 0). +Proof. exact: eval_var. Qed. +Goal @eval [::] _ [let "x" := {1}:R in #{"x"}] (fun=> 1). +Proof. +apply/exec_eval; rewrite exec_letin/=. +apply/funext => t/=. +by rewrite exp_var'E exec_real/= exec_var/=. +Qed. + +Goal exec (g := [::]) [let "x" := {1}:R in #{"x"}] = (fun=> 1). +Proof. +rewrite exec_letin//=. +apply/funext => x. +by rewrite exp_var'E exec_var/= exec_real. +Qed. + +End eval. + +End with_bidirectional_hints. + +End lang_intrinsic_tysc. +End lang_intrinsic_tysc. diff --git a/theories/lang_syntax_util.v b/theories/lang_syntax_util.v new file mode 100644 index 000000000..a7b19a603 --- /dev/null +++ b/theories/lang_syntax_util.v @@ -0,0 +1,84 @@ +(* mathcomp analysis (c) 2025 Inria and AIST. License: CeCILL-C. *) +From Coq Require Import String. +From HB Require Import structures. +Require Import Classical_Prop. (* NB: to compile with Coq 8.17 *) +From mathcomp Require Import all_ssreflect. +From mathcomp Require Import interval_inference. + +(**md**************************************************************************) +(* Shared by lang_syntax_*.v files *) +(******************************************************************************) + +HB.instance Definition _ := hasDecEq.Build string eqb_spec. + +Ltac inj_ex H := revert H; + match goal with + | |- existT ?P ?l (existT ?Q ?t (existT ?R ?u (existT ?T ?v ?v1))) = + existT ?P ?l (existT ?Q ?t (existT ?R ?u (existT ?T ?v ?v2))) -> _ => + (intro H; do 4 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l (existT ?Q ?t (existT ?R ?u ?v1)) = + existT ?P ?l (existT ?Q ?t (existT ?R ?u ?v2)) -> _ => + (intro H; do 3 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l (existT ?Q ?t ?v1) = + existT ?P ?l (existT ?Q ?t ?v2) -> _ => + (intro H; do 2 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l (existT ?Q ?t ?v1) = + existT ?P ?l (existT ?Q ?t' ?v2) -> _ => + (intro H; do 2 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l ?v1 = + existT ?P ?l ?v2 -> _ => + (intro H; apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l ?v1 = + existT ?P ?l' ?v2 -> _ => + (intro H; apply Classical_Prop.EqdepTheory.inj_pair2 in H) +end. + +Set Implicit Arguments. +Unset Strict Implicit. +Set Printing Implicit Defensive. + +Class infer (P : Prop) := Infer : P. +#[global] Hint Mode infer ! : typeclass_instances. +#[global] Hint Extern 0 (infer _) => (exact) : typeclass_instances. + +Section tagged_context. +Context {T : eqType} {t0 : T}. +Let ctx := seq (string * T). +Implicit Types (str : string) (g : ctx) (t : T). + +Definition dom g := map fst g. + +Definition lookup g str := nth t0 (map snd g) (index str (dom g)). + +Structure tagged_ctx := Tag {untag : ctx}. + +Structure find str t := Find { + ctx_of : tagged_ctx ; + #[canonical=no] ctx_prf : t = lookup (untag ctx_of) str}. + +Lemma ctx_prf_head str t g : t = lookup ((str, t) :: g) str. +Proof. by rewrite /lookup /= !eqxx. Qed. + +Lemma ctx_prf_tail str t g str' t' : + str' != str -> + t = lookup g str -> + t = lookup ((str', t') :: g) str. +Proof. +move=> str'str tg /=; rewrite /lookup/=. +by case: ifPn => //=; rewrite (negbTE str'str). +Qed. + +Definition recurse_tag g := Tag g. +Canonical found_tag g := recurse_tag g. + +Canonical found str t g : find str t := + @Find str t (found_tag ((str, t) :: g)) + (@ctx_prf_head str t g). + +Canonical recurse str t str' t' {H : infer (str' != str)} + (g : find str t) : find str t := + @Find str t (recurse_tag ((str', t') :: untag (ctx_of g))) + (@ctx_prf_tail str t (untag (ctx_of g)) str' t' H (ctx_prf g)). + +End tagged_context. +Arguments lookup {T} t0 g str. diff --git a/theories/normedtype_theory/ereal_normedtype.v b/theories/normedtype_theory/ereal_normedtype.v index 2225c83b0..19ad908fb 100644 --- a/theories/normedtype_theory/ereal_normedtype.v +++ b/theories/normedtype_theory/ereal_normedtype.v @@ -5,9 +5,9 @@ From mathcomp Require Import rat interval zmodp vector fieldext falgebra. From mathcomp Require Import archimedean. From mathcomp Require Import mathcomp_extra unstable boolp classical_sets. From mathcomp Require Import functions cardinality set_interval. -From mathcomp Require Import interval_inference ereal reals topology. -From mathcomp Require Import separation_axioms function_spaces real_interval. -From mathcomp Require Import prodnormedzmodule tvs. +From mathcomp Require Import interval_inference ereal reals constructive_ereal. +From mathcomp Require Import topology separation_axioms function_spaces. +From mathcomp Require Import real_interval prodnormedzmodule tvs. From mathcomp Require Import num_normedtype. (**md**************************************************************************) diff --git a/theories/prob_lang.v b/theories/prob_lang.v new file mode 100644 index 000000000..e0ee01552 --- /dev/null +++ b/theories/prob_lang.v @@ -0,0 +1,2285 @@ +(* mathcomp analysis (c) 2025 Inria and AIST. License: CeCILL-C. *) +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval finmap. +From mathcomp Require Import rat archimedean ring lra. +From mathcomp Require Import unstable mathcomp_extra boolp classical_sets. +From mathcomp Require Import functions cardinality fsbigop interval_inference. +From mathcomp Require Import reals ereal topology normedtype sequences. +From mathcomp Require Import esum measure lebesgue_measure numfun exp. +From mathcomp Require Import lebesgue_integral trigo probability kernel charge. + +(**md**************************************************************************) +(* # Semantics of a probabilistic programming language using s-finite kernels *) +(* *) +(* Reference: *) +(* - R. Affeldt, C. Cohen, A. Saito. Semantics of probabilistic programs *) +(* using s-finite kernels in Coq. CPP 2023 *) +(* - S. Staton. Commutative Semantics for Probabilistic Programming. *) +(* ESOP 2017 *) +(* *) +(* ``` *) +(* measurable_sum X Y == the type X + Y, as a measurable type *) +(* ``` *) +(* *) +(* ``` *) +(* mscore f t := mscale `|f t| \d_tt *) +(* kscore f := fun=> mscore f *) +(* This is an s-finite kernel. *) +(* kite k1 k2 mf := kdirac mf \; kadd (kiteT k1) (kiteF k2). *) +(* k1 has type R.-sfker T ~> T'. *) +(* k2 has type R.-sfker T ~> T'. *) +(* mf is a proof that f : T -> bool is measurable. *) +(* KITE.kiteT k1 is k1 \o fst if f returne true *) +(* and zero otherwise. *) +(* KITE.kiteF k2 is k2 \o fst if f returne false *) +(* and zero otherwise. *) +(* *) +(* ret mf == access the context with f and return the result *) +(* mf is a proof that f is measurable. *) +(* This is a probability kernel. *) +(* sample mP == sample according to the probability measure P *) +(* mP is a proof that P is a measurable function. *) +(* sample_cst P == same as sample with a constant probability measure *) +(* normalize k P == normalize the kernel k into a probability kernel *) +(* P is a default probability in case normalization *) +(* is not possible. *) +(* normalize_pt k := normalize k point *) +(* ite mf k1 k2 == access the context with the boolean function f and *) +(* behaves as k1 or k2 according to the result *) +(* letin l k == execute l, augment the context, and execute k *) +(* fail := let _ := score 0 in ret point *) +(* score mf == observe t from d, where f is the density of d and *) +(* t occurs in f *) +(* e.g., score (r e^(-r * t)) = observe t from exp(r) *) +(* acc0of2, acc1of2, etc. == accessor function *) +(* case_nat t u_ == case analysis on natural numbers *) +(* t has type R.-sfker T ~> nat *) +(* u_ has type nat -> R.-sfker T ~> T' *) +(* CASE_SUM.case_sum g k1 k2 == case analysis on a sum type *) +(* g has type R.-sfker X ~> (A + B). *) +(* k1 has type A -> R.-sfker X ~> Y. *) +(* k2 has type B -> R.-sfker X ~> Y. *) +(* kcounting == the counting measure as a kernel *) +(* iterate k mu == iteration *) +(* k has type R.-sfker G * A ~> (A + B). *) +(* mu is a proof that u : G -> A is measurable. *) +(* flift_neq == an s-finite kernel to test that two expressions *) +(* are different *) +(* ``` *) +(* *) +(* Examples: Staton's bus, von Neumann's trick, etc. *) +(* *) +(* ``` *) +(* mkswap k == given a kernel k : (Y * X) ~> Z, *) +(* returns a kernel of type (X * Y) ~> Z *) +(* letin' := mkcomp \o mkswap *) +(* ``` *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. +Import Order.TTheory GRing.Theory Num.Def Num.ExtraDef Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. + +Lemma eq_probability R d (Y : measurableType d) (m1 m2 : probability Y R) : + (m1 =1 m2 :> (set Y -> \bar R)) -> m1 = m2. +Proof. +move: m1 m2 => [m1 +] [m2 +] /= m1m2. +move/funext : m1m2 => <- -[[c11 c12] [m01] [sf1] [sig1] [fin1] [sub1] [p1]] + [[c21 c22] [m02] [sf2] [sig2] [fin2] [sub2] [p2]]. +have ? : c11 = c21 by []. +subst c21. +have ? : c12 = c22 by []. +subst c22. +have ? : m01 = m02 by []. +subst m02. +have ? : sf1 = sf2 by []. +subst sf2. +have ? : sig1 = sig2 by []. +subst sig2. +have ? : fin1 = fin2 by []. +subst fin2. +have ? : sub1 = sub2 by []. +subst sub2. +have ? : p1 = p2 by []. +subst p2. +by f_equal. +Qed. + +Definition poisson3 {R : realType} := @poisson_pmf R 3%:R 4. (* 0.168 *) +Definition poisson10 {R : realType} := @poisson_pmf R 10%:R 4. (* 0.019 *) + +Definition dep_uncurry (A : Type) (B : A -> Type) (C : Type) : + (forall a : A, B a -> C) -> {a : A & B a} -> C := + fun f p => let (a, Ba) := p in f a Ba. + +(* TODO: move *) +Lemma poisson_pmf_gt0 {R : realType} k (r : R) : + (0 < r -> 0 < poisson_pmf r k.+1)%R. +Proof. +move=> r0; rewrite /poisson_pmf r0 mulr_gt0 ?expR_gt0//. +by rewrite divr_gt0// ?exprn_gt0// invr_gt0 ltr0n fact_gt0. +Qed. + +Lemma exponential_pdf_gt0 {R : realType} (r : R) x : + (0 < r -> 0 < x -> 0 < exponential_pdf r x)%R. +Proof. +move=> r0 x0; rewrite /exponential_pdf/=. +rewrite patchE/= ifT; last first. + by rewrite inE/= in_itv/= (ltW x0). +by rewrite mulr_gt0// expR_gt0. +Qed. + +(* X + Y is a measurableType if X and Y are *) +HB.instance Definition _ (X Y : pointedType) := + isPointed.Build (X + Y)%type (@inl X Y point). + +Section measurable_sum. +Context d d' (X : measurableType d) (Y : measurableType d'). + +Definition measurable_sum : set (set (X + Y)) := setT. + +Let sum0 : measurable_sum set0. Proof. by []. Qed. + +Let sumC A : measurable_sum A -> measurable_sum (~` A). Proof. by []. Qed. + +Let sumU (F : (set (X + Y))^nat) : (forall i, measurable_sum (F i)) -> + measurable_sum (\bigcup_i F i). +Proof. by []. Qed. + +HB.instance Definition _ := @isMeasurable.Build default_measure_display + (X + Y)%type measurable_sum sum0 sumC sumU. + +End measurable_sum. + +Lemma measurable_fun_sum dA dB d' (A : measurableType dA) (B : measurableType dB) + (Y : measurableType d') (f : A -> Y) (g : B -> Y) : + measurable_fun setT f -> measurable_fun setT g -> + measurable_fun setT (fun tb : A + B => + match tb with inl a => f a | inr b => g b end). +Proof. +move=> mx my/= _ Z mZ /=; rewrite setTI /=. +rewrite (_ : _ @^-1` Z = inl @` (f @^-1` Z) `|` inr @` (g @^-1` Z)). + exact: measurableU. +apply/seteqP; split. + by move=> [a Zxa|b Zxb]/=; [left; exists a|right; exists b]. +by move=> z [/= [a Zxa <-//=]|]/= [b Zyb <-//=]. +Qed. + +(* TODO: measurable_fun_if_pair -> measurable_fun_if_pair_bool? *) +Lemma measurable_fun_if_pair_nat d d' (X : measurableType d) + (Y : measurableType d') (f g : X -> Y) (n : nat) : + measurable_fun setT f -> measurable_fun setT g -> + measurable_fun setT (fun xn => if xn.2 == n then f xn.1 else g xn.1). +Proof. +move=> mx my; apply: measurable_fun_ifT => //=. +- have h : measurable_fun [set: nat] (fun t => t == n) by []. + exact: (measurableT_comp h). +- exact: measurableT_comp. +- exact: measurableT_comp. +Qed. + +Module Notations. +Notation munit := (unit : measurableType _). +Notation mbool := (bool : measurableType _). +Notation mnat := (nat : measurableType _). +End Notations. + +Lemma invr_nonneg_proof (R : numDomainType) (p : {nonneg R}) : + (0 <= (p%:num)^-1)%R. +Proof. by rewrite invr_ge0. Qed. + +Definition invr_nonneg (R : numDomainType) (p : {nonneg R}) := + NngNum (invr_nonneg_proof p). + +Section constants. +Variable R : realType. +Local Open Scope ring_scope. + +Lemma onem1S n : `1- (1 / n.+1%:R) = (n%:R / n.+1%:R)%:nng%:num :> R. +Proof. +by rewrite /onem/= -{1}(@divrr _ n.+1%:R) ?unitfE// -mulrBl -natr1 addrK. +Qed. + +Lemma p1S n : (1 / n.+1%:R)%:nng%:num <= 1 :> R. +Proof. by rewrite ler_pdivrMr//= mul1r ler1n. Qed. + +Lemma p12 : (1 / 2%:R)%:nng%:num <= 1 :> R. +Proof. by rewrite ler_pdivrMr//= mul1r ler1n. Qed. + +Lemma p14 : (1 / 4%:R)%:nng%:num <= 1 :> R. +Proof. by rewrite ler_pdivrMr//= mul1r ler1n. Qed. + +Lemma onem27 : `1- (2 / 7%:R) = (5%:R / 7%:R)%:nng%:num :> R. +Proof. by apply/eqP; rewrite subr_eq/= -mulrDl -natrD divrr// unitfE. Qed. + +(*Lemma p27 : (2 / 7%:R)%:nng%:num <= 1 :> R. +Proof. by rewrite /= lter_pdivrMr// mul1r ler_nat. Qed.*) + +End constants. +Arguments p12 {R}. +Arguments p14 {R}. +(*Arguments p27 {R}.*) +Arguments p1S {R}. + +Section mscore. +Context d (T : measurableType d) (R : realType). +Variable f : T -> R. + +Definition mscore t : {measure set unit -> \bar R} := + let p := NngNum (normr_ge0 (f t)) in mscale p \d_tt. + +Lemma mscoreE t U : mscore t U = if U == set0 then 0 else `| (f t)%:E |. +Proof. +rewrite /mscore/= /mscale/=; have [->|->] := set_unit U. + by rewrite eqxx dirac0 mule0. +by rewrite diracT mule1 (negbTE setT0). +Qed. + +Lemma measurable_fun_mscore U : measurable_fun setT f -> + measurable_fun setT (mscore ^~ U). +Proof. +move=> mr; under eq_fun do rewrite mscoreE/=. +have [U0|U0] := eqVneq U set0; first exact: measurable_cst. +by apply: measurableT_comp => //; exact: measurableT_comp. +Qed. + +End mscore. + +(* decomposition of score into finite kernels [Section 3.2, Staton ESOP 2017] *) +Module SCORE. +Section score. +Context d (T : measurableType d) (R : realType). +Variable f : T -> R. + +Definition k (mf : measurable_fun [set: T] f) i t U := + if i%:R%:E <= mscore f t U < i.+1%:R%:E then + mscore f t U + else + 0. + +Hypothesis mf : measurable_fun setT f. + +Lemma k0 i t : k mf i t (set0 : set unit) = 0 :> \bar R. +Proof. by rewrite /k measure0; case: ifP. Qed. + +Lemma k_ge0 i t B : 0 <= k mf i t B. +Proof. by rewrite /k; case: ifP. Qed. + +Lemma k_sigma_additive i t : semi_sigma_additive (k mf i t). +Proof. +move=> /= F mF tF mUF; rewrite /k /=. +have [F0|UF0] := eqVneq (\bigcup_n F n) set0. + rewrite F0 measure0 (_ : (fun _ => _) = cst 0). + by case: ifPn => _; exact: cvg_cst. + apply/funext => k; rewrite big1// => n _. + by move: F0 => /bigcup0P -> //; rewrite measure0; case: ifPn. +move: (UF0) => /eqP/bigcup0P/existsNP[m /not_implyP[_ /eqP Fm0]]. +rewrite [in X in _ --> X]mscoreE (negbTE UF0). +rewrite -(cvg_shiftn m.+1)/=. +case: ifPn => ir. + rewrite (_ : (fun _ => _) = cst `|(f t)%:E|); first exact: cvg_cst. + apply/funext => n. + rewrite big_mkord (bigD1 (widen_ord (leq_addl n _) (Ordinal (ltnSn m))))//=. + rewrite [in X in X + _]mscoreE (negbTE Fm0) ir big1 ?adde0// => /= j jk. + rewrite mscoreE; have /eqP -> : F j == set0. + have [/eqP//|Fjtt] := set_unit (F j). + move/trivIsetP : tF => /(_ j m Logic.I Logic.I jk). + by rewrite Fjtt setTI => /eqP; rewrite (negbTE Fm0). + by rewrite eqxx; case: ifP. +rewrite (_ : (fun _ => _) = cst 0); first exact: cvg_cst. +apply/funext => n. +rewrite big_mkord (bigD1 (widen_ord (leq_addl n _) (Ordinal (ltnSn m))))//=. +rewrite [in X in if X then _ else _]mscoreE (negbTE Fm0) (negbTE ir) add0e. +rewrite big1//= => j jm; rewrite mscoreE; have /eqP -> : F j == set0. + have [/eqP//|Fjtt] := set_unit (F j). + move/trivIsetP : tF => /(_ j m Logic.I Logic.I jm). + by rewrite Fjtt setTI => /eqP; rewrite (negbTE Fm0). +by rewrite eqxx; case: ifP. +Qed. + +HB.instance Definition _ i t := isMeasure.Build _ _ _ + (k mf i t) (k0 i t) (k_ge0 i t) (@k_sigma_additive i t). + +Lemma measurable_fun_k i U : measurable U -> measurable_fun setT (k mf i ^~ U). +Proof. +move=> /= mU; rewrite /k /= (_ : (fun x => _) = + (fun x => if i%:R%:E <= x < i.+1%:R%:E then x else 0) \o (mscore f ^~ U)) //. +apply: measurableT_comp => /=; last exact/measurable_fun_mscore. +rewrite (_ : (fun x => _) = (fun x => x * + (\1_(`[i%:R%:E, i.+1%:R%:E [%classic : set _) x)%:E)); last first. + apply/funext => x; case: ifPn => ix; first by rewrite indicE/= mem_set ?mule1. + by rewrite indicE/= memNset ?mule0// /= in_itv/=; exact/negP. +apply: emeasurable_funM => //=; apply/measurable_EFinP. +by rewrite (_ : \1__ = mindic R (emeasurable_itv `[i%:R%:E, i.+1%:R%:E[)). +Qed. + +Definition mk i t := [the measure _ _ of k mf i t]. + +HB.instance Definition _ i := + isKernel.Build _ _ _ _ _ (mk i) (measurable_fun_k i). + +Lemma mk_uub i : measure_fam_uub (mk i). +Proof. +exists i.+1%:R => /= t; rewrite /k mscoreE setT_unit. +by case: ifPn => //; case: ifPn => // _ /andP[]. +Qed. + +HB.instance Definition _ i := + Kernel_isFinite.Build _ _ _ _ _ (mk i) (mk_uub i). + +End score. +End SCORE. + +Section kscore. +Context d (T : measurableType d) (R : realType). +Variable f : T -> R. + +Definition kscore (mf : measurable_fun setT f) + : T -> {measure set _ -> \bar R} := + mscore f. + +Variable mf : measurable_fun setT f. + +Let measurable_fun_kscore U : measurable U -> + measurable_fun setT (kscore mf ^~ U). +Proof. by move=> /= _; exact: measurable_fun_mscore. Qed. + +HB.instance Definition _ := isKernel.Build _ _ T _ R + (kscore mf) measurable_fun_kscore. + +Import SCORE. + +Let sfinite_kscore : exists k : (R.-fker T ~> _)^nat, + forall x U, measurable U -> + kscore mf x U = mseries (k ^~ x) 0 U. +Proof. +rewrite /=; exists (fun i => [the R.-fker _ ~> _ of mk mf i]) => /= t U mU. +rewrite /mseries /kscore/= mscoreE; case: ifPn => [/eqP U0|U0]. + by apply/esym/eseries0 => i _; rewrite U0 measure0. +rewrite /mk /= /k /= mscoreE (negbTE U0). +apply/esym/cvg_lim => //. +rewrite -(cvg_shiftn `|floor (fine `|(f t)%:E|)|%N.+1)/=. +rewrite (_ : (fun _ => _) = cst `|(f t)%:E|); first exact: cvg_cst. +apply/funext => n. +pose floor_f := widen_ord (leq_addl n `|floor `|f t| |.+1) + (Ordinal (ltnSn `|floor `|f t| |)). +rewrite big_mkord (bigD1 floor_f)//= ifT; last first. + rewrite lee_fin lte_fin; apply/andP; split. + by rewrite natr_absz (@ger0_norm _ (floor `|f t|)) ?floor_ge0// floor_le_tmp. + rewrite -addn1 natrD natr_absz. + by rewrite (@ger0_norm _ (floor `|f t|)) ?floor_ge0// intrD1 floorD1_gt. +rewrite big1 ?adde0//= => j jk. +rewrite ifF// lte_fin lee_fin. +move: jk; rewrite neq_ltn/= => /orP[|] jr. +- suff : (j.+1%:R <= `|f t|)%R by rewrite leNgt => /negbTE ->; rewrite andbF. + rewrite (_ : j.+1%:R = j.+1%:~R)// floor_ge_int//. + move: jr; rewrite -lez_nat => /le_trans; apply. + by rewrite -[leRHS](@ger0_norm _ (floor `|f t|)) ?floor_ge0. +- suff : (`|f t| < j%:R)%R by rewrite ltNge => /negbTE ->. + move: jr; rewrite -ltz_nat -(@ltr_int R) (@gez0_abs (floor `|f t|)) ?floor_ge0//. + by rewrite ltr_int floor_lt_int. +Qed. + +HB.instance Definition _ := + @Kernel_isSFinite.Build _ _ _ _ _ (kscore mf) sfinite_kscore. + +End kscore. + +(* decomposition of ite into s-finite kernels [Section 3.2, Staton ESOP 2017] *) +Module ITE. +Section ite. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Section kiteT. +Variable k : R.-ker X ~> Y. + +Definition kiteT : X * bool -> {measure set Y -> \bar R} := + fun xb => if xb.2 then k xb.1 else mzero. + +Let measurable_fun_kiteT U : measurable U -> measurable_fun setT (kiteT ^~ U). +Proof. +move=> /= mcU; rewrite /kiteT. +rewrite (_ : (fun _ => _) = + (fun x => if x.2 then k x.1 U else mzero U)); last first. + by apply/funext => -[t b]/=; case: ifPn. +apply: (@measurable_fun_if_pair _ _ _ _ (k ^~ U) (fun=> mzero U)) => //. +exact/measurable_kernel. +Qed. + +#[export] +HB.instance Definition _ := isKernel.Build _ _ _ _ _ + kiteT measurable_fun_kiteT. +End kiteT. + +Section sfkiteT. +Variable k : R.-sfker X ~> Y. + +Let sfinite_kiteT : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> kiteT k x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => [the _.-ker _ ~> _ of kiteT (k_ n)]) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + by exists r%:num => /= -[x []]; rewrite /kiteT//= /mzero//. +move=> [x b] U mU; rewrite /kiteT; case: ifPn => hb; first by rewrite hk. +by rewrite /mseries eseries0. +Qed. + +#[export] +HB.instance Definition _ := @isSFiniteKernel_subdef.Build _ _ _ _ _ + (kiteT k) sfinite_kiteT. +End sfkiteT. + +Section fkiteT. +Variable k : R.-fker X ~> Y. + +Let kiteT_uub : measure_fam_uub (kiteT k). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +exists M%:num => /= -[]; rewrite /kiteT => t [|]/=; first exact: hM. +by rewrite /= /mzero. +Qed. + +#[export] +HB.instance Definition _ := Kernel_isFinite.Build _ _ _ _ _ + (kiteT k) kiteT_uub. +End fkiteT. + +Section kiteF. +Variable k : R.-ker X ~> Y. + +Definition kiteF : X * bool -> {measure set Y -> \bar R} := + fun xb => if ~~ xb.2 then k xb.1 else mzero. + +Let measurable_fun_kiteF U : measurable U -> measurable_fun setT (kiteF ^~ U). +Proof. +move=> /= mcU; rewrite /kiteF. +rewrite (_ : (fun x => _) = + (fun x => if x.2 then mzero U else k x.1 U)); last first. + by apply/funext => -[t b]/=; rewrite if_neg//; case: ifPn. +apply: (@measurable_fun_if_pair _ _ _ _ (fun=> mzero U) (k ^~ U)) => //. +exact/measurable_kernel. +Qed. + +#[export] +HB.instance Definition _ := isKernel.Build _ _ _ _ _ + kiteF measurable_fun_kiteF. + +End kiteF. + +Section sfkiteF. +Variable k : R.-sfker X ~> Y. + +Let sfinite_kiteF : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> kiteF k x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => [the _.-ker _ ~> _ of kiteF (k_ n)]) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + by exists r%:num => /= -[x []]; rewrite /kiteF//= /mzero//. +move=> [x b] U mU; rewrite /kiteF; case: ifPn => hb; first by rewrite hk. +by rewrite /mseries eseries0. +Qed. + +#[export] +HB.instance Definition _ := @isSFiniteKernel_subdef.Build _ _ _ _ _ + (kiteF k) sfinite_kiteF. + +End sfkiteF. + +Section fkiteF. +Variable k : R.-fker X ~> Y. + +Let kiteF_uub : measure_fam_uub (kiteF k). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +by exists M%:num => /= -[]; rewrite /kiteF/= => t; case => //=; rewrite /mzero. +Qed. + +#[export] +HB.instance Definition _ := Kernel_isFinite.Build _ _ _ _ _ + (kiteF k) kiteF_uub. + +End fkiteF. +End ite. +End ITE. + +Section ite. +Context d d' (T : measurableType d) (T' : measurableType d') (R : realType). +Variables (f : T -> bool) (u1 u2 : R.-sfker T ~> T'). + +(* NB: not used? *) +Definition mite (mf : measurable_fun setT f) : T -> set T' -> \bar R := + fun t => if f t then u1 t else u2 t. + +Hypothesis mf : measurable_fun [set: T] f. + +Let mite0 t : mite mf t set0 = 0. +Proof. by rewrite /mite; case: ifPn. Qed. + +Let mite_ge0 t U : 0 <= mite mf t U. +Proof. by rewrite /mite; case: ifPn. Qed. + +Let mite_sigma_additive t : semi_sigma_additive (mite mf t). +Proof. +by rewrite /mite; case: ifPn => ft; exact: measure_semi_sigma_additive. +Qed. + +HB.instance Definition _ t := isMeasure.Build _ _ _ (mite mf t) + (mite0 t) (mite_ge0 t) (@mite_sigma_additive t). + +Import ITE. + +Definition kite : R.-sfker T ~> T' := + kdirac mf \; kadd (kiteT u1) (kiteF u2). + +End ite. + +Section insn2. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Definition ret (f : X -> Y) (mf : measurable_fun [set: X] f) + : R.-pker X ~> Y := kdirac mf. + +Definition sample (P : X -> pprobability Y R) (mP : measurable_fun [set: X] P) + : R.-pker X ~> Y := + kprobability mP. + +Definition sample_cst (P : pprobability Y R) : R.-pker X ~> Y := + sample (measurable_cst P). + +Definition normalize (k : R.-ker X ~> Y) P : X -> probability Y R := + knormalize k P. + +Definition normalize_pt (k : R.-ker X ~> Y) : X -> probability Y R := + normalize k point. + +Lemma measurable_normalize_pt (f : R.-ker X ~> Y) : + measurable_fun [set: X] (normalize_pt f : X -> pprobability Y R). +Proof. +apply: (@measurability _ _ _ _ _ _ + (@pset _ _ _ : set (set (pprobability Y R)))) => //. +move=> _ -[_ [r r01] [Ys mYs <-]] <-. +apply: emeasurable_fun_infty_o => //. +exact: (measurable_kernel (knormalize f point) Ys). +Qed. + +Definition ite (f : X -> bool) (mf : measurable_fun setT f) + (k1 k2 : R.-sfker X ~> Y) : R.-sfker X ~> Y := + locked [the R.-sfker X ~> Y of kite k1 k2 mf]. + +End insn2. +Arguments ret {d d' X Y R f} mf. +Arguments sample_cst {d d' X Y R}. +Arguments sample {d d' X Y R}. + +Section insn2_lemmas. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Lemma retE (f : X -> Y) (mf : measurable_fun setT f) x : + ret mf x = \d_(f x) :> (_ -> \bar R). +Proof. by []. Qed. + +Lemma sample_cstE (P : probability Y R) (x : X) : sample_cst P x = P. +Proof. by []. Qed. + +Lemma sampleE (P : X -> pprobability Y R) (mP : measurable_fun setT P) (x : X) : sample P mP x = P x. +Proof. by []. Qed. + +Lemma normalizeE (f : R.-sfker X ~> Y) P x U : + normalize f P x U = + if (f x [set: Y] == 0) || (f x [set: Y] == +oo) then P U + else f x U * ((fine (f x [set: Y]))^-1)%:E. +Proof. by rewrite /normalize /= /mnormalize; case: ifPn. Qed. + +Lemma iteE (f : X -> bool) (mf : measurable_fun setT f) + (k1 k2 : R.-sfker X ~> Y) x : + ite mf k1 k2 x = if f x then k1 x else k2 x. +Proof. +apply/eq_measure/funext => U. +rewrite /ite; unlock => /=. +rewrite /kcomp/= integral_dirac//=. +rewrite diracT mul1e. +rewrite -/(measure_add (ITE.kiteT k1 (x, f x)) (ITE.kiteF k2 (x, f x))). +rewrite measure_addE. +rewrite /ITE.kiteT /ITE.kiteF/=. +by case: ifPn => fx /=; rewrite /mzero ?(adde0,add0e). +Qed. + +End insn2_lemmas. + +Lemma normalize_kdirac (R : realType) + d (T : measurableType d) d' (T' : measurableType d') (x : T) (r : T') P : + normalize (kdirac (measurable_cst r)) P x = \d_r :> probability T' R. +Proof. +apply: eq_probability => U. +rewrite normalizeE /= diracE in_setT/=. +by rewrite onee_eq0/= indicE in_setT/= -div1r divr1 mule1. +Qed. + +Section insn3. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Definition letin (l : R.-sfker X ~> Y) (k : R.-sfker (X * Y) ~> Z) + : R.-sfker X ~> Z := + [the R.-sfker X ~> Z of l \; k]. + +End insn3. + +Section insn3_lemmas. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Lemma letinE (l : R.-sfker X ~> Y) (k : R.-sfker [the measurableType _ of (X * Y)%type] ~> Z) x U : + letin l k x U = \int[l x]_y k (x, y) U. +Proof. by []. Qed. + +End insn3_lemmas. + +(* rewriting laws *) +Section letin_return. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Lemma letin_kret (k : R.-sfker X ~> Y) + (f : X * Y -> Z) (mf : measurable_fun [set: X * Y] f) x U : + measurable U -> + letin k (ret mf) x U = k x (curry f x @^-1` U). +Proof. +move=> mU; rewrite letinE. +under eq_integral do rewrite retE. +rewrite integral_indic ?setIT// -[X in measurable X]setTI. +exact: (measurableT_comp mf). +Qed. + +Lemma letin_retk (f : X -> Y) + (mf : measurable_fun [set: X] f) (k : R.-sfker X * Y ~> Z) x U : + measurable U -> + letin (ret mf) k x U = k (x, f x) U. +Proof. +move=> mU; rewrite letinE retE integral_dirac ?diracT ?mul1e//. +exact: (measurableT_comp (measurable_kernel k _ mU)). +Qed. + +End letin_return. + +Section insn1. +Context d (X : measurableType d) (R : realType). + +Definition score (f : X -> R) (mf : measurable_fun setT f) : R.-sfker X ~> _ := + [the R.-sfker X ~> _ of kscore mf]. + +End insn1. + +Section hard_constraint. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Definition fail : R.-sfker X ~> Y := + letin (score (measurable_cst (0%R : R))) + (ret (measurable_cst point)). + +Lemma failE x U : fail x U = 0. +Proof. by rewrite /fail letinE ge0_integral_mscale//= normr0 mul0e. Qed. + +End hard_constraint. +Arguments fail {d d' X Y R}. + +Section cst_fun. +Context d (T : measurableType d) (R : realType). + +Definition kr (r : R) := @measurable_cst _ _ T _ setT r. +Definition k3 : measurable_fun _ _ := kr 3%:R. +Definition k10 : measurable_fun _ _ := kr 10%:R. +Definition ktt := @measurable_cst _ _ T _ setT tt. +Definition kb (b : bool) := @measurable_cst _ _ T _ setT b. +Definition kn (n : nat) := @measurable_cst _ _ T _ setT n. + +End cst_fun. +Arguments kr {d T R}. +Arguments k3 {d T R}. +Arguments k10 {d T R}. +Arguments ktt {d T}. +Arguments kb {d T}. +Arguments kn {d T}. + +Section iter_mprod. +Local Open Scope type_scope. + +Fixpoint iter_mprod (l : seq {d & measurableType d}) : {d & measurableType d} := + match l with + | [::] => existT measurableType _ unit + | h :: t => let t' := iter_mprod t in + existT _ _ [the measurableType _ of projT2 h * projT2 t'] + end. + +End iter_mprod. + +Section acc. +Import Notations. +Context {R : realType}. + +Fixpoint acc (l : seq {d & measurableType d}) k : + projT2 (iter_mprod l) -> projT2 (nth (existT _ _ munit) l k) := + match l with + | [::] => match k with O => id | _ => id end + | _ :: _ => match k with + | O => fst + | m.+1 => fun x => acc m x.2 + end + end. + +Lemma measurable_acc (l : seq {d & measurableType d}) n : + measurable_fun setT (@acc l n). +Proof. +by elim: l n => //= h t ih [|m] //; exact: (measurableT_comp (ih _)). +Qed. +End acc. +Arguments acc : clear implicits. +Arguments measurable_acc : clear implicits. + +Section rpair_pairA. +Context d0 d1 d2 (T0 : measurableType d0) (T1 : measurableType d1) + (T2 : measurableType d2). + +Definition rpair d (T : measurableType d) t : T0 -> T0 * T := + fun x => (x, t). + +Lemma mrpair d (T : measurableType d) t : measurable_fun setT (@rpair _ T t). +Proof. exact: measurable_fun_pair. Qed. + +Definition pairA : T0 * T1 * T2 -> T0 * (T1 * T2) := + fun x => (x.1.1, (x.1.2, x.2)). + +Definition mpairA : measurable_fun [set: (T0 * T1) * T2] pairA. +Proof. +apply: measurable_fun_pair => /=; first exact: measurableT_comp. +by apply: measurable_fun_pair => //=; exact: measurableT_comp. +Qed. + +Definition pairAi : T0 * (T1 * T2) -> T0 * T1 * T2 := + fun x => (x.1, x.2.1, x.2.2). + +Definition mpairAi : measurable_fun setT pairAi. +Proof. +apply: measurable_fun_pair => //=; last exact: measurableT_comp. +by apply: measurable_fun_pair => //=; exact: measurableT_comp. +Qed. + +End rpair_pairA. +Arguments rpair {d0 T0 d} T. +#[global] Hint Extern 0 (measurable_fun _ (rpair _ _)) => + solve [apply: mrpair] : core. +Arguments pairA {d0 d1 d2 T0 T1 T2}. +#[global] Hint Extern 0 (measurable_fun _ pairA) => + solve [apply: mpairA] : core. +Arguments pairAi {d0 d1 d2 T0 T1 T2}. +#[global] Hint Extern 0 (measurable_fun _ pairAi) => + solve [apply: mpairAi] : core. + +Section rpair_pairA_comp. +Import Notations. +Context d0 d1 d2 d3 (T0 : measurableType d0) (T1 : measurableType d1) + (T2 : measurableType d2) (T3 : measurableType d3) (R : realType). + +Definition pairAr d (T : measurableType d) t : T0 * T1 -> T0 * (T1 * T) := + pairA \o rpair T t. +Arguments pairAr {d} T. + +Lemma mpairAr d (T : measurableType d) t : measurable_fun setT (pairAr T t). +Proof. exact: measurableT_comp. Qed. + +Definition pairAAr : T0 * T1 * T2 -> T0 * (T1 * (T2 * unit)) := + pairA \o pairA \o rpair unit tt. + +Lemma mpairAAr : measurable_fun setT pairAAr. +Proof. by do 2 apply: measurableT_comp => //. Qed. + +Definition pairAAAr : T0 * T1 * T2 * T3 -> T0 * (T1 * (T2 * (T3 * unit))) := + pairA \o pairA \o pairA \o rpair unit tt. + +Lemma mpairAAAr : measurable_fun setT pairAAAr. +Proof. by do 3 apply: measurableT_comp => //. Qed. + +Definition pairAArAi : T0 * (T1 * T2) -> T0 * (T1 * (T2 * unit)) := + pairAAr \o pairAi. + +Lemma mpairAArAi : measurable_fun setT pairAArAi. +Proof. by apply: measurableT_comp => //=; exact: mpairAAr. Qed. + +Definition pairAAArAAi : T3 * (T0 * (T1 * T2)) -> T3 * (T0 * (T1 * (T2 * unit))) := + pairA \o pairA \o pairA \o rpair unit tt \o pairAi \o pairAi. + +Lemma mpairAAARAAAi : measurable_fun setT pairAAArAAi. +Proof. by do 5 apply: measurableT_comp => //=. Qed. + +End rpair_pairA_comp. +Arguments pairAr {d0 d1 T0 T1 d} T. +Arguments pairAAr {d0 d1 d2 T0 T1 T2}. +Arguments pairAAAr {d0 d1 d2 d3 T0 T1 T2 T3}. +Arguments pairAArAi {d0 d1 d2 T0 T1 T2}. +Arguments pairAAArAAi {d0 d1 d2 d3 T0 T1 T2 T3}. + +Section accessor_functions. +Import Notations. +Context d0 d1 d2 d3 (T0 : measurableType d0) (T1 : measurableType d1) + (T2 : measurableType d2) (T3 : measurableType d3) (R : realType). + +Let T01 : seq {d & measurableType d} := [:: existT _ _ T0; existT _ _ T1]. + +Definition acc0of2 : T0 * T1 -> T0 := + acc T01 0 \o pairAr unit tt. + +Lemma macc0of2 : measurable_fun setT acc0of2. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T01 0)|exact: mpairAr]. +Qed. + +Definition acc1of2 : T0 * T1 -> T1 := + acc T01 1 \o pairAr unit tt. + +Lemma macc1of2 : measurable_fun setT acc1of2. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T01 1)|exact: mpairAr]. +Qed. + +Let T02 := [:: existT _ _ T0; existT _ _ T1; existT _ _ T2]. + +Definition acc1of3 : T0 * T1 * T2 -> T1 := + acc T02 1 \o pairAAr. + +Lemma macc1of3 : measurable_fun setT acc1of3. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T02 1)|exact: mpairAAr]. +Qed. + +Definition acc2of3 : T0 * T1 * T2 -> T2 := + acc T02 2 \o pairAAr. + +Lemma macc2of3 : measurable_fun setT acc2of3. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T02 2)|exact: mpairAAr]. +Qed. + +Definition acc0of3' : T0 * (T1 * T2) -> T0 := + acc T02 0 \o pairAArAi. + +Lemma macc0of3' : measurable_fun setT acc0of3'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T02 0)|exact: mpairAArAi]. +Qed. + +Definition acc1of3' : T0 * (T1 * T2) -> T1 := + acc T02 1 \o pairAArAi. + +Lemma macc1of3' : measurable_fun setT acc1of3'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T02 1)|exact: mpairAArAi]. +Qed. + +Definition acc2of3' : T0 * (T1 * T2) -> T2 := + acc T02 2 \o pairAArAi. + +Lemma macc2of3' : measurable_fun setT acc2of3'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T02 2)|exact: mpairAArAi]. +Qed. + +Let T03 := [:: existT _ _ T0; existT _ _ T1; existT _ d2 T2; existT _ d3 T3]. + +Definition acc1of4 : T0 * T1 * T2 * T3 -> T1 := + acc T03 1 \o pairAAAr. + +Lemma macc1of4 : measurable_fun setT acc1of4. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T03 1)|exact: mpairAAAr]. +Qed. + +Definition acc2of4' : T0 * (T1 * (T2 * T3)) -> T2 := + acc T03 2 \o pairAAArAAi. + +Lemma macc2of4' : measurable_fun setT acc2of4'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T03 2)|exact: mpairAAARAAAi]. +Qed. + +Definition acc3of4 : T0 * T1 * T2 * T3 -> T3 := + acc T03 3 \o pairAAAr. + +Lemma macc3of4 : measurable_fun setT acc3of4. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc T03 3)|exact: mpairAAAr]. +Qed. + +End accessor_functions. +Arguments macc0of2 {d0 d1 _ _}. +Arguments macc1of2 {d0 d1 _ _}. +Arguments macc0of3' {d0 d1 d2 _ _ _}. +Arguments macc1of3 {d0 d1 d2 _ _ _}. +Arguments macc1of3' {d0 d1 d2 _ _ _}. +Arguments macc2of3 {d0 d1 d2 _ _ _}. +Arguments macc2of3' {d0 d1 d2 _ _ _}. +Arguments macc1of4 {d0 d1 d2 d3 _ _ _ _}. +Arguments macc2of4' {d0 d1 d2 d3 _ _ _ _}. +Arguments macc3of4 {d0 d1 d2 d3 _ _ _ _}. + +Module CASE_NAT. +Section case_nat. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Section case_nat_ker. +Variable k : R.-ker X ~> Y. + +Definition case_nat_ m (xn : X * nat) : {measure set Y -> \bar R} := + if xn.2 == m then k xn.1 else mzero. + +Let measurable_fun_case_nat_ m U : measurable U -> + measurable_fun setT (case_nat_ m ^~ U). +Proof. +move=> mU/=; rewrite /case_nat_ (_ : (fun _ => _) = + (fun x => if x.2 == m then k x.1 U else mzero U)) /=; last first. + by apply/funext => -[t b]/=; case: ifPn. +apply: (@measurable_fun_if_pair_nat _ _ _ _ (k ^~ U) (fun=> mzero U)) => //. +exact/measurable_kernel. +Qed. + +#[export] +HB.instance Definition _ m := isKernel.Build _ _ _ _ _ + (case_nat_ m) (measurable_fun_case_nat_ m). +End case_nat_ker. + +Section sfcase_nat. +Variable k : R.-sfker X ~> Y. + +Let sfcase_nat_ m : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> case_nat_ k m x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => case_nat_ (k_ n) m) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + exists r%:num => /= -[x [|n']]; rewrite /case_nat_//= /mzero//. + by case: ifPn => //= ?; rewrite /mzero. + by case: ifPn => // ?; rewrite /= /mzero. +move=> [x b] U mU; rewrite /case_nat_; case: ifPn => hb; first by rewrite hk. +by rewrite /mseries eseries0. +Qed. + +#[export] +HB.instance Definition _ m := @isSFiniteKernel_subdef.Build _ _ _ _ _ + (case_nat_ k m) (sfcase_nat_ m). +End sfcase_nat. + +Section fkcase_nat. +Variable k : R.-fker X ~> Y. + +Let case_nat_uub n : measure_fam_uub (case_nat_ k n). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +exists M%:num => /= -[]; rewrite /case_nat_ => t [|n']/=. + by case: ifPn => //= ?; rewrite /mzero. +by case: ifPn => //= ?; rewrite /mzero. +Qed. + +#[export] +HB.instance Definition _ n := Kernel_isFinite.Build _ _ _ _ _ + (case_nat_ k n) (case_nat_uub n). +End fkcase_nat. + +End case_nat. +End CASE_NAT. + +Import CASE_NAT. + +Section case_nat. +Context d d' (T : measurableType d) (T' : measurableType d') (R : realType). + +Import CASE_NAT. + +Definition case_nat (t : R.-sfker T ~> nat) (u_ : (R.-sfker T ~> T')^nat) + : R.-sfker T ~> T' := + t \; kseries (fun n => case_nat_ (u_ n) n). + +End case_nat. + +Definition measure_sum_display : + measure_display * measure_display -> measure_display. +Proof. exact. Qed. + +Definition g_sigma_imageU d1 d2 + (T1 : measurableType d1) (T2 : measurableType d2) (T : Type) + (f1 : T1 -> T) (f2 : T2 -> T) := + <>. + +Section sum_salgebra_instance. +Context d1 d2 (T1 : measurableType d1) (T2 : measurableType d2). +Let f1 : T1 -> T1 + T2 := @inl T1 T2. +Let f2 : T2 -> T1 + T2 := @inr T1 T2. + +Lemma sum_salgebra_set0 : g_sigma_imageU f1 f2 (set0 : set (T1 + T2)). +Proof. exact: sigma_algebra0. Qed. + +Lemma sum_salgebra_setC A : g_sigma_imageU f1 f2 A -> + g_sigma_imageU f1 f2 (~` A). +Proof. exact: sigma_algebraC. Qed. + +Lemma sum_salgebra_bigcup (F : _^nat) : (forall i, g_sigma_imageU f1 f2 (F i)) -> + g_sigma_imageU f1 f2 (\bigcup_i (F i)). +Proof. exact: sigma_algebra_bigcup. Qed. + +HB.instance Definition sum_salgebra_mixin := + @isMeasurable.Build (measure_sum_display (d1, d2)) + (T1 + T2)%type (g_sigma_imageU f1 f2) + sum_salgebra_set0 sum_salgebra_setC sum_salgebra_bigcup. + +End sum_salgebra_instance. +Reserved Notation "p .-sum" (at level 1, format "p .-sum"). +Reserved Notation "p .-sum.-measurable" + (at level 2, format "p .-sum.-measurable"). +Notation "p .-sum" := (measure_sum_display p) : measure_display_scope. +Notation "p .-sum.-measurable" := + ((p.-sum).-measurable : set (set (_ + _))) : + classical_set_scope. + +#[short(type="measurableCountType")] +HB.structure Definition MeasurableCountable d := + {T of Measurable d T & Countable T }. + +#[short(type="measurableFinType")] +HB.structure Definition MeasurableFinite d := + {T of Measurable d T & Finite T }. + +Definition measurableTypeUnit := unit. + +HB.instance Definition _ := Pointed.on measurableTypeUnit. +HB.instance Definition _ := Finite.on measurableTypeUnit. +HB.instance Definition _ := Measurable.on measurableTypeUnit. +HB.instance Definition _ := MeasurableFinite.on measurableTypeUnit. + +Definition measurableTypeBool := bool. + +HB.instance Definition _ := Pointed.on measurableTypeBool. +HB.instance Definition _ := Finite.on measurableTypeBool. +HB.instance Definition _ := Measurable.on measurableTypeBool. + +Module CASE_SUM. + +Section case_sum'. + +Section kcase_sum'. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Context dA (A : measurableCountType dA) dB (B : measurableCountType dB). +Variables (k1 : A -> R.-sfker X ~> Y) (k2 : B -> R.-sfker X ~> Y). + +Definition case_sum' : X * (A + B) -> {measure set Y -> \bar R} := + fun xab => match xab with + | (x, inl a) => k1 a x + | (x, inr b) => k2 b x + end. + +Let measurable_fun_case_sum' U : measurable U -> + measurable_fun setT (case_sum' ^~ U). +Proof. +rewrite /= => mU. +apply: (measurability _ (ErealGenInftyO.measurableE R)) => //. +move=> /= _ [_ [x ->] <-]; apply: measurableI => //. +rewrite /case_sum'/= (_ : _ @^-1` _ = + (\bigcup_a ([set x1 | k1 a x1 U < x%:E] `*` inl @` [set a])) `|` + (\bigcup_b ([set x1 | k2 b x1 U < x%:E] `*` inr @` [set b]))); last first. + apply/seteqP; split. + - move=> z/=; rewrite in_itv/=. + move: z => [z [a|b]]/= ?. + + by left; exists a => //; split => //=; exists a. + + by right; exists b => //; split => //=; exists b. + - move=> z/=; rewrite in_itv/=. + move: z => [z [a|b]]/= [|]. + + by case => a' _ /= [] /[swap] [] [_ ->] [->]. + + by case => b' _ /= [] b'x [_ ->]. + + by case => b' _ /= [] b'x [_ ->]. + + by case => b' _ /= [] /[swap] [] [_ ->] [->]. +apply: measurableU. +- pose h1 a := [set xub : X * (A + B) | k1 a xub.1 U < x%:E]. + apply: countable_bigcupT_measurable; first exact: countableP. + move=> a; apply: measurableX => //. + rewrite [X in measurable X](_ : _ = ysection (h1 a) (inl a)). + + apply: measurable_ysection. + rewrite -[X in measurable X]setTI. + apply: emeasurable_fun_infty_o => //= => _ /= C mC; rewrite setTI. + have : measurable_fun setT (fun x => k1 a x U) by exact/measurable_kernel. + move=> /(_ measurableT _ mC); rewrite setTI => H. + rewrite [X in measurable X](_ : _ = ((fun x => k1 a x U) @^-1` C) `*` setT)//. + exact: measurableX. + by apply/seteqP; split => [z//=| z/= []]. + + by rewrite ysectionE. +- pose h2 a := [set xub : X * (A + B)| k2 a xub.1 U < x%:E]. + apply: countable_bigcupT_measurable; first exact: countableP. + move=> b; apply: measurableX => //. + rewrite [X in measurable X](_ : _ = ysection (h2 b) (inr b))//. + + apply: measurable_ysection. + rewrite -[X in measurable X]setTI. + apply: emeasurable_fun_infty_o => //= _ /= C mC; rewrite setTI. + have : measurable_fun setT (fun x => k2 b x U) by exact/measurable_kernel. + move=> /(_ measurableT _ mC); rewrite setTI => H. + rewrite [X in measurable X](_ : _ = ((fun x => k2 b x U) @^-1` C) `*` setT)//. + exact: measurableX. + by apply/seteqP; split => [z //=|z/= []]. + + by rewrite ysectionE. +Qed. + +#[export] +HB.instance Definition _ := isKernel.Build _ _ _ _ _ + case_sum' measurable_fun_case_sum'. +End kcase_sum'. + +Section sfkcase_sum'. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Context dA (A : measurableFinType dA) dB (B : measurableFinType dB). +Variables (k1 : A -> R.-sfker X ~> Y) (k2 : B-> R.-sfker X ~> Y). + +Let sfinite_case_sum' : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> case_sum' k1 k2 x U = mseries (k_ ^~ x) 0 U. +Proof. +rewrite /=. +set f1 : A -> (R.-fker _ ~> _)^nat := + fun ab : A => sval (cid (sfinite_kernel (k1 ab))). +set Hf1 := fun ab : A => svalP (cid (sfinite_kernel (k1 ab))). +rewrite /= in Hf1. +set f2 : B -> (R.-fker _ ~> _)^nat := + fun ab : B => sval (cid (sfinite_kernel (k2 ab))). +set Hf2 := fun ab : B => svalP (cid (sfinite_kernel (k2 ab))). +rewrite /= in Hf2. +exists (fun n => case_sum' (f1 ^~ n) (f2 ^~ n)). + move=> n /=. + pose f1' a := sval (cid (measure_uub (f1 a n))). + pose f2' b := sval (cid (measure_uub (f2 b n))). + red. + exists (maxr (\big[Order.max/0%R]_a f1' a) (\big[Order.max/0%R]_b (f2' b)))%R. + move=> /= [x [a|b]]. + - have [bnd Hbnd] := measure_uub (f1 a n). + rewrite EFin_max lt_max; apply/orP; left. + rewrite /case_sum' -EFin_bigmax. + apply: lt_le_trans; last exact: le_bigmax_cond. + by rewrite /f1'; case: cid => /=. + - have [bnd Hbnd] := measure_uub (f2 b n). + rewrite EFin_max lt_max; apply/orP; right. + rewrite /case_sum' -EFin_bigmax. + apply: lt_le_trans; last exact: le_bigmax_cond. + by rewrite /f2'; case: cid => /=C. +move=> [x [a|b]] U mU/=-. +- by rewrite (Hf1 a x _ mU). +- by rewrite (Hf2 b x _ mU). +Qed. + +#[export] +HB.instance Definition _ := @isSFiniteKernel_subdef.Build _ _ _ _ _ + (case_sum' k1 k2) (sfinite_case_sum'). +End sfkcase_sum'. + +End case_sum'. + +Section case_sum. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Context dA (A : measurableFinType dA) dB (B : measurableFinType dB). + +Definition case_sum (f : R.-sfker X ~> (A + B)%type) + (k1 : A -> R.-sfker X ~> Y) (k2 : B -> R.-sfker X ~> Y) : R.-sfker X ~> Y := + f \; case_sum' k1 k2. + +End case_sum. + +End CASE_SUM. + +(* counting measure as a kernel *) +Section kcounting. +Context d (G : measurableType d) (R : realType). + +Definition kcounting : G -> {measure set nat -> \bar R} := fun=> counting. + +Let mkcounting U : measurable U -> measurable_fun setT (kcounting ^~ U). +Proof. by []. Qed. + +HB.instance Definition _ := isKernel.Build _ _ _ _ _ kcounting mkcounting. + +Let sfkcounting : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> kcounting x U = mseries (k_ ^~ x) 0 U. +Proof. +exists (fun n => [the R.-fker _ ~> _ of + @kdirac _ _ G nat R _ (@measurable_cst _ _ _ _ setT n)]). + by move=> n /=; exact: measure_uub. +by move=> g U mU; rewrite /kcounting/= counting_dirac. +Qed. + +HB.instance Definition _ := + isSFiniteKernel_subdef.Build _ _ _ _ R kcounting sfkcounting. + +End kcounting. + +(* formalization of the iterate construct [Section 4.2, Staton ESOP 2017] *) +Section iterate. +Context d {G : measurableType d} {R : realType}. +Context dA (A : measurableFinType dA) dB (B : measurableFinType dB). + +Import CASE_SUM. + +(* formalization of iterate^n + Gamma |-p iterate^n t from x = u : B *) +Variables (t : R.-sfker (G * A) ~> (A + B)%type) + (u : G -> A) (mu : measurable_fun setT u). + +Fixpoint iterate_ n : R.-sfker G ~> B := + match n with + | 0%N => case_sum (letin (ret mu) t) + (fun u' => fail) + (fun v => ret (measurable_cst v)) + | m.+1 => case_sum (letin (ret mu) t) + (fun u' => iterate_ m) + (fun v => fail) + end. + +(* formalization of iterate + Gamma, x : A |-p t : A + B Gamma |-d u : A +----------------------------------------------- + Gamma |-p iterate t from x = u : B *) +Definition iterate : R.-sfker G ~> B := case_nat (kcounting R) iterate_. + +End iterate. + +Section iterate_unit. + +Let unit := measurableTypeUnit. +Let bool := measurableTypeBool. +Context d {G : measurableType d} {R : realType}. +Context dB (B : measurableFinType dB). + +Section iterate_elim. +Variables (t : R.-sfker (G * unit) ~> (unit + B)%type) + (u : G -> unit) (mu : measurable_fun setT u). +Variables (r : R) (tlE : forall gamma, t (gamma, tt) [set inl tt] = r%:E). + +Variables (gamma : G) (X : set B) (q : R). +Hypothesis trE : t (gamma, tt) [set inr x | x in X] = q%:E. + +Let q_ge0 : (0 <= q)%R. Proof. by rewrite -lee_fin -trE measure_ge0. Qed. +Let r_ge0 : (0 <= r)%R. +Proof. by rewrite -lee_fin -(tlE gamma) measure_ge0. Qed. + +Lemma iterate_E n : iterate_ t mu n gamma X = (geometric q r n)%:E. +Proof. +elim: n => [|n IHn] //=; + rewrite /kcomp; rewrite integral_kcomp//=; + rewrite /= integral_dirac//= ?diracT ?mul1e ?expr0 ?exprS ?mulr1. + rewrite (eq_integral (EFin \o \1_[set inr x | x in X]))//=; last first. + move=> [a' _|b _]//=; last first. + by rewrite diracE indicE/= (mem_image inr_inj). + rewrite /kcomp/= indicE /= ge0_integral_mscale//= normr0 mul0e. + by rewrite [_ \in _](introF idP)// inE /= => -[]. + by rewrite ?unitE integral_indic//= setIT. +pose g : unit + B -> R^o := (geometric q r n \o* \1_[set inl tt])%R. +rewrite (eq_integral (EFin \o g))//=; last first. + move=> [[] _|b _]//=. + by rewrite /g/= indicE//= in_set1 inE eqxx mul1r. + rewrite /kcomp/= ge0_integral_mscale//= normr0 mul0e. + by rewrite /g /= indicE//= in_set1 inE mul0r. +rewrite /g /=; under eq_integral do rewrite EFinM. +rewrite integralZr//=; last first. + apply/integrableP; split=> //. + under eq_integral => x. + rewrite gee0_abs//=. + over. + by rewrite /= integral_indic// setIT [u gamma]unitE tlE ltey. +by rewrite integral_indic//= [u gamma]unitE setIT tlE -EFinM mulrCA. +Qed. + +Hypothesis r_lt1 : (r < 1)%R. + +Lemma iterateE : iterate t mu gamma X = (q / (1 - r))%:E. +Proof. +rewrite /= /kcomp/= /case_nat_/= /mseries. +under eq_integral => n _. + under (@congr_lim _ _ _ \o @eq_fun _ _ _ _) => k. + under eq_bigr do rewrite fun_if/= (fun_if (@^~ _))/mzero eq_sym. + rewrite -big_mkcond/= big_nat1_eq. + over. +over. +rewrite /= (eq_integral (EFin \o geometric q r))//=; last first. + move=> k _; apply/lim_near_cst => //; rewrite iterate_E ?r_ge0 ?r_lt1//. + by near do rewrite ifT//. +have cvgg: series (geometric q r) x @[x --> \oo] --> (q / (1 - r))%R. + by apply/cvg_geometric_series; rewrite ger0_norm ?r_lt1//. +have limgg := cvg_lim (@Rhausdorff R) cvgg. +have sumgE : \big[+%R/0%R]_(0 <= k x; rewrite inE trueE. +rewrite -(@nneseries_esum _ _ predT)//=. +under eq_eseriesr do rewrite ger0_norm// ?geometric_ge0//. +by rewrite sumgE ltey. +Unshelve. all: end_near. Qed. + +End iterate_elim. + +Import CASE_SUM. + +Variables (t : R.-pker (G * unit) ~> (unit + B)%type) + (u : G -> unit) (mu : measurable_fun setT u). +Variables (r : R) (r_lt1 : (r < 1)%R). +Hypothesis (tlE : forall gamma, t (gamma, tt) [set inl tt] = r%:E). + +Let trE gamma X : t (gamma, tt) [set inr x | x in X] \in fin_num. +Proof. +apply/fin_numPlt; rewrite (@lt_le_trans _ _ 0)//=. +rewrite (@le_lt_trans _ _ 1)//= ?ltey//. +rewrite -( @prob_kernel _ _ _ _ _ t (gamma, tt) ). +by apply/le_measure => //=; rewrite inE//=. +Qed. + +Lemma iterate_normalize p : + iterate t mu = knormalize (case_sum (letin (ret mu) t) + (fun u' => fail) + (fun v => ret (measurable_cst v))) p. +Proof. +apply/eq_sfkernel => gamma U. +have /EFin_fin_numP[q trE'] := trE gamma U. +rewrite (iterateE mu tlE trE')//; symmetry. +rewrite /= /mnormalize/= (fun_if (@^~ U))/=. +set m := kcomp _ _ _. +have mE V : m V = t (gamma, tt) [set inr x | x in V]. + rewrite /m/= /kcomp/= integral_kcomp//= integral_dirac//= diracT mul1e. + rewrite (eq_integral (EFin \o \1_[set inr x | x in V])). + by rewrite integral_indic ?setIT ?unitE. + move=> [x|x] xV /=; rewrite indicE. + rewrite ?inl_in_set_inr /kcomp/=. + by rewrite ge0_integral_mscale//= ?normr0 mul0e. + by rewrite inr_in_set_inr// indicE. +rewrite !mE trE'. +suff -> : t (gamma, tt) (range inr) = 1 - t (gamma, tt) [set inl tt]. + by rewrite tlE -EFinB/= orbF eqe subr_eq0 eq_sym lt_eqF. +rewrite -( @prob_kernel _ _ _ _ _ t (gamma, tt) ). +have -> : [set: unit + B] = [set inl tt] `|` (range inr). + symmetry; apply/eq_set => -[[]|b]//=; apply/propT; first by left. + by right; exists b. +rewrite measureU//=; first by rewrite addeAC subee ?add0e// ?tlE//. +by apply/eq_set => -[[]|b]//=; apply/propF; case=> []// _ []. +Qed. + +End iterate_unit. + +Section lift_neq. +Context {R : realType} d (G : measurableType d). +Variables (f : G -> bool) (g : G -> bool). + +Definition flift_neq : G -> bool := fun x' => f x' != g x'. + +Hypotheses (mf : measurable_fun setT f) (mg : measurable_fun setT g). + +(* see also emeasurable_fun_neq *) +Lemma measurable_fun_flift_neq : measurable_fun setT flift_neq. +Proof. +apply: (@measurable_fun_bool _ _ _ _ true). +rewrite setTI. +rewrite /flift_neq /= (_ : _ @^-1` _ = ([set x | f x] `&` [set x | ~~ g x]) `|` + ([set x | ~~ f x] `&` [set x | g x])). + apply: measurableU; apply: measurableI. + - by rewrite -[X in measurable X]setTI; exact: mf. + - rewrite [X in measurable X](_ : _ = ~` [set x | g x]); last first. + by apply/seteqP; split => x /= /negP. + by apply: measurableC; rewrite -[X in measurable X]setTI; exact: mg. + - rewrite [X in measurable X](_ : _ = ~` [set x | f x]); last first. + by apply/seteqP; split => x /= /negP. + by apply: measurableC; rewrite -[X in measurable X]setTI; exact: mf. + - by rewrite -[X in measurable X]setTI; exact: mg. +by apply/seteqP; split => x /=; move: (f x) (g x) => [|] [|]//=; intuition. +Qed. + +Definition lift_neq : R.-sfker G ~> bool := ret measurable_fun_flift_neq. + +End lift_neq. + +Section insn1_lemmas. +Import Notations. +Context d (T : measurableType d) (R : realType). + +Let kcomp_scoreE d1 d2 (T1 : measurableType d1) (T2 : measurableType d2) + (g : R.-sfker [the measurableType _ of (T1 * unit)%type] ~> T2) + f (mf : measurable_fun setT f) r U : + (score mf \; g) r U = `|f r|%:E * g (r, tt) U. +Proof. +rewrite /= /kcomp /kscore /= ge0_integral_mscale//=. +by rewrite integral_dirac// diracT mul1e. +Qed. + +Lemma scoreE d' (T' : measurableType d') (x : T * T') (U : set T') (f : R -> R) + (r : R) (r0 : (0 <= r)%R) + (f0 : (forall r, 0 <= r -> 0 <= f r)%R) (mf : measurable_fun setT f) : + score (measurableT_comp mf (@macc1of2 _ _ _ _)) + (x, r) (curry (snd \o fst) x @^-1` U) = + (f r)%:E * \d_x.2 U. +Proof. +by rewrite /score/= /mscale/= ger0_norm//= f0. +Qed. + +Lemma score_score (f : R -> R) (g : R * unit -> R) + (mf : measurable_fun [set: R] f) + (mg : measurable_fun [set: R * unit] g) : + letin (score mf) (score mg) = + score (measurable_funM mf (measurableT_comp mg (pair2_measurable tt))). +Proof. +apply/eq_sfkernel => x U. +rewrite {1}/letin; unlock. +by rewrite kcomp_scoreE/= /mscale/= diracE normrM muleA EFinM. +Qed. + +(* hard constraints to express score below 1 *) +Lemma score_fail (r : R) : (0 <= r <= 1)%R -> + score (kr r) = + letin (sample_cst (bernoulli_prob r) : R.-pker T ~> _) + (ite (@macc1of2 _ _ _ _) (ret ktt) fail). +Proof. +move=> /andP[r0 r1]; apply/eq_sfkernel => x U. +rewrite letinE/= /sample; unlock. +rewrite /mscale/= ger0_norm//. +by rewrite integral_bernoulli_prob ?r0//= 2!iteE//= failE mule0 adde0. +Qed. + +End insn1_lemmas. + +Section letin_ite. +Context d d2 d3 (T : measurableType d) (T2 : measurableType d2) + (Z : measurableType d3) (R : realType). +Variables (k1 k2 : R.-sfker T ~> Z) + (u : R.-sfker [the measurableType _ of (T * Z)%type] ~> T2) + (f : T -> bool) (mf : measurable_fun setT f) + (t : T) (U : set T2). + +Lemma letin_iteT : f t -> letin (ite mf k1 k2) u t U = letin k1 u t U. +Proof. +move=> ftT; rewrite !letinE/=; apply: eq_measure_integral => V mV _. +by rewrite iteE ftT. +Qed. + +Lemma letin_iteF : ~~ f t -> letin (ite mf k1 k2) u t U = letin k2 u t U. +Proof. +move=> ftF; rewrite !letinE/=; apply: eq_measure_integral => V mV _. +by rewrite iteE (negbTE ftF). +Qed. + +End letin_ite. + +(* associativity of let [Section 4.2, Staton ESOP 2017] *) +Section letinA. +Context d d' d1 d2 d3 (X : measurableType d) (Y : measurableType d') + (T1 : measurableType d1) (T2 : measurableType d2) (T3 : measurableType d3) + (R : realType). +Import Notations. +Variables (t : R.-sfker X ~> T1) + (u : R.-sfker (X * T1) ~> T2) + (v : R.-sfker (X * T2) ~> Y) + (v' : R.-sfker (X * T1 * T2) ~> Y) + (vv' : forall y, v =1 fun xz => v' (xz.1, y, xz.2)). + +Lemma letinA x A : measurable A -> + letin t (letin u v') x A + = + (letin (letin t u) v) x A. +Proof. +move=> mA. +rewrite !letinE. +under eq_integral do rewrite letinE. +rewrite integral_kcomp; [|by []|]. +- apply: eq_integral => y _. + apply: eq_integral => z _. + by rewrite (vv' y). +- exact: (measurableT_comp (measurable_kernel v _ mA)). +Qed. + +End letinA. + +(* commutativity of let [Section 4.2, Staton ESOP 2017] *) +Section letinC. +Context d d1 d' (X : measurableType d) (Y : measurableType d1) + (Z : measurableType d') (R : realType). + +Import Notations. + +Variables (t : R.-sfker Z ~> X) + (t' : R.-sfker [the measurableType _ of (Z * Y)%type] ~> X) + (tt' : forall y, t =1 fun z => t' (z, y)) + (u : R.-sfker Z ~> Y) + (u' : R.-sfker [the measurableType _ of (Z * X)%type] ~> Y) + (uu' : forall x, u =1 fun z => u' (z, x)). + +Definition T z : set X -> \bar R := t z. +Let T0 z : (T z) set0 = 0. Proof. by []. Qed. +Let T_ge0 z x : 0 <= (T z) x. Proof. by []. Qed. +Let T_semi_sigma_additive z : semi_sigma_additive (T z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ X R (T z) (T0 z) (T_ge0 z) + (@T_semi_sigma_additive z). + +Let sfinT z : sfinite_measure (T z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @isSFinite.Build _ X R (T z) (sfinT z). + +Definition U z : set Y -> \bar R := u z. +Let U0 z : (U z) set0 = 0. Proof. by []. Qed. +Let U_ge0 z x : 0 <= (U z) x. Proof. by []. Qed. +Let U_semi_sigma_additive z : semi_sigma_additive (U z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ Y R (U z) (U0 z) (U_ge0 z) + (@U_semi_sigma_additive z). + +Let sfinU z : sfinite_measure (U z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @isSFinite.Build _ Y R (U z) (sfinU z). + +Lemma letinC z A : measurable A -> + letin t + (letin u' + (ret (measurable_fun_pair macc1of3 macc2of3))) z A = + letin u + (letin t' + (ret (measurable_fun_pair macc2of3 macc1of3))) z A. +Proof. +move=> mA. +rewrite !letinE. +under eq_integral. + move=> x _. + rewrite letinE -uu'. + under eq_integral do rewrite retE /=. + over. +rewrite (sfinite_Fubini + [the {sfinite_measure set X -> \bar R} of T z] + [the {sfinite_measure set Y -> \bar R} of U z] + (fun x => \d_(x.1, x.2) A ))//; last first. + apply/measurable_EFinP => /=; rewrite (_ : (fun x => _) = mindic R mA)//. + by apply/funext => -[]. +rewrite /=. +apply: eq_integral => y _. +by rewrite letinE/= -tt'; apply: eq_integral => // x _; rewrite retE. +Qed. + +End letinC. + +(* examples *) + +Lemma letin_sample_bernoulli_prob d d' (T : measurableType d) + (T' : measurableType d') (R : realType) (r : R) + (u : R.-sfker [the measurableType _ of (T * bool)%type] ~> T') x y : + (0 <= r <= 1)%R -> + letin (sample_cst (bernoulli_prob r)) u x y = + r%:E * u (x, true) y + (`1- r)%:E * u (x, false) y. +Proof. by move=> r01; rewrite letinE/= integral_bernoulli_prob. Qed. + +Section sample_and_return. +Import Notations. +Context d (T : measurableType d) (R : realType). + +Definition sample_and_return : R.-sfker T ~> _ := + letin + (sample_cst (bernoulli_prob (2 / 7))) (* T -> B *) + (ret macc1of2) (* T * B -> B *). + +Lemma sample_and_returnE t U : sample_and_return t U = + (2 / 7%:R)%:E * \d_true U + (5%:R / 7%:R)%:E * \d_false U. +Proof. +rewrite /sample_and_return letin_sample_bernoulli_prob; last lra. +by rewrite !retE onem27. +Qed. + +End sample_and_return. + +Section sample_and_branch. +Import Notations. +Context d (T : measurableType d) (R : realType). + +(* let x = sample (bernoulli (2/7)) in + let r = case x of {(1, _) => return (k3()), (2, _) => return (k10())} in + return r *) +Definition sample_and_branch : R.-sfker T ~> _ := + letin + (sample_cst (bernoulli_prob (2 / 7))) (* T -> B *) + (ite macc1of2 (ret (@k3 _ _ R)) (ret k10)). + +Lemma sample_and_branchE t U : sample_and_branch t U = + (2 / 7)%:E * \d_(3%R : R) U + (5 / 7)%:E * \d_(10%R : R) U. +Proof. +rewrite /sample_and_branch letin_sample_bernoulli_prob/=; last lra. +by rewrite !iteE/= onem27. +Qed. + +End sample_and_branch. + +Section bernoulli_and. +Context d (T : measurableType d) (R : realType). +Import Notations. + +Definition bernoulli_prob_and : R.-sfker T ~> mbool := + (letin (sample_cst (bernoulli_prob (1 / 2))) + (letin (sample_cst (bernoulli_prob (1 / 2))) + (ret (measurable_and macc1of3 macc2of3)))). + +Lemma bernoulli_prob_andE t U : + bernoulli_prob_and t U = sample_cst (bernoulli_prob (1 / 4)) t U. +Proof. +rewrite /bernoulli_prob_and. +rewrite letin_sample_bernoulli_prob; last lra. +rewrite (letin_sample_bernoulli_prob (r := 1 / 2)); last lra. +rewrite (letin_sample_bernoulli_prob (r := 1 / 2)); last lra. +rewrite muleDr//= -muleDl//. +rewrite !muleA -addeA -muleDl// -!EFinM !onem1S/= -splitr mulr1. +have -> : (1 / 2 * (1 / 2) = 1 / 4%:R :> R)%R by rewrite mulf_div mulr1// -natrM. +rewrite [in RHS](_ : 1 / 4 = (1 / 4)%:nng%:num)%R//. +rewrite bernoulli_probE/=; last lra. +rewrite -!EFinM; congr( _ + (_ * _)%:E). +by rewrite /onem; lra. +Qed. + +End bernoulli_and. + +Section staton_bus. +Import Notations. +Context d (T : measurableType d) (R : realType) (h : R -> R). +Hypothesis mh : measurable_fun setT h. +Definition kstaton_bus : R.-sfker T ~> mbool := + letin (sample_cst (bernoulli_prob (2 / 7))) + (letin + (letin (ite macc1of2 (ret k3) (ret k10)) + (score (measurableT_comp mh macc2of3))) + (ret macc1of3)). + +Definition staton_bus := normalize kstaton_bus. + +End staton_bus. + +(* let x = sample (bernoulli (2/7)) in + let r = case x of {(1, _) => return (k3()), (2, _) => return (k10())} in + let _ = score (1/4! r^4 e^-r) in + return x *) +Section staton_bus_poisson. +Import Notations. +Context d (T : measurableType d) (R : realType). +Let poisson4 r := @poisson_pmf R r 4%N. +Let mpoisson4 := @measurable_poisson_pmf R setT 4%N measurableT. + +Definition kstaton_bus_poisson : R.-sfker R ~> mbool := + kstaton_bus _ mpoisson4. + +Let kstaton_bus_poissonE t U : kstaton_bus_poisson t U = + (2 / 7)%:E * (poisson4 3)%:E * \d_true U + + (5 / 7)%:E * (poisson4 10)%:E * \d_false U. +Proof. +rewrite /kstaton_bus_poisson /kstaton_bus. +rewrite letin_sample_bernoulli_prob; last lra. +rewrite -!muleA; congr (_ * _ + _ * _). +- rewrite letin_kret//. + rewrite letin_iteT//. + rewrite letin_retk//. + by rewrite scoreE//= => r r0; exact: poisson_pmf_ge0. +- by rewrite onem27. + rewrite letin_kret//. + rewrite letin_iteF//. + rewrite letin_retk//. + by rewrite scoreE//= => r r0; exact: poisson_pmf_ge0. +Qed. + +(* true -> 2/7 * 0.168 = 2/7 * 3^4 e^-3 / 4! *) +(* false -> 5/7 * 0.019 = 5/7 * 10^4 e^-10 / 4! *) + +Lemma staton_busE P (t : R) U : + let N := ((2 / 7) * poisson4 3 + + (5 / 7) * poisson4 10)%R in + staton_bus mpoisson4 P t U = + ((2 / 7)%:E * (poisson4 3)%:E * \d_true U + + (5 / 7)%:E * (poisson4 10)%:E * \d_false U) * N^-1%:E. +Proof. +rewrite /staton_bus normalizeE !kstaton_bus_poissonE !diracT !mule1 ifF //. +apply/negbTE; rewrite gt_eqF// lte_fin. +by rewrite addr_gt0// mulr_gt0//= ?divr_gt0// ?ltr0n// poisson_pmf_gt0// ltr0n. +Qed. + +End staton_bus_poisson. + +(* let x = sample (bernoulli (2/7)) in + let r = case x of {(1, _) => return (k3()), (2, _) => return (k10())} in + let _ = score (r e^-(15/60 r)) in + return x *) +Section staton_bus_exponential. +Import Notations. +Context d (T : measurableType d) (R : realType). +Let exp1560 := @exponential_pdf R (ratr (15%:Q / 60%:Q)). +Let mexp1560 := @measurable_exponential_pdf R (ratr (15%:Q / 60%:Q)). + +(* 15/60 = 0.25 *) + +Definition kstaton_bus_exponential : R.-sfker R ~> mbool := + kstaton_bus _ mexp1560. + +Let kstaton_bus_exponentialE t U : kstaton_bus_exponential t U = + (2 / 7)%:E * (exp1560 3)%:E * \d_true U + + (5 / 7)%:E * (exp1560 10)%:E * \d_false U. +Proof. +rewrite /kstaton_bus. +rewrite letin_sample_bernoulli_prob; last lra. +rewrite -!muleA; congr (_ * _ + _ * _). +- rewrite letin_kret//. + rewrite letin_iteT//. + rewrite letin_retk//. + rewrite scoreE//= => r r0; apply: exponential_pdf_ge0 => //. + by rewrite ler0q; lra. +- by rewrite onem27. + rewrite letin_kret//. + rewrite letin_iteF//. + rewrite letin_retk//. + rewrite scoreE//= => r r0; apply: exponential_pdf_ge0. + by rewrite ler0q; lra. +Qed. + +(* true -> 5/7 * 0.019 = 5/7 * 10^4 e^-10 / 4! *) +(* false -> 2/7 * 0.168 = 2/7 * 3^4 e^-3 / 4! *) + +Lemma staton_bus_exponentialE P (t : R) U : + let N := ((2 / 7) * exp1560 3 + + (5 / 7) * exp1560 10)%R in + staton_bus mexp1560 P t U = + ((2 / 7)%:E * (exp1560 3)%:E * \d_true U + + (5 / 7)%:E * (exp1560 10)%:E * \d_false U) * N^-1%:E. +Proof. +rewrite /staton_bus. +rewrite normalizeE /= !kstaton_bus_exponentialE !diracT !mule1 ifF //. +apply/negbTE; rewrite gt_eqF// lte_fin. +by rewrite addr_gt0// mulr_gt0//= ?divr_gt0// ?ltr0n//; + rewrite exponential_pdf_gt0 ?ltr0n// ltr0q; lra. +Qed. + +End staton_bus_exponential. + +Section von_neumann_trick. +Context d {T : measurableType d} {R : realType}. + +Definition minltt {d1 d2} {T1 : measurableType d1} {T2 : measurableType d2} := + @measurable_cst _ _ T1 _ setT (@inl _ T2 tt). + +Definition finrb d1 d2 (T1 : measurableType d1) (T2 : measurableType d2) : + T1 * bool -> T2 + bool := fun t1b => inr t1b.2. + +Lemma minrb {d1 d2} {T1 : measurableType d1} {T2 : measurableType d2} : + measurable_fun setT (@finrb _ _ T1 T2). +Proof. exact: measurableT_comp. Qed. + +Variable (D : pprobability bool R). (* biased coin *) +Let unit := measurableTypeUnit. +Let bool := measurableTypeBool. + +Definition trick : R.-sfker (T * unit) ~> (unit + bool)%type := + letin (sample_cst D) + (letin (sample_cst D) + (letin (lift_neq macc1of3 macc2of3) + (ite macc3of4 + (letin (ret macc1of4) (ret minrb)) + (ret minltt)))). + +HB.instance Definition _ := SFiniteKernel.on trick. +HB.instance Definition _ x := Measure.on (trick x). + +Definition kvon_neumann_trick : _ -> _ := + (@iterate _ _ R _ unit _ bool trick _ ktt). +Definition von_neumann_trick x : _ -> _ := kvon_neumann_trick x. + +HB.instance Definition _ := SFiniteKernel.on kvon_neumann_trick. +HB.instance Definition _ x := Measure.on (von_neumann_trick x). + +Section von_neumann_trick_proof. + +Let p : R := fine (D [set true]). +Let q : R := p * (1 - p). +Let r : R := p ^+ 2 + (1 - p) ^+ 2. + +Let Dtrue : D [set true] = p%:E. +Proof. by rewrite fineK//= fin_num_measure. Qed. + +Lemma trickE gamma X : trick gamma X = + (r *+ (inl tt \in X) + + q *+ ((inr true \in X) + (inr false \in X)))%:E. +Proof. +have Dbernoulli : D =1 bernoulli_prob p by exact/eq_bernoulli_prob/Dtrue. +have p_itv01 : (0 <= p <= 1)%R. + by rewrite -2!lee_fin -Dtrue?measure_ge0 ?probability_le1. +pose eqbern := eq_measure_integral _ (fun x _ _ => Dbernoulli x). +rewrite /trick/= /kcomp. +do 2?rewrite ?eqbern ?integral_bernoulli_prob//= /kcomp/=. +rewrite !integral_dirac ?diracT//= ?mul1e. +rewrite !iteE//= ?diracE/= /kcomp/=. +rewrite !integral_dirac /acc1of4/= ?diracT ?diracE ?mul1e//. +rewrite /finrb /acc1of4/= -?(EFinB, EFinN, EFinM, EFinD) /q /r /onem. +by congr (_)%:E; do 3!move: (_ \in _) => ? /=; ring. +Qed. + +Lemma trick_prob_kernelT gamma : trick gamma setT = 1. +Proof. +by rewrite trickE !mem_setT mulr2n mulr1n /r /q; congr (_)%:E; ring. +Qed. + +HB.instance Definition _ gamma := Measure_isProbability.Build _ _ _ + (trick gamma) (trick_prob_kernelT gamma). +HB.instance Definition _ := Kernel_isProbability.Build _ _ _ _ _ + trick trick_prob_kernelT. + +Hypothesis D_nontrivial : 0 < D [set true] < 1. + +Let p_gt0 : (0 < p)%R. +Proof. by rewrite -lte_fin -Dtrue; case/andP : D_nontrivial. Qed. + +Let p_lt1 : (p < 1)%R. +Proof. by rewrite -lte_fin -Dtrue; case/andP : D_nontrivial. Qed. + +Let p'_gt0 : (0 < 1 - p)%R. Proof. by rewrite subr_gt0. Qed. + +Let r_lt1 : (r < 1)%R. +Proof. +rewrite /r -subr_gt0 [ltRHS](_ : _ = 2 * p * (1 - p))%R; last by ring. +by rewrite !mulr_gt0. +Qed. + +Lemma von_neumann_trick_prob_kernel gamma b : + kvon_neumann_trick gamma [set b] = 2^-1%:E. +Proof. +rewrite [LHS](@iterateE _ _ _ _ _ _ _ _ r _ _ _ q)//=. +- rewrite /r /q; congr (_)%:E. + suff: (1 - ((p ^+ 2)%R + ((1 - p) ^+ 2)%R)%E)%R != 0%R by move=> *; field. + rewrite [X in X != _](_ : _ = 2 * (p * (1 - p)))%R; last by ring. + by rewrite mulf_eq0 ?pnatr_eq0/= mulf_neq0// gt_eqF ?p_gt0 ?p'_gt0. +- by move=> gamma'; rewrite trickE//= ?in_set1 ?inE//= addr0. +- rewrite trickE/= ?inl_in_set_inr ?inr_in_set_inr// add0r !in_set1 !inE. + by case: b. +Qed. + +Lemma von_neumann_trick_prob_kernelT gamma : + von_neumann_trick gamma [set: bool] = 1. +Proof. +rewrite setT_bool measureU//=; last by rewrite disjoints_subset => -[]. +rewrite !von_neumann_trick_prob_kernel -EFinD. +by have := splitr (1 : R); rewrite mul1r => <-. +Qed. + +HB.instance Definition _ gamma := Measure.on (von_neumann_trick gamma). +HB.instance Definition _ gamma := Measure_isProbability.Build _ _ _ + (von_neumann_trick gamma) (von_neumann_trick_prob_kernelT gamma). +HB.instance Definition _ := Kernel_isProbability.Build _ _ _ _ _ + kvon_neumann_trick von_neumann_trick_prob_kernelT. + +Theorem von_neumann_trickP gamma : von_neumann_trick gamma =1 bernoulli_prob 2^-1. +Proof. by apply: eq_bernoulli_prob; rewrite von_neumann_trick_prob_kernel. Qed. + +End von_neumann_trick_proof. + +End von_neumann_trick. + +(**md + letin' variants +*) + +Section mswap. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). +Variable k : R.-ker Y * X ~> Z. + +Definition mswap xy U := k (swap xy) U. + +Let mswap0 xy : mswap xy set0 = 0. +Proof. done. Qed. + +Let mswap_ge0 x U : 0 <= mswap x U. +Proof. done. Qed. + +Let mswap_sigma_additive x : semi_sigma_additive (mswap x). +Proof. exact: measure_semi_sigma_additive. Qed. + +HB.instance Definition _ x := isMeasure.Build _ _ R + (mswap x) (mswap0 x) (mswap_ge0 x) (@mswap_sigma_additive x). + +Definition mkswap : _ -> {measure set Z -> \bar R} := + fun x => mswap x. + +Let measurable_fun_kswap U : + measurable U -> measurable_fun setT (mkswap ^~ U). +Proof. +move=> mU. +rewrite [X in measurable_fun _ X](_ : _ = k ^~ U \o @swap _ _)//. +apply measurableT_comp => //=; first exact: measurable_kernel. +exact: measurable_swap. +Qed. + +HB.instance Definition _ := isKernel.Build _ _ + (X * Y)%type Z R mkswap measurable_fun_kswap. + +End mswap. + +Section mswap_sfinite_kernel. +Variables (d d' d3 : _) (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). +Variable k : R.-sfker Y * X ~> Z. + +Let mkswap_sfinite : + exists2 k_ : (R.-ker X * Y ~> Z)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> mkswap k x U = kseries k_ x U. +Proof. +have [k_ /= kE] := sfinite_kernel k. +exists (fun n => mkswap (k_ n)). + move=> n. + have /measure_fam_uubP[M hM] := measure_uub (k_ n). + by exists M%:num => x/=; exact: hM. +move=> xy U mU. +by rewrite /mswap/= kE. +Qed. + +HB.instance Definition _ := + isSFiniteKernel_subdef.Build _ _ _ Z R (mkswap k) mkswap_sfinite. + +End mswap_sfinite_kernel. + +Section kswap_finite_kernel_finite. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType) + (k : R.-fker Y * X ~> Z). + +Let mkswap_finite : measure_fam_uub (mkswap k). +Proof. +have /measure_fam_uubP[r hr] := measure_uub k. +apply/measure_fam_uubP; exists (PosNum [gt0 of r%:num%R]) => x /=. +exact: hr. +Qed. + +HB.instance Definition _ := + Kernel_isFinite.Build _ _ _ Z R (mkswap k) mkswap_finite. + +End kswap_finite_kernel_finite. + +Reserved Notation "f .; g" (at level 60, right associativity, + format "f .; '/ ' g"). + +Notation "l .; k" := (mkcomp l (mkswap k)) : ereal_scope. + +Section letin'. +Variables (d d' d3 : _) (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Definition letin' (l : R.-sfker X ~> Y) (k : R.-sfker Y * X ~> Z) := + locked [the R.-sfker X ~> Z of l .; k]. + +Lemma letin'E (l : R.-sfker X ~> Y) (k : R.-sfker Y * X ~> Z) x U : + letin' l k x U = \int[l x]_y k (y, x) U. +Proof. by rewrite /letin'; unlock. Qed. + +Lemma letin'_letin (l : R.-sfker X ~> Y) (k : R.-sfker Y * X ~> Z) : + letin' l k = letin l (mkswap k). +Proof. by rewrite /letin'; unlock. Qed. + +End letin'. + +Section letin'C. +Import Notations. +Context d d1 d' (X : measurableType d) (Y : measurableType d1) + (Z : measurableType d') (R : realType). +Variables (t : R.-sfker Z ~> X) + (u' : R.-sfker X * Z ~> Y) + (u : R.-sfker Z ~> Y) + (t' : R.-sfker Y * Z ~> X) + (tt' : forall y, t =1 fun z => t' (y, z)) + (uu' : forall x, u =1 fun z => u' (x, z)). + +Definition T' z : set X -> \bar R := t z. +Let T0 z : (T' z) set0 = 0. Proof. by []. Qed. +Let T_ge0 z x : 0 <= (T' z) x. Proof. by []. Qed. +Let T_semi_sigma_additive z : semi_sigma_additive (T' z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ X R (T' z) (T0 z) (T_ge0 z) + (@T_semi_sigma_additive z). + +Let sfinT z : sfinite_measure (T' z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @isSFinite.Build _ X R (T' z) (sfinT z). + +Definition U' z : set Y -> \bar R := u z. +Let U0 z : (U' z) set0 = 0. Proof. by []. Qed. +Let U_ge0 z x : 0 <= (U' z) x. Proof. by []. Qed. +Let U_semi_sigma_additive z : semi_sigma_additive (U' z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ Y R (U' z) (U0 z) (U_ge0 z) + (@U_semi_sigma_additive z). + +Let sfinU z : sfinite_measure (U' z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @isSFinite.Build _ Y R + (U' z) (sfinU z). + +Lemma letin'C z A : measurable A -> + letin' t + (letin' u' + (ret (measurable_fun_pair macc1of3' macc0of3'))) z A = + letin' u + (letin' t' + (ret (measurable_fun_pair macc0of3' macc1of3'))) z A. +Proof. +move=> mA. +rewrite !letin'E. +under eq_integral. + move=> x _. + rewrite letin'E -uu'. + under eq_integral do rewrite retE /=. + over. +rewrite (sfinite_Fubini (T' z) (U' z) (fun x => \d_(x.1, x.2) A ))//; last first. + apply/measurable_EFinP => /=; rewrite (_ : (fun x => _) = mindic R mA)//. + by apply/funext => -[]. +rewrite /=. +apply: eq_integral => y _. +by rewrite letin'E/= -tt'; apply: eq_integral => // x _; rewrite retE. +Qed. + +End letin'C. +Arguments letin'C {d d1 d' X Y Z R} _ _ _ _. + +Section letin'A. +Context d d' d1 d2 d3 (X : measurableType d) (Y : measurableType d') + (T1 : measurableType d1) (T2 : measurableType d2) (T3 : measurableType d3) + (R : realType). +Import Notations. +Variables (t : R.-sfker X ~> T1) + (u : R.-sfker T1 * X ~> T2) + (v : R.-sfker T2 * X ~> Y) + (v' : R.-sfker T2 * (T1 * X) ~> Y) + (vv' : forall y, v =1 fun xz => v' (xz.1, (y, xz.2))). + +Lemma letin'A x A : measurable A -> + letin' t (letin' u v') x A + = + (letin' (letin' t u) v) x A. +Proof. +move=> mA. +rewrite !letin'E. +under eq_integral do rewrite letin'E. +rewrite letin'_letin/=. +rewrite integral_kcomp; [|by []|]. + apply: eq_integral => z _. + apply: eq_integral => y _. + by rewrite (vv' z). +exact: measurableT_comp (@measurable_kernel _ _ _ _ _ v _ mA) _. +Qed. + +End letin'A. + +Lemma letin'_sample_bernoulli_prob d d' (T : measurableType d) + (T' : measurableType d') (R : realType) (r : R) (r01 : (0 <= r <= 1)%R) + (u : R.-sfker bool * T ~> T') x y : + letin' (sample_cst (bernoulli_prob r)) u x y = + r%:E * u (true, x) y + (`1- r)%:E * u (false, x) y. +Proof. by rewrite letin'_letin letin_sample_bernoulli_prob. Qed. + +Section letin'_return. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Lemma letin'_kret (k : R.-sfker X ~> Y) + (f : Y * X -> Z) (mf : measurable_fun setT f) x U : + measurable U -> + letin' k (ret mf) x U = k x (curry f ^~ x @^-1` U). +Proof. +move=> mU. +rewrite letin'E. +under eq_integral do rewrite retE. +rewrite integral_indic ?setIT// -[X in measurable X]setTI. +exact: (measurableT_comp mf). +Qed. + +Lemma letin'_retk (f : X -> Y) (mf : measurable_fun setT f) + (k : R.-sfker Y * X ~> Z) x U : + measurable U -> letin' (ret mf) k x U = k (f x, x) U. +Proof. +move=> mU; rewrite letin'E retE integral_dirac ?diracT ?mul1e//. +exact: (measurableT_comp (measurable_kernel k _ mU)). +Qed. + +End letin'_return. + +Section letin'_ite. +Context d d2 d3 (T : measurableType d) (T2 : measurableType d2) + (Z : measurableType d3) (R : realType). +Variables (k1 k2 : R.-sfker T ~> Z) + (u : R.-sfker Z * T ~> T2) + (f : T -> bool) (mf : measurable_fun setT f) + (t : T) (U : set T2). + +Lemma letin'_iteT : f t -> letin' (ite mf k1 k2) u t U = letin' k1 u t U. +Proof. by move=> ftT; rewrite !letin'_letin letin_iteT. Qed. + +Lemma letin'_iteF : ~~ f t -> letin' (ite mf k1 k2) u t U = letin' k2 u t U. +Proof. by move=> ftF; rewrite !letin'_letin letin_iteF. Qed. + +End letin'_ite. + +Section hard_constraint'. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Definition fail' : R.-sfker X ~> Y := + letin' (score (measurable_cst (0%R : R))) + (ret (measurable_cst point)). + +Lemma fail'E x U : fail' x U = 0. +Proof. by rewrite /fail' letin'_letin failE. Qed. + +End hard_constraint'. +Arguments fail' {d d' X Y R}. + +Lemma score_fail' d (X : measurableType d) {R : realType} + (r : R) (r01 : (0 <= r <= 1)%R) : + score (kr r) = + letin' (sample_cst (bernoulli_prob r) : R.-pker X ~> _) + (ite macc0of2 (ret ktt) fail'). +Proof. +move: r01 => /andP[r0 r1]; apply/eq_sfkernel => x U. +rewrite letin'E/= /sample; unlock. +rewrite integral_bernoulli_prob ?r0//=. +by rewrite /mscale/= iteE//= iteE//= fail'E mule0 adde0 ger0_norm. +Qed. + +(* TODO: move to probability.v? *) +Section gauss. +Variable R : realType. +Local Open Scope ring_scope. + +Definition gauss_pdf := @normal_pdf R 0 1. + +Lemma normal_pdf_gt0 m s x : 0 < s -> 0 < normal_pdf m s x :> R. +Proof. +move=> s0; rewrite /normal_pdf gt_eqF// mulr_gt0 ?expR_gt0// invr_gt0. +by rewrite sqrtr_gt0 pmulrn_rgt0// mulr_gt0 ?pi_gt0 ?exprn_gt0. +Qed. + +Lemma gauss_pdf_gt0 x : 0 < gauss_pdf x. +Proof. exact: normal_pdf_gt0. Qed. + +Definition gauss_prob := @normal_prob R 0 1. + +HB.instance Definition _ := Probability.on gauss_prob. + +Lemma gauss_prob_dominates : gauss_prob `<< lebesgue_measure. +Proof. exact: normal_prob_dominates. Qed. + +Lemma continuous_gauss_pdf x : {for x, continuous gauss_pdf}. +Proof. exact: continuous_normal_pdf. Qed. + +End gauss. + +(* the Lebesgue measure is definable in Staton's language + [equation (10), Section 4.1, Staton ESOP 2017] *) +Section gauss_lebesgue. +Context d (T : measurableType d) (R : realType). +Notation mu := (@lebesgue_measure R). + +Let f1 (x : measurableTypeR R) := (gauss_pdf x)^-1%R. + +Let f1E (x : R) : f1 x = (Num.sqrt (pi *+ 2) * expR (- (- x ^+ 2 / 2)))%R. +Proof. +rewrite /f1 /gauss_pdf /normal_pdf oner_eq0. +rewrite /normal_peak expr1n mul1r. +by rewrite /normal_fun subr0 expr1n invfM invrK expRN. +Qed. + +Let f1_gt0 (x : R) : (0 < f1 x)%R. +Proof. by rewrite f1E mulr_gt0 ?expR_gt0// sqrtr_gt0 mulrn_wgt0// pi_gt0. Qed. + +Lemma measurable_fun_f1 : measurable_fun setT f1. +Proof. +apply: continuous_measurable_fun => x. +apply: (@continuousV _ _ (@gauss_pdf R)). + by rewrite gt_eqF// gauss_pdf_gt0. +exact: continuous_gauss_pdf. +Qed. + +Lemma integral_mgauss01 : forall U, measurable U -> + \int[(@gauss_prob R)]_(y in U) (f1 y)%:E = + \int[mu]_(x0 in U) (gauss_pdf x0 * f1 x0)%:E. +Proof. +move=> U mU. +under [in RHS]eq_integral do rewrite EFinM/= muleC. +rewrite /=. +rewrite -(@Radon_Nikodym_SigmaFinite.change_of_variables + _ _ _ _ (@lebesgue_measure R))//=; last 3 first. + exact: gauss_prob_dominates. + by move=> /= x; rewrite lee_fin ltW. + apply/measurable_EFinP. + apply: measurable_funTS. + exact: measurable_fun_f1. +apply: ae_eq_integral => //=. +- apply: emeasurable_funM => //. + apply/measurable_funTS/measurableT_comp => //. + exact: measurable_fun_f1. + apply: (measurable_int mu). + apply: (integrableS _ _ (@subsetT _ _)) => //=. + apply: Radon_Nikodym_SigmaFinite.f_integrable => /=. + exact: gauss_prob_dominates. +- apply: emeasurable_funM => //. + apply/measurable_funTS/measurableT_comp => //. + exact: measurable_fun_f1. + apply/measurable_funTS/measurableT_comp => //. + exact: measurable_normal_pdf. +- apply: ae_eqe_mul2l => /=. + rewrite /Radon_Nikodym_SigmaFinite.f/=. + case: pselect => [gauss_prob_dom|]; last first. + by move=> /(_ (@gauss_prob_dominates R)). + case: cid => //= h [h1 h2 h3] gauss_probE. + apply: integral_ae_eq => //=. + + exact: integrableS h3. + + apply/measurable_funTS/measurableT_comp => //. + exact: measurable_normal_pdf. + + move=> E EU mE. + by rewrite -gauss_probE. +Qed. + +Let mf1 : measurable_fun setT f1. +Proof. +apply: (measurable_comp (F := [set r : R | r != 0%R])) => //. +- exact: open_measurable. +- by move=> /= r [t _ <-]; rewrite gt_eqF// gauss_pdf_gt0. +- apply: open_continuous_measurable_fun => //. + by apply/in_setP => x /= x0; exact: inv_continuous. +- exact: measurable_normal_pdf. +Qed. + +Definition staton_lebesgue : R.-sfker T ~> _ := + letin (sample_cst (@gauss_prob R : pprobability _ _)) + (letin + (score (measurableT_comp mf1 macc1of2)) + (ret macc1of3)). + +Lemma staton_lebesgueE x U : measurable U -> + staton_lebesgue x U = lebesgue_measure U. +Proof. +move=> mU; rewrite [in LHS]/staton_lebesgue/=. +rewrite [in LHS]letinE /=. +transitivity (\int[(@gauss_prob R)]_(y in U) (f1 y)%:E). + rewrite -[in RHS](setTI U) integral_mkcondr/=. + apply: eq_integral => //= r _. + rewrite letinE/= ge0_integral_mscale//= ger0_norm//; last first. + by rewrite invr_ge0// normal_pdf_ge0. + rewrite integral_dirac// diracT mul1e/= diracE epatch_indic/=. + by rewrite indicE. +rewrite integral_mgauss01//. +transitivity (\int[lebesgue_measure]_(x in U) (\1_U x)%:E). + apply: eq_integral => /= y yU. + by rewrite /f1 divrr ?indicE ?yU// unitfE gt_eqF// gauss_pdf_gt0. +by rewrite integral_indic//= setIid. +Qed. + +End gauss_lebesgue. diff --git a/theories/prob_lang_wip.v b/theories/prob_lang_wip.v new file mode 100644 index 000000000..d080a274c --- /dev/null +++ b/theories/prob_lang_wip.v @@ -0,0 +1,57 @@ +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval finmap. +From mathcomp Require Import rat interval_inference. +From mathcomp Require Import mathcomp_extra boolp classical_sets. +From mathcomp Require Import functions cardinality fsbigop. +From mathcomp Require Import interval_inference reals ereal topology normedtype. +From mathcomp Require Import sequences esum measure lebesgue_measure numfun. +From mathcomp Require Import lebesgue_integral exp kernel trigo prob_lang. +From mathcomp Require Import realfun charge probability derive ftc. +From mathcomp Require Import gauss_integral. + +(**md**************************************************************************) +(* wip waiting for the Poisson distribution *) +(* *) +(* Another example from Section 4.2 in [Equation (13), Staton, ESOP 2017]. *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. +Import Order.TTheory GRing.Theory Num.Def Num.ExtraDef Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. + +(* Staton's definition of the counting measure + [equation (13), Sect. 4.2, Staton ESOP 2017] *) +Section staton_counting. +Context d (X : measurableType d). +Variable R : realType. +Notation mu := (@lebesgue_measure R). +Import Notations. +Hypothesis integral_poisson_density : forall k, + (\int[mu]_x (@poisson_pmf R x k)%:E = 1%E)%E. + +Let f1 n := (@poisson_pmf R 1%R n)^-1%R. + +Let mf1 : measurable_fun setT f1. +Proof. +rewrite /f1 /poisson_pmf. +apply: (measurable_comp (F := [set r : R | r != 0%R])) => //. +- exact: open_measurable. +- move=> /= r [t ? <-]. + by case: ifPn => // t0; rewrite gt_eqF ?mulr_gt0 ?expR_gt0//= invrK ltr0n. +- apply: open_continuous_measurable_fun => //. + by apply/in_setP => x /= x0; exact: inv_continuous. +Qed. + +Definition staton_counting (r : R) : R.-sfker X ~> _ := + letin (sample_cst (@poisson_prob R r 1%N)) + (letin + (score (measurableT_comp mf1 macc1of2)) + (ret macc1of3)). + +End staton_counting. diff --git a/theories/probability.v b/theories/probability.v index 92b50640b..99b9e7c84 100644 --- a/theories/probability.v +++ b/theories/probability.v @@ -8,7 +8,7 @@ From mathcomp Require Import exp numfun lebesgue_measure lebesgue_integral. From mathcomp Require Import reals interval_inference ereal topology normedtype. From mathcomp Require Import sequences derive esum measure exp trigo realfun. From mathcomp Require Import numfun lebesgue_measure lebesgue_integral kernel. -From mathcomp Require Import ftc gauss_integral hoelder. +From mathcomp Require Import charge ftc gauss_integral hoelder. (**md**************************************************************************) (* # Probability *) @@ -49,7 +49,7 @@ From mathcomp Require Import ftc gauss_integral hoelder. (* *) (* ``` *) (* bernoulli_pmf p == Bernoulli pmf with parameter p : R *) -(* bernoulli p == Bernoulli probability measure when 0 <= p <= 1 *) +(* bernoulli_prob p == Bernoulli probability measure when 0 <= p <= 1 *) (* and \d_false otherwise *) (* binomial_pmf n p == binomial pmf with parameters n : nat and p : R *) (* binomial_prob n p == binomial probability measure when 0 <= p <= 1 *) @@ -70,6 +70,11 @@ From mathcomp Require Import ftc gauss_integral hoelder. (* exponential_prob r == exponential probability measure *) (* poisson_pmf r k == pmf of the Poisson distribution with parameter r *) (* poisson_prob r == Poisson probability measure *) +(* XMonemX a b := x ^+ a * `1-x ^+ b *) +(* beta_fun a b := \int[mu]_x (XMonemX a.-1 b.-1 \_`[0,1] x) *) +(* beta_pdf == probability density function for beta *) +(* beta_prob == beta probability measure *) +(* div_beta_fun a b c d := beta_fun (a + c) (b + d) / beta_fun a b *) (* ``` *) (* *) (******************************************************************************) @@ -1094,13 +1099,13 @@ Proof. by apply/measurable_funTS/measurable_fun_if => //=; exact: measurable_funB. Qed. -Definition bernoulli {R : realType} (p : R) : set bool -> \bar R := fun A => +Definition bernoulli_prob {R : realType} (p : R) : set bool -> \bar R := fun A => if (0 <= p <= 1)%R then \sum_(b \in A) (bernoulli_pmf p b)%:E else \d_false A. -Section bernoulli. +Section bernoulli_prob. Context {R : realType} (p : R). -Local Notation bernoulli := (bernoulli p). +Local Notation bernoulli := (bernoulli_prob p). Let bernoulli0 : bernoulli set0 = 0. Proof. @@ -1143,7 +1148,7 @@ Qed. HB.instance Definition _ := @Measure_isProbability.Build _ _ R bernoulli bernoulli_setT. -Lemma eq_bernoulli (P : probability bool R) : +Lemma eq_bernoulli_prob (P : probability bool R) : P [set true] = p%:E -> P =1 bernoulli. Proof. move=> Ptrue sb; rewrite /bernoulli /bernoulli_pmf. @@ -1160,16 +1165,16 @@ rewrite -[in LHS](eq_sb sb)/= measure_fin_bigcup//; last 2 first. - by apply: eq_fsbigr => /= -[]. Qed. -End bernoulli. +End bernoulli_prob. Section bernoulli_measure. Context {R : realType}. Variables (p : R) (p0 : (0 <= p)%R) (p1 : ((NngNum p0)%:num <= 1)%R). -Lemma bernoulli_dirac : bernoulli p = measure_add +Lemma bernoulli_prob_dirac : bernoulli_prob p = measure_add (mscale (NngNum p0) \d_true) (mscale (1 - (Itv01 p0 p1)%:num)%:nng \d_false). Proof. -apply/funext => U; rewrite /bernoulli; case: ifPn => [p01|]; last first. +apply/funext => U; rewrite /bernoulli_prob; case: ifPn => [p01|]; last first. by rewrite p0/= p1. rewrite measure_addE/= /mscale/=. have := @subsetT _ U; rewrite setT_bool => UT. @@ -1187,44 +1192,47 @@ have [->|->|->|->] /= := subset_set2 UT. Qed. End bernoulli_measure. -Arguments bernoulli {R}. +Arguments bernoulli_prob {R}. -Lemma eq_bernoulliV2 {R : realType} (P : probability bool R) : - P [set true] = P [set false] -> P =1 bernoulli 2^-1. +Lemma eq_bernoulli_probV2 {R : realType} (P : probability bool R) : + P [set true] = P [set false] -> P =1 bernoulli_prob 2^-1. Proof. -move=> Ptrue_eq_false; apply/eq_bernoulli. -have : P [set: bool] = 1%E := probability_setT P. +move=> Ptrue_eq_false; apply/eq_bernoulli_prob. +have : P [set: bool] = 1%E := probability_setT _. rewrite setT_bool measureU//=; last first. by rewrite disjoints_subset => -[]//. rewrite Ptrue_eq_false -mule2n; move/esym/eqP. by rewrite -mule_natl -eqe_pdivrMl// mule1 => /eqP<-. Qed. -Section integral_bernoulli. +Section integral_bernoulli_prob. Context {R : realType}. Variables (p : R) (p01 : (0 <= p <= 1)%R). Local Open Scope ereal_scope. -Lemma bernoulliE A : bernoulli p A = p%:E * \d_true A + (`1-p)%:E * \d_false A. -Proof. by case/andP : p01 => p0 p1; rewrite bernoulli_dirac// measure_addE. Qed. +Lemma bernoulli_probE A : + bernoulli_prob p A = p%:E * \d_true A + (`1-p)%:E * \d_false A. +Proof. +by case/andP : p01 => p0 p1; rewrite bernoulli_prob_dirac// measure_addE. +Qed. -Lemma integral_bernoulli (f : bool -> \bar R) : (forall x, 0 <= f x) -> - \int[bernoulli p]_y (f y) = p%:E * f true + (`1-p)%:E * f false. +Lemma integral_bernoulli_prob (f : bool -> \bar R) : (forall x, 0 <= f x) -> + \int[bernoulli_prob p]_y (f y) = p%:E * f true + (`1-p)%:E * f false. Proof. -move=> f0; case/andP : p01 => p0 p1; rewrite bernoulli_dirac/=. +move=> f0; case/andP : p01 => p0 p1; rewrite bernoulli_prob_dirac/=. rewrite ge0_integral_measure_sum// 2!big_ord_recl/= big_ord0 adde0/=. by rewrite !ge0_integral_mscale//= !integral_dirac//= !diracT !mul1e. Qed. -End integral_bernoulli. +End integral_bernoulli_prob. -Section measurable_bernoulli. +Section measurable_bernoulli_prob. Local Open Scope ring_scope. Variable R : realType. Implicit Type p : R. -Lemma measurable_bernoulli : - measurable_fun setT (bernoulli : R -> pprobability bool R). +Lemma measurable_bernoulli_prob : + measurable_fun setT (bernoulli_prob : R -> pprobability bool R). Proof. apply: (measurability (@pset _ _ _ : set (set (pprobability _ R)))) => //. move=> _ -[_ [r r01] [Ys mYs <-]] <-; apply: emeasurable_fun_infty_o => //=. @@ -1238,14 +1246,14 @@ apply: emeasurable_sum => n; move=> k Ysk; apply/measurableT_comp => //. exact: measurable_bernoulli_pmf. Qed. -Lemma measurable_bernoulli2 U : measurable U -> - measurable_fun setT (bernoulli ^~ U : R -> \bar R). +Lemma measurable_bernoulli_prob2 U : measurable U -> + measurable_fun setT (bernoulli_prob ^~ U : R -> \bar R). Proof. -by move=> ?; exact: (measurable_kernel (kprobability measurable_bernoulli)). +by move=> ?; exact: (measurable_kernel (kprobability measurable_bernoulli_prob)). Qed. -End measurable_bernoulli. -Arguments measurable_bernoulli {R}. +End measurable_bernoulli_prob. +Arguments measurable_bernoulli_prob {R}. Section binomial_pmf. Local Open Scope ring_scope. @@ -1374,9 +1382,9 @@ End binomial_probability. Lemma integral_binomial_prob (R : realType) n p U : (0 <= p <= 1)%R -> (\int[binomial_prob n p]_y \d_(0 < y)%N U = - bernoulli (1 - `1-p ^+ n) U :> \bar R)%E. + bernoulli_prob (1 - `1-p ^+ n) U :> \bar R)%E. Proof. -move=> /andP[p0 p1]; rewrite bernoulliE//=; last first. +move=> /andP[p0 p1]; rewrite bernoulli_probE//=; last first. rewrite subr_ge0 exprn_ile1//=; [|exact/onem_ge0|exact/onem_le1]. by rewrite lerBlDr addrC -lerBlDr subrr; exact/exprn_ge0/onem_ge0. rewrite (@integral_binomial _ n p _ _ (fun y => \d_(1 <= y)%N U))//. @@ -1827,14 +1835,14 @@ Section exponential_pdf. Context {R : realType}. Notation mu := lebesgue_measure. Variable rate : R. -Hypothesis rate_gt0 : 0 < rate. +Hypothesis rate_gt0 : 0 <= rate. Let exponential_pdfT x := rate * expR (- rate * x). Definition exponential_pdf := exponential_pdfT \_ `[0%R, +oo[. Lemma exponential_pdf_ge0 x : 0 <= exponential_pdf x. Proof. -by apply: restrict_ge0 => {}x _; apply: mulr_ge0; [exact: ltW|exact: expR_ge0]. +by apply: restrict_ge0 => {}x _; apply: mulr_ge0 => //; exact: expR_ge0. Qed. Lemma lt0_exponential_pdf x : x < 0 -> exponential_pdf x = 0. @@ -1894,7 +1902,6 @@ Context {R : realType}. Local Open Scope ring_scope. Notation mu := lebesgue_measure. Variable rate : R. -Hypothesis rate_gt0 : 0 < rate. Lemma derive1_exponential_pdf : {in `]0, +oo[%R, (fun x => - (expR : R^o -> R^o) (- rate * x))^`()%classic @@ -1931,21 +1938,22 @@ apply: (@continuous_FTC2 _ _ (fun x => - expR (- rate * x))) => //. by apply: derive1_exponential_pdf; rewrite in_itv/= andbT. Qed. -Lemma integral_exponential_pdf : (\int[mu]_x (exponential_pdf rate x)%:E = 1)%E. +Lemma integral_exponential_pdf (rate_gt0 : 0 < rate) : + (\int[mu]_x (exponential_pdf rate x)%:E = 1)%E. Proof. have mEex : measurable_fun setT (EFin \o exponential_pdf rate). by apply/measurable_EFinP; exact: measurable_exponential_pdf. rewrite -(setUv `[0, +oo[%classic) ge0_integral_setU//=; last 4 first. exact: measurableC. by rewrite setUv. - by move=> x _; rewrite lee_fin exponential_pdf_ge0. + by move=> x _; rewrite lee_fin exponential_pdf_ge0// ltW. exact/disj_setPCl. rewrite [X in _ + X]integral0_eq ?adde0; last first. by move=> x x0; rewrite /exponential_pdf patchE ifF// memNset. rewrite (@ge0_continuous_FTC2y _ _ (fun x => - (expR (- rate * x))) _ 0)//. - by rewrite mulr0 expR0 EFinN oppeK add0e. -- by move=> x _; apply: exponential_pdf_ge0. +- by move=> x _; apply: exponential_pdf_ge0; exact: ltW. - exact: within_continuous_exponential_pdf. - rewrite -oppr0; apply: (@cvgN _ R^o). rewrite (_ : (fun x => expR (- rate * x)) = @@ -1957,16 +1965,18 @@ rewrite (@ge0_continuous_FTC2y _ _ - exact: derive1_exponential_pdf. Qed. -Lemma integrable_exponential_pdf : +Lemma integrable_exponential_pdf (rate_gt0 : 0 < rate) : mu.-integrable setT (EFin \o (exponential_pdf rate)). Proof. have mEex : measurable_fun setT (EFin \o exponential_pdf rate). by apply/measurable_EFinP; exact: measurable_exponential_pdf. apply/integrableP; split => //. -under eq_integral do rewrite /= ger0_norm ?exponential_pdf_ge0//. -by rewrite /= integral_exponential_pdf ltry. +under eq_integral do rewrite /= ger0_norm ?(exponential_pdf_ge0 (ltW rate_gt0))//. +by rewrite /= integral_exponential_pdf// ltry. Qed. +Hypothesis rate_gt0 : 0 < rate. + Local Notation exponential := (exponential_prob rate). Let exponential0 : exponential set0 = 0%E. @@ -1975,7 +1985,7 @@ Proof. by rewrite /exponential integral_set0. Qed. Let exponential_ge0 A : (0 <= exponential A)%E. Proof. rewrite /exponential integral_ge0//= => x _. -by rewrite lee_fin exponential_pdf_ge0. +by rewrite lee_fin exponential_pdf_ge0// ltW. Qed. Let exponential_sigma_additive : semi_sigma_additive exponential. @@ -1983,11 +1993,11 @@ Proof. move=> /= F mF tF mUF; rewrite /exponential; apply: cvg_toP. apply: ereal_nondecreasing_is_cvgn => m n mn. apply: lee_sum_nneg_natr => // k _ _; apply: integral_ge0 => /= x Fkx. - by rewrite lee_fin; apply: exponential_pdf_ge0. + by rewrite lee_fin exponential_pdf_ge0// ltW. rewrite ge0_integral_bigcup//=. - apply/measurable_funTS/measurableT_comp => //. exact: measurable_exponential_pdf. -- by move=> x _; rewrite lee_fin exponential_pdf_ge0. +- by move=> x _; rewrite lee_fin exponential_pdf_ge0// ltW. Qed. HB.instance Definition _ := isMeasure.Build _ _ _ @@ -2007,23 +2017,28 @@ Context {R : realType}. Implicit Types (rate : R) (k : nat). Definition poisson_pmf rate k : R := - (rate ^+ k) * k`!%:R^-1 * expR (- rate). + if rate > 0 then (rate ^+ k) * k`!%:R^-1 * expR (- rate) else 1. -Lemma poisson_pmf_ge0 rate k : 0 <= rate -> 0 <= poisson_pmf rate k. -Proof. by move=> r0; rewrite /poisson_pmf 2?mulr_ge0// exprn_ge0. Qed. +Lemma poisson_pmf_ge0 rate k : 0 <= poisson_pmf rate k. +Proof. +rewrite /poisson_pmf; case: ifPn => // rate0. +by rewrite 2?mulr_ge0// exprn_ge0// ltW. +Qed. End poisson_pmf. -Lemma measurable_poisson_pmf {R : realType} D (rate : R) k : +Lemma measurable_poisson_pmf {R : realType} D k : measurable D -> measurable_fun D (@poisson_pmf R ^~ k). Proof. +move=> mD; rewrite /poisson_pmf; apply: measurable_fun_if => //. + exact: measurable_fun_ltr. apply: measurable_funM; first exact: measurable_funM. by apply: measurable_funTS; exact: measurableT_comp. Qed. Definition poisson_prob {R : realType} (rate : R) (k : nat) : set nat -> \bar R := - fun U => if 0 <= rate then + fun U => if 0 < rate then \esum_(k in U) (poisson_pmf rate k)%:E else \d_0%N U. Section poisson. @@ -2061,8 +2076,13 @@ rewrite /poisson; case: ifPn => [rate0|_]; last by rewrite probability_setT. rewrite [RHS](_ : _ = (expR (- rate))%:E * (expR rate)%:E); last first. by rewrite -EFinM expRN mulVf ?gt_eqF ?expR_gt0. rewrite -nneseries_esumT; last by move=> *; rewrite lee_fin poisson_pmf_ge0. -under eq_eseriesr do rewrite EFinM muleC. -rewrite nneseriesZl/=; last by move=> *; rewrite lee_fin divr_ge0// exprn_ge0. +under eq_eseriesr. + move=> n _. + rewrite /poisson_pmf rate0 EFinM muleC. + over. +rewrite /= nneseriesZl/=; last first. + move=> n _. + by rewrite lee_fin divr_ge0// exprn_ge0// ltW. congr *%E; rewrite expRE -EFin_lim; last first. rewrite /pseries/=; under eq_fun do rewrite mulrC. exact: is_cvg_series_exp_coeff. @@ -2080,12 +2100,1734 @@ Lemma measurable_poisson_prob {R : realType} n : Proof. apply: (measurability (@pset _ _ _ : set (set (pprobability _ R)))) => //. move=> _ -[_ [r r01] [Ys mYs <-]] <-; apply: emeasurable_fun_infty_o => //=. -apply: measurable_fun_if => //=; first exact: measurable_fun_ler. +apply: measurable_fun_if => //=; first exact: measurable_fun_ltr. apply: (eq_measurable_fun (fun t => \sum_(k x /set_mem[_/= x01]. by rewrite nneseries_esum ?set_mem_set// =>*; rewrite lee_fin poisson_pmf_ge0. apply: ge0_emeasurable_sum. by move=> k x/= [_ x01] _; rewrite lee_fin poisson_pmf_ge0. -by move=> k Ysk; apply/measurableT_comp => //; exact: measurable_poisson_pmf. +move=> k Ysk; apply/measurableT_comp => //. +apply: measurable_poisson_pmf => //. +rewrite setTI. +rewrite (_ : _ @^-1` _ = `]0, +oo[%classic)//. +by apply/seteqP; split => /= x /=; rewrite in_itv/= andbT. +Qed. + +Section near_lt_lim. +Variable R : realFieldType. +Implicit Types u : R ^nat. + +Lemma near_lt_lim u (M : R) : + (\forall N \near \oo, {in [set n | (N <= n)%N] &, nondecreasing_seq u}) -> + cvgn u -> M < limn u -> \forall n \near \oo, M <= u n. +Proof. +move=> [] N _ Hnear. +move=> cu Ml; have [[n Mun]|/forallNP Mu] := pselect (exists n, M <= u n). + exists (maxn N n) => //. + move=> k/=. + rewrite geq_max => /andP. +(* + near=> m; suff : u n <= u m by exact: le_trans. + apply/(Hnear m). + near: m; exists n.+1 => // p q; apply/(Hnear n)/ltnW => //. + + +have {}Mu : forall x, M > u x by move=> x; rewrite ltNge; apply/negP. +have : limn u <= M by apply: limr_le => //; near=> m; apply/ltW/Mu. +by move/(lt_le_trans Ml); rewrite ltxx. +Unshelve. all: by end_near. Qed. +*) +Abort. + +End near_lt_lim. + +Section near_ereal_nondecreasing_is_cvgn. + +Let G N := ([set n | (N <= n)%N]). + +Lemma ereal_shiftn_nondecreasing_cvgn (R : realType) (u_ : (\bar R)^nat) + (N : nat) : +(* \forall N \near \oo, {in G N &, nondecreasing_seq u_ } + -> u_ @ \oo --> ereal_sup (range (fun n => u_ (n + N))). +*) +{in G N &, nondecreasing_seq u_ } + -> u_ @ \oo --> ereal_sup (range (fun n => u_ (n + N))). +Proof. +move=> H. +rewrite -(cvg_shiftn N). +apply: ereal_nondecreasing_cvgn. +move=> k m km. +apply: H; rewrite /G ?inE//=. +- exact: leq_addl. +- exact: leq_addl. +- exact: leq_add. +Qed. + +Lemma near_ereal_nondecreasing_is_cvgn (R : realType) (u_ : (\bar R) ^nat) : + (\forall N \near \oo, {in G N &, nondecreasing_seq u_ }) -> cvgn u_. +Proof. +move=> [] N _ H. +apply/cvg_ex. +exists (ereal_sup (range (fun n => u_ (n + N)))). +apply: ereal_shiftn_nondecreasing_cvgn. +by apply: (H N); rewrite /G ?inE/=. +Qed. + +End near_ereal_nondecreasing_is_cvgn. + +(* TODO: move as another PR *) +Section near_monotone_convergence. +Local Open Scope ereal_scope. + +Context d (T : measurableType d) (R : realType). +Variable mu : {measure set T -> \bar R}. +Variables (D : set T) (mD : measurable D) (g' : (T -> \bar R)^nat). +Hypothesis mg' : forall n, measurable_fun D (g' n). +Hypothesis near_g'0 : \forall n \near \oo, forall x, D x -> 0 <= g' n x. +Hypothesis near_nd_g' : \forall N \near \oo, (forall x : T, D x -> + {in [set k| (N <= k)%N]&, {homo g'^~ x : n m / (n <= m)%N >-> (n <= m)%E}}). +Let f' := fun x => limn (g'^~ x). + +Lemma near_monotone_convergence : +(\int[mu]_(x in D) (fun x0 : T => limn (g'^~ x0)) x)%E = +limn (fun n : nat => (\int[mu]_(x in D) g' n x)%E). +Proof. +have [N0 _ H0] := near_g'0. +have [N1 _ H1] := near_nd_g'. +pose N := maxn N0 N1. +under eq_integral. + move=> x; rewrite inE/= => Dx. + have <- : limn (fun n : nat => g' (n + N) x) = limn (g'^~ x). + apply/cvg_lim => //. + rewrite (cvg_shiftn _ (g'^~ x) _). + apply: (@near_ereal_nondecreasing_is_cvgn _ (g'^~ x)). + exists N1 => //. + move=> n /= N1n. + exact: H1. + over. +apply/esym/cvg_lim => //. +rewrite -(cvg_shiftn N). +apply: cvg_monotone_convergence => //. + move=> n x Dx. + apply: H0 => //=. + apply: (leq_trans (leq_maxl N0 N1)). + exact: leq_addl. +move=> x Dx n m nm. +apply: (H1 N) => //; rewrite ?inE/=. +- exact: leq_maxr. +- exact: leq_addl. +- exact: leq_addl. +- exact: leq_add. +Qed. + +Lemma cvg_near_monotone_convergence : + \int[mu]_(x in D) g' n x @[n \oo] --> \int[mu]_(x in D) f' x. +Proof. +have [N0 _ Hg'0] := near_g'0. +have [N1 _ Hndg'] := near_nd_g'. +pose N := maxn N0 N1. +have N0N : (N0 <= N)%N by apply: (leq_maxl N0 N1). +have N1N : (N1 <= N)%N by apply: (leq_maxr N0 N1). +have g'_ge0 n x : D x -> (N <= n)%N -> 0 <= g' n x. + move=> + Nn. + apply: Hg'0 => /=. + exact: (leq_trans N0N). +have ndg' n m x : D x -> (N <= n)%N -> (n <= m)%N -> g' n x <= g' m x. + move=> Dx Nn nm. + apply: (Hndg' N); rewrite ?inE//=. + exact: leq_trans nm. +rewrite near_monotone_convergence. +apply: near_ereal_nondecreasing_is_cvgn. +exists N => //. +move=> k/= Nk n m; rewrite !inE/= => kn km nm. +apply: ge0_le_integral => // t Dt; [| |]. +- apply: g'_ge0 => //. + exact: leq_trans kn. +- apply: g'_ge0 => //. + exact: leq_trans km. +- apply: ndg' => //. + exact: leq_trans kn. +Qed. + +End near_monotone_convergence. + +Section exp_coeff_properties. +Context {R : realType}. + +(* not used, TODO:PR *) +Lemma exp_coeff_gt0 (x : R) n : 0 < x -> 0 < exp_coeff x n. +Proof. +move=> x0. +rewrite /exp_coeff/=. +apply: divr_gt0. + exact: exprn_gt0. +rewrite (_:0%R = 0%:R)// ltr_nat. +exact: fact_gt0. +Abort. + +Lemma series_exp_coeff_near_ge0 (x : R) : + \forall n \near \oo, 0 <= (series (exp_coeff x)) n. +Proof. +apply: (cvgr_ge (expR x)); last exact: expR_gt0. +exact: is_cvg_series_exp_coeff. +Abort. + +Lemma normr_exp_coeff_near_nonincreasing (x : R) : + \forall n \near \oo, + `|exp_coeff x n.+1| <= `|exp_coeff x n|. +Proof. +exists `|archimedean.Num.Def.ceil x |%N => //. +move=> n/= H. +rewrite exp_coeffE. +rewrite exprS mulrA normrM [leRHS]normrM ler_pM//. +rewrite factS mulnC natrM invfM -mulrA normrM ger_pMr; last first. + rewrite normr_gt0. + by rewrite invr_neq0//. +rewrite normrM normfV. +rewrite ler_pdivrMl; last first. + rewrite normr_gt0. + by rewrite lt0r_neq0. +rewrite mulr1. +apply: (le_trans (abs_ceil_ge _)). +rewrite gtr0_norm//. +by rewrite ler_nat ltnS. +Qed. + +Lemma exp_coeff2_near_nondecreasing (x : R) : + \forall N \near \oo, nondecreasing_seq (fun n => (series (exp_coeff x) (2 * (n + N))%N)). +Proof. +have := normr_exp_coeff_near_nonincreasing x. +move=> [N _] Hnear. +exists N => //n/= Nn. +apply/nondecreasing_seqP => k. +rewrite /series/=. +have N0 : (0 <= N)%N by []. +rewrite addSn mulnS add2n. +rewrite !big_nat_recr//=. +rewrite -addrA lerDl. +rewrite -[X in _ <= _ + X]opprK subr_ge0. +rewrite (le_trans (ler_norm _))// normrN. +have : (N <= (2 * (k + n)))%N. + rewrite mulnDr -(add0n N) leq_add//. + by rewrite mulSn mul1n -(add0n N) leq_add. +move/Hnear => H. +apply: (le_trans H). +rewrite ler_norml lexx andbT. +suff Hsuff : 0 <= exp_coeff x (2 * (k + n))%N. + by apply: (le_trans _ Hsuff); rewrite lerNl oppr0. +rewrite /exp_coeff/=. +apply: mulr_ge0 => //. +apply: exprn_even_ge0. +by rewrite mul2n odd_double. +Qed. + +Lemma exp_coeff2_near_in_increasing (x : R) : + \forall N \near \oo, {in [set k | (N <= k)%N] &, +nondecreasing_seq (fun n => (series (exp_coeff x) (2 * n)%N))}. +Proof. +have := normr_exp_coeff_near_nonincreasing x. +move=> [N _] Hnear. +exists N => //k/= Nk. +move=> n m; rewrite !inE/= => kn km nm. +have kn2 : (2 * k <= 2 * n)%N by rewrite leq_pmul2l. +have km2 : (2 * k <= 2 * m)%N by rewrite leq_pmul2l. +rewrite /series/=. +rewrite (big_cat_nat _ kn2)//=. +rewrite (big_cat_nat _ km2)//=. +rewrite lerD2. +have nm2 : (2 * n <= 2 * m)%N by rewrite leq_pmul2l. +rewrite (big_cat_nat _ nm2)//=. +rewrite lerDl. +rewrite -(add0n (2 * n)%N). +rewrite big_addn. +rewrite -mulnBr. +elim: (m - n)%N. + rewrite muln0. + rewrite big_mkord. + by rewrite big_ord0. +move=> {km nm km2 nm2} {}m IH. +rewrite mul2n. +rewrite doubleS. +rewrite big_nat_recr//=. +rewrite big_nat_recr//=. +rewrite -addrA. +rewrite addr_ge0//. + by rewrite -mul2n. +rewrite -[X in _ <= _ + X]opprK subr_ge0. +rewrite (le_trans (ler_norm _))// normrN. +rewrite -mul2n addSn -mulnDr. +have : (N <= (2 * (m + n)))%N. + rewrite mulnDr -(add0n N) leq_add//. + by rewrite (leq_trans _ kn2)// (leq_trans Nk)// leq_pmull. +move/Hnear => H. +apply: (le_trans H). +rewrite ler_norml lexx andbT. +suff Hsuff : 0 <= exp_coeff x (2 * (m + n))%N. + by apply: (le_trans _ Hsuff); rewrite lerNl oppr0. +rewrite /exp_coeff/=. +apply: mulr_ge0 => //. +apply: exprn_even_ge0. +by rewrite mul2n odd_double. +Qed. + +End exp_coeff_properties. + +(* TODO: move *) +Section shift_properties. +Variable R : realType. + +Local Open Scope ring_scope. + +Notation mu := lebesgue_measure. + +Lemma ge0_integration_by_substitution_shift_itvNy (f : R -> R) (r e : R) : + {within `]-oo, r + e], continuous f} -> + {in `]-oo, r + e[, forall x : R, 0 <= f x} -> + (\int[mu]_(x in `]-oo, (r + e)%R]) (f x)%:E = + \int[mu]_(x in `]-oo, r]) ((f \o (shift e)) x)%:E)%E. +Proof. +move=> cf f0. +have := (derive_shift 1 e). +have <- := (funext (@derive1E R _ (shift e : R^o -> R^o))). +move=> dshiftE. +rewrite (@increasing_ge0_integration_by_substitutionNy _ (shift e))//; first last. +- exact: cvg_addrr_Ny. +- split. + move=> x _. + exact/ex_derive. + apply: cvg_at_left_filter. + apply: cvgD => //. + exact: cvg_cst. +- rewrite dshiftE. + exact: cvg_cst. +- rewrite dshiftE. + exact: is_cvg_cst. +- rewrite dshiftE. + move=> ? _; apply: cst_continuous. +- by move=> x y _ _ xy; rewrite ltr_leD. +by rewrite dshiftE mulr1/=. +Qed. + +Lemma ge0_integration_by_substitution_shift_itvy (f : R -> R) (r e : R) : + {within `[r + e, +oo[, continuous f} -> + {in `]r + e, +oo[, forall x : R, 0 <= f x} -> + (\int[mu]_(x in `[r + e, +oo[) (f x)%:E = + \int[mu]_(x in `[r, +oo[) ((f \o (shift e)) x)%:E)%E. +Proof. +move=> cf f0. +have := (derive_shift 1 e). +have <- := (funext (@derive1E R _ (shift e : R^o -> R^o))). +move=> dshiftE. +rewrite (@increasing_ge0_integration_by_substitutiony _ (shift e))//=; first last. +- exact: cvg_addrr. +- split. + move=> x _. + exact/ex_derive. + apply: cvg_at_right_filter. + apply: cvgD => //. + exact: cvg_cst. +- rewrite dshiftE. + exact: is_cvg_cst. +- rewrite dshiftE. + exact: is_cvg_cst. +- rewrite dshiftE. + move=> ? _; apply: cst_continuous. +- by move=> x y _ _ xy; rewrite ltr_leD. +by rewrite dshiftE mulr1/=. +Qed. + +End shift_properties. + +Section normal_kernel. +Variable R : realType. +Variables s : R. +Hypothesis s0 : s != 0. +Local Open Scope ring_scope. +Notation mu := lebesgue_measure. + +Let normal_pdfE m x : normal_pdf m s x = + (Num.sqrt (s^+2 * pi *+ 2))^-1 * expR (- (x - m) ^+ 2 / (s^+2 *+ 2)). +Proof. +rewrite /normal_pdf ifF//. +exact/negP/negP. +Qed. + +Local Definition normal_prob2 := (fun m => normal_prob m s) : _ -> pprobability _ _. + +Lemma bij_shift x : bijective (id \+ @cst R R x). +Proof. +apply: (@Bijective _ _ _ (id \- cst x)). +- by move=> z;rewrite /=addrK. +- by move=> z; rewrite /= subrK. +Qed. + +Lemma shift_ocitv (x a b : R) : + (shift x) @` `]a, b]%classic = `]a + x, b + x]%classic. +Proof. +rewrite eqEsubset; split => r/=. + move=> [r' + <-]. + rewrite in_itv/=; move/andP => [ar' r'b]. + by rewrite in_itv/=; apply/andP; split; rewrite ?lerD2 ?ltrD2. +rewrite in_itv/=; move/andP => [axr rbx]. +exists (r - x); last by rewrite subrK. +rewrite in_itv/=; apply/andP; split. +- by rewrite ltrBrDr. +- by rewrite lerBlDr. +Qed. + +Lemma shift_preimage (x : R) U : + (shift x) @^-1` U = (shift (- x)) @` U. +Proof. +rewrite eqEsubset; split => r. + rewrite /= => Urx. + by exists (r + x) => //; rewrite addrK. +by move=> [z Uz <-]/=; rewrite subrK. +Qed. + +Lemma pushforward_shift_itv (mu : measure (measurableTypeR R) R) (a b x : R) : + (pushforward mu (fun z => z + x) + `]a, b]) = + mu `]a - x, b - x]%classic. +Proof. +rewrite /pushforward. +rewrite shift_preimage. +by rewrite shift_ocitv. +Qed. + +Lemma pushforward_shift_measurable (mu : measure (measurableTypeR R) R) (x : R) (U : set R) : + (pushforward mu (fun z => z + x) + U) = + mu ((center x) @` U). +Proof. +by rewrite /pushforward shift_preimage. +Qed. + +From mathcomp Require Import charge. +Open Scope charge_scope. + +(* +Lemma radon_nikodym_crestr_fin U (mU : measurable U) +(Uoo : (@lebesgue_measure R U < +oo)%E) : + ae_eq lebesgue_measure setT ('d charge_of_finite_measure (mfrestr mU Uoo) '/d + [the sigma_finite_measure _ _ of @lebesgue_measure R]) + (EFin \o \1_U). +Proof. +apply: integral_ae_eq => //=. +- admit. +- admit. +move=> E _ mE. +rewrite -Radon_Nikodym_integral. +rewrite integral_indic/=. +by rewrite /mfrestr/mrestr setIC. +Admitted. +*) + +(* +Lemma radon_nikodym_crestr U (mU : measurable U) : + ae_eq lebesgue_measure setT ('d charge_of_finite_measure (mfrestr mU Uoo) '/d + [the sigma_finite_measure _ _ of @lebesgue_measure R]) + (EFin \o \1_U). +Proof. +*) + +(* +rewrite [RHS](_:_= ('d charge_of_finite_measure (mfrestr mU Uoo) '/d + [the sigma_finite_measure _ _ of @lebesgue_measure R]) + (EFin \o \1_U) + move=> x _. + rewrite epatch_indic. + rewrite -radon_nikodym_crestr. +rewrite [RHS]integral_mkcond. +under [RHS]eq_integral do rewrite epatch_indic. + +rewrite -integral_pushforward. +apply: eq_integral. +move=> x _. +Admitted. +*) +(*Local Definition normal_prob2 := + (fun m => normal_prob m s) : _ -> pprobability _ _. +*) +(* +Lemma normal_shift0 x : +normal_prob2 x = + @pushforward _ _ _ + (measurableTypeR R) _ (normal_prob2 0%R) (fun z => z + x) + :> (set R -> \bar R). +Proof. +apply: funext. +move=> U. +rewrite /normal_prob2/=. +rewrite /pushforward/=. +rewrite /normal_prob. +rewrite shift_preimage. +rewrite integration_by_substitution_shift/=. +apply: eq_integral. +move=> z Uz. +congr EFin. +rewrite /normal_pdf/=. +rewrite ifF; last exact/negP/negP. +rewrite ifF; last exact/negP/negP. +rewrite {2}/normal_fun. +by rewrite subr0. +Qed. +*) + +(* +Lemma measurable_normal_prob2_ocitv a b: + measurable_fun [set: R] (normal_prob2 ^~ `]a, b]%classic). +Proof. +apply: (@measurability _ _ _ _ _ _ + (@pset _ _ _ : set (set (pprobability _ R)))) => //. +move=> _ -[_ [r r01] [Ys mYs <-]] <-; apply: emeasurable_fun_infty_o => //=. + +rewrite /normal_prob2/=. +rewrite /normal_prob. + +under [X in measurable_fun _ X]eq_fun. + move=> x. + rewrite (_: normal_kernel _ _ = (fine (normal_kernel x `]a, b]%classic))%:E); last first. + rewrite fineK//. + rewrite ge0_fin_numE//. + apply: (@le_lt_trans _ _ 1%E); last exact: ltey. + exact: probability_le1. + rewrite normal_shift0/=. + over. +apply: measurableT_comp; last by []. +apply: measurableT_comp; first exact: EFin_measurable. +rewrite /=. +under [X in measurable_fun _ X]eq_fun. + move=> x. + rewrite /normal_prob. +(pushforward_shift_itv (normal_kernel 0) a b x). +apply: continuous_measurable_fun. +*) + +(* outline of proof: + 1. It is enough to prove that `(fun x => normal_prob x s Ys)` is continuous for + all measurable set `Ys`. + 2. Continuity is obtained by continuity under integral from continuity of + `normal_pdf`. + 3. Fix a point `a` in `R` and `e` with `0 < e`. Then take the function + `g : R -> R` as that `g x` is the maximum value of + `normal_pdf a s x` at a point within `e` of `x`. + Then `g x` is equal to `normal_pdf a s 0` if `x` in `ball a e`, + `normal_pdf a s (x - e)` for x > a + e, + and `normal_pdf a s (x + e)` for x < a - e. + 4. Integrability of `g` is checked by calculating integration. + By integration by substitution, the integral of `g` on ]-oo, a - e] + is equal to the integral of `normal_pdf a s` on `]-oo, a], + and it on `[a + e, +oo[ similarly. + So the integral of `g` on ]-oo, +oo[ is the integral of `f` on ]-oo, +oo[ + added by the integral of `normal_pdf a s x` on ]a - e, a + e[ + *) + +Let normal_pdf0 m s x : R := normal_peak s * normal_fun m s x. + +Let normal_pdf0_ge0 m x : 0 <= normal_pdf0 m s x. +Proof. by rewrite mulr_ge0 ?normal_peak_ge0 ?expR_ge0. Qed. + +Let continuous_normal_pdf0 m : continuous (normal_pdf0 m s). +Proof. +move=> x; apply: cvgM; first exact: cvg_cst. +apply: (@cvg_comp _ R^o _ _ _ _ + (nbhs (- (x - m) ^+ 2 / (s ^+ 2 *+ 2)))); last exact: continuous_expR. +apply: cvgM; last exact: cvg_cst; apply: (@cvgN _ R^o). +apply: (@cvg_comp _ _ _ _ (@GRing.exp R^~ 2) _ (nbhs (x - m))). + apply: (@cvgB _ R^o) => //; exact: cvg_cst. +exact: sqr_continuous. +Qed. + +Let normal_pdf0_ub m x : normal_pdf0 m s x <= normal_peak s. +Proof. +rewrite /normal_pdf0 ler_piMr ?normal_peak_ge0//. +rewrite -[leRHS]expR0 ler_expR mulNr oppr_le0 mulr_ge0// ?sqr_ge0//. +by rewrite invr_ge0 mulrn_wge0// sqr_ge0. +Qed. + +Let g' a e : R -> R := fun x => if x \in (ball a e : set R^o) then + normal_peak s else normal_pdf0 e s `|x - a|. + +Let ballFE_le (a e x : R) : x <= (a - e)%R -> + (x \in (ball a e : set R^o)) = false. +Proof. +move=> xae. +apply: memNset. +rewrite ball_itv/= in_itv/=; apply/negP/andP/not_andP; left. +by apply/negP; rewrite -leNgt. +Qed. + +Let ballFE_ge (a e x : R) : a + e <= x -> + (x \in (ball a e : set R^o)) = false. +Proof. +move=> xae. +apply: memNset. +rewrite ball_itv/= in_itv/=; apply/negP/andP/not_andP; right. +by apply/negP; rewrite -leNgt. +Qed. + +Let g'a0 (a : R) : g' a 0 = normal_pdf0 a s. +Proof. +apply/funext => x; rewrite /g'. +have /orP [x0|x0] := le_total x a. + rewrite ballFE_le; last by rewrite subr0. + by rewrite /normal_pdf0 /normal_fun subr0 real_normK// num_real. +rewrite ballFE_ge; last by rewrite addr0. +by rewrite /normal_pdf0 /normal_fun subr0 real_normK// num_real. +Qed. + +Let mg' a e : measurable_fun setT (g' a e). +Proof. +apply: measurable_fun_if => //. + apply: (measurable_fun_bool true) => /=. + rewrite setTI preimage_mem_true. + exact: measurable_ball. +apply: measurable_funTS => /=; apply: measurable_funM => //. +apply: measurableT_comp => //; first exact: measurable_normal_fun. +by apply: measurableT_comp => //; exact: measurable_funD. +Qed. + +Let g'_ge0 a e x : 0 <= g' a e x. +Proof. +rewrite /g'; case: ifP => _; first by rewrite normal_peak_ge0. +exact: normal_pdf0_ge0. +Qed. + +Let continuous_g' (a e : R) : 0 <= e -> continuous (g' a e). +Proof. +move=> e0. +have aNe k : k < a - e -> (`|k - a| - e) ^+ 2 = (k - (a - e)) ^+ 2. + move=> kae; rewrite ler0_norm; first by rewrite -sqrrN !opprB addrCA. + by rewrite subr_le0 (le_trans (ltW kae))// lerBlDl lerDr. +have aDe k : a + e < k -> (`|k - a| - e) ^+ 2 = (k - (a + e)) ^+ 2. + move=> kae; rewrite opprD addrA. + by rewrite ger0_norm// subr_ge0 (le_trans _ (ltW kae))// lerDl. +apply: (@in1TT R). +rewrite -continuous_open_subspace; last exact: openT. +rewrite (_ : [set: R] = + `]-oo, (a - e)%R] `|` `[(a - e)%R, a + e] `|` `[a + e, +oo[); last first. + rewrite -setUitv1// -setUA setUAC setUA -itv_bndbnd_setU//; last first. + by rewrite bnd_simp lerD// ge0_cp. + rewrite -setUitv1// -setUA setUAC setUA -itv_bndbnd_setU//. + by rewrite set_itvE !setTU. +apply: withinU_continuous. +- rewrite -setUitv1// -setUA setUCA -itv_bndbnd_setU//; last first. + by rewrite bnd_simp lerD// ge0_cp. + by rewrite setUidr// sub1set inE/= in_itv/= lerD// ge0_cp. +- exact: interval_closed. +- apply: withinU_continuous; first exact: interval_closed. + + exact: interval_closed. + + apply/continuous_within_itvNycP; split. + * move=> x. + rewrite in_itv/= => xae. + apply/(@cvgrPdist_le _ R^o R _ _ (g' a e) (g' a e x)) => /= eps eps0. + near=> t. + have tae : t < a - e by near: t; exact: lt_nbhsl. + rewrite /g'. + rewrite !ballFE_le ?(@ltW _ _ _ (a - e))//. + rewrite /normal_pdf0 /normal_fun !aNe//. + rewrite -!/(normal_fun _ _ _) -!/(normal_pdf0 _ _ _). + move=> {tae}; near: t. + move: eps eps0. + apply/(@cvgrPdist_le R^o). + exact: continuous_normal_pdf0. + * apply/(@cvgrPdist_lt _ R^o) => eps eps0. + near=> t. + rewrite /g' !ballFE_le//. + rewrite -addrAC subrr sub0r normrN (ger0_norm e0)//. + rewrite /normal_pdf0 /normal_fun subrr -(subrr (a - e)) aNe//. + near: t; move: eps eps0. + apply/(@cvgrPdist_lt _ R^o). + apply: cvg_at_left_filter. + exact: continuous_normal_pdf0. + move: e0; rewrite le_eqVlt => /predU1P[<-|e0]. + rewrite g'a0. + apply: continuous_subspaceT. + exact: continuous_normal_pdf0. + apply/continuous_within_itvP; first by rewrite -(opprK e) ler_ltB// opprK gtrN. + split. + + move=> x xae. + rewrite /continuous_at. + rewrite /g' ifT; last by rewrite ball_itv inE/=. + apply/(@cvgrPdist_le _ R^o) => eps eps0. + near=> t. + rewrite ifT; first by rewrite subrr normr0 ltW. + rewrite ball_itv inE/= in_itv/=; apply/andP; split. + near: t. + apply: lt_nbhsr. + by move: xae; rewrite in_itv/= => /andP[]. + near: t. + apply: lt_nbhsl. + by move: xae; rewrite in_itv/= => /andP[]. + + apply/(@cvgrPdist_le _ R^o) => eps eps0. + near=> t. + rewrite /g' ballFE_le// ifT; last first. + rewrite ball_itv inE/= in_itv/=; apply/andP => []; split => //. + near: t. + apply: nbhs_right_lt. + by rewrite -(opprK e) ler_ltB// opprK gtrN. + rewrite /normal_pdf0. + rewrite addrAC subrr sub0r normrN /normal_fun (gtr0_norm e0) subrr. + by rewrite expr0n/= oppr0 mul0r expR0 mulr1 subrr normr0 ltW. + + apply/(@cvgrPdist_le _ R^o) => eps eps0. + near=> t. + rewrite /g' ballFE_ge// ifT; last first. + rewrite ball_itv inE/= in_itv/=; apply/andP => []; split => //. + near: t. + apply: nbhs_left_gt. + by rewrite -(opprK e) ler_ltB// opprK gtrN. + rewrite /normal_pdf0. + rewrite addrAC subrr add0r /normal_fun (gtr0_norm e0) subrr expr0n oppr0. + by rewrite mul0r expR0 mulr1 subrr normr0 ltW. +- apply/continuous_within_itvcyP; split. + + move=> x. + rewrite in_itv/= andbT => aex. + apply/(@cvgrPdist_le _ R^o) => /= eps eps0. + near=> t. + have tae : a + e < t by near: t; exact: lt_nbhsr. + rewrite /g'. + rewrite !ballFE_ge ?(@ltW _ _ (a + e)%E)//. + rewrite /normal_pdf0 /normal_fun !aDe// ?(@ltW _ _ (a + e)). + move=> {tae}; near: t. + move: eps eps0. + apply/(@cvgrPdist_le _ R^o). + exact: continuous_normal_pdf0. + + apply/(@cvgrPdist_le _ R^o) => eps eps0. + near=> t. + rewrite /g' !ballFE_ge//. + rewrite addrAC subrr add0r (ger0_norm e0)//. + rewrite /normal_pdf0 /normal_fun subrr -(subrr (a + e)). + rewrite aDe//. + near: t. + move: eps eps0. + apply/cvgrPdist_le. + apply: cvg_at_right_filter. + exact: continuous_normal_pdf0. +Unshelve. all: end_near. Qed. + +Let gE_Ny a e : 0 <= e -> + (\int[mu]_(x in `]-oo, (a - e)%R]) `|g' a e x|%:E = + \int[mu]_(x in `]-oo, a]) `|normal_pdf a s x|%:E)%E. +Proof. +move=> e0. +rewrite ge0_integration_by_substitution_shift_itvNy => /=; first last. +- by move=> ? _; exact: normr_ge0. +- apply/continuous_subspaceT => x. + apply: continuous_comp; first exact: continuous_g'. + exact: (@norm_continuous _ R^o) . +under eq_integral. + move=> x. + rewrite inE/= in_itv/= => xae. + rewrite /g' ballFE_le//; last exact: lerB. + rewrite -(normrN (x - e - a)) !opprB addrA. + have /normr_idP -> : 0 <= a + e - x by rewrite subr_ge0 ler_wpDr. + rewrite /normal_pdf0 /normal_fun. + rewrite -(addrAC _ (- x)) addrK. + rewrite -(sqrrN (a - x)) opprB. + over. +by apply: eq_integral => /= x xay; rewrite /normal_pdf (negbTE s0). +Qed. + +Let gE_y a e : 0 <= e -> + (\int[mu]_(x in `[a + e, +oo[) `|g' a e x|%:E = + \int[mu]_(x in `[a, +oo[) `|normal_pdf a s x|%:E)%E. +Proof. +move=> e0. +rewrite ge0_integration_by_substitution_shift_itvy => /=; first last. +- by move=> ? _; exact: normr_ge0. +- apply/continuous_subspaceT => x. + apply: continuous_comp; first exact: continuous_g'. + exact: (@norm_continuous _ R^o). +under eq_integral. + move=> x. + rewrite inE/= in_itv/= andbT => aex. + rewrite /g' ballFE_ge//; last exact: lerD. + have /normr_idP -> : 0 <= x + e - a by rewrite subr_ge0 ler_wpDr. + rewrite /normal_pdf0 /normal_fun -(addrAC _ (- a)) addrK. + over. +rewrite /=. +by apply: eq_integral => /= x xay; rewrite /normal_pdf (negbTE s0). +Qed. + +Lemma normal_prob_continuous (V : set R) : measurable V -> + continuous (fun m => fine (normal_prob m s V)). +Proof. +move=> mV a. +near (0 : R)^'+ => e. +set g := g' a e. +have mg := mg' a e. +apply: (@continuity_under_integral _ _ _ mu _ _ _ _ (a - e) (a + e) _ _ _ g) => //=. +- rewrite in_itv/=. + by rewrite ltrDl gtrBl andbb. +- move=> x _. + by apply: (integrableS measurableT) => //=; exact: integrable_normal_pdf. +- apply/aeW => y _ x. + under [X in _ _ X]eq_fun. + move=> x0. + rewrite normal_pdfE// /normal_pdf0 /normal_fun -(sqrrN (y - _)) opprB. + over. + exact: continuous_normal_pdf0. +- apply: (integrableS measurableT) => //=. + apply/integrableP; split; first exact/measurable_EFinP. + rewrite -(setUv (ball a e)) ge0_integral_setU//=; last 4 first. + exact: measurable_ball. + by apply: measurableC; exact: measurable_ball. + rewrite setUv. + by apply/measurable_EFinP; exact: measurableT_comp. + exact/disj_setPCl. + apply: lte_add_pinfty. + under eq_integral. + move=> x xae. + rewrite /g /g' xae. + over. + rewrite integral_cst/=. + apply: lte_mul_pinfty => //. + rewrite ball_itv lebesgue_measure_itv/= ifT -?EFinD ?ltry// lte_fin. + by rewrite ltrBlDr -addrA -ltrBlDl subrr -mulr2n mulrn_wgt0. + exact: measurable_ball. + rewrite [ltLHS](_ : _ = \int[mu]_x `|normal_pdf a s x|%:E)%E; last first. + rewrite ball_itv setCitv ge0_integral_setU//=; first last. + apply/disj_setPRL. + rewrite setCitvl. + apply: subset_itvr; rewrite bnd_simp. + by rewrite -{2}(opprK e) ler_ltB// gtrN. + apply: measurable_funTS; apply/measurable_EFinP. + exact: measurableT_comp. + rewrite gE_Ny// gE_y// -integral_itv_obnd_cbnd; last first. + apply: measurableT_comp => //; apply: measurable_funTS. + exact: measurable_normal_pdf. + rewrite -ge0_integral_setU/= ?measurable_itv//; first last. + by apply/disj_setPRL; rewrite setCitvl. + rewrite -setCitvl setUv. + apply/measurable_EFinP; apply: measurableT_comp => //. + exact: measurable_normal_pdf. + by rewrite -setCitvl setUv. + under eq_integral do rewrite -abse_EFin. + apply/abse_integralP => //=. + by apply/measurable_EFinP; exact: measurable_normal_pdf. + by rewrite integral_normal_pdf ltry. +move=> x; rewrite in_itv/= => /andP[aex xae]. +apply: aeW => /= y Vy. +rewrite ger0_norm; last exact: normal_pdf_ge0. +rewrite normal_pdfE// /g /g'. +case: ifPn => [_|]; first exact: normal_pdf0_ub. +rewrite notin_setE/= ball_itv/= in_itv/= => aey. +rewrite /normal_pdf0 ler_pM//. +rewrite ler_expR !mulNr lerN2 ler_pM //. + exact: sqr_ge0. + by rewrite invr_ge0 mulrn_wge0// sqr_ge0. +move: aey; move/negP/nandP; rewrite -!leNgt => -[yae|aey]. + rewrite -normrN opprB ger0_norm; last first. + by rewrite subr_ge0 (le_trans yae)// gerBl. + rewrite -[leRHS]sqrrN opprB ler_sqr ?nnegrE; first last. + rewrite subr_ge0 ltW// (le_lt_trans yae)//. + by rewrite addrAC subr_ge0. + by rewrite addrAC lerD2r ltW. +rewrite ger0_norm; last first. + by rewrite subr_ge0 (le_trans _ aey)// lerDl. +rewrite ler_sqr ?nnegrE; last 2 first. + by rewrite -addrA -opprD subr_ge0. + by rewrite subr_ge0 (le_trans _ aey)// ltW. +by rewrite -addrA -opprD lerD2l lerN2 ltW. +Unshelve. end_near. Qed. + +Lemma measurable_normal_prob2 : + measurable_fun setT (normal_prob2 : R -> pprobability _ _). +Proof. +apply: (@measurability _ _ _ _ _ _ + (@pset _ _ _ : set (set (pprobability _ R)))) => //. +move=>_ -[_ [r r01] [Ys mYs <-]] <-. +apply: emeasurable_fun_infty_o => //=. +under [X in _ _ X]eq_fun. + move=> x. + rewrite -(@fineK _ (normal_prob x s Ys)); last first. + rewrite ge0_fin_numE => //. + apply: (@le_lt_trans _ _ (normal_prob x s setT)). + by rewrite le_measure ?inE. + exact: (le_lt_trans (probability_le1 _ _) (ltey _)). + over. +apply/measurable_EFinP; apply: continuous_measurable_fun. +exact: normal_prob_continuous. +Qed. + +End normal_kernel. + +(* lemmas about the function x -> (1 - x)^n *) +Section about_onemXn. +(* TODO: move? *) + +Lemma continuous_comp_cvg {R : numFieldType} (U V : pseudoMetricNormedZmodType R) + (f : U -> V) (g : R -> U) (r : R) (l : V) : continuous f -> + (f \o g) x @[x --> r] --> l -> f x @[x --> g r] --> l. +Proof. +move=> cf fgl; apply/(@cvgrPdist_le _ V) => /= e e0. +have e20 : 0 < e / 2 by rewrite divr_gt0. +move/(@cvgrPdist_le _ V) : fgl => /(_ _ e20) fgl. +have /(@cvgrPdist_le _ V) /(_ _ e20) fgf := cf (g r). +rewrite !near_simpl/=; near=> t. +rewrite -(@subrK _ (f (g r)) l) -(addrA (_ + _)) (le_trans (ler_normD _ _))//. +rewrite (splitr e) lerD//; last by near: t. +by case: fgl => d /= d0; apply; rewrite /ball_/= subrr normr0. +Unshelve. all: by end_near. Qed. + +Lemma continuous_onemXn {R : realType} (n : nat) x : + {for x, continuous (fun y : R => `1-y ^+ n)}. +Proof. +apply: (@continuous_comp _ _ _ (@onem R) (fun x => x ^+ n)). + by apply: (@cvgB _ R^o); [exact: cvg_cst|exact: cvg_id]. +exact: exprn_continuous. +Qed. + +Lemma onemXn_derivable {R : realType} n (x : R) : + derivable (fun y : R^o => `1-y ^+ n : R^o)%R x 1. +Proof. +have := @derivableX R R^o (@onem R) n x 1%R. +rewrite fctE. +apply. +exact: derivableB. +Qed. + +Lemma derivable_oo_continuous_bnd_onemXnMr {R : realType} (n : nat) (r : R) : + derivable_oo_continuous_bnd (fun x => `1-x ^+ n * r : R^o) 0 1. +Proof. +split. +- by move=> x x01; apply: derivableM => //=; exact: onemXn_derivable. +- apply: cvgM; last exact: cvg_cst. + apply: cvg_at_right_filter. + apply: (@cvg_comp _ _ _ (fun x => `1-x) (fun x => x ^+ n)). + by apply: (@cvgB _ R^o); [exact: cvg_cst|exact: cvg_id]. + exact: exprn_continuous. +- apply: cvg_at_left_filter. + apply: cvgM; last exact: cvg_cst. + apply: (@cvg_comp _ _ _ (fun x => `1-x) (fun x => x ^+ n)). + by apply: (@cvgB _ R^o); [exact: cvg_cst|exact: cvg_id]. + exact: exprn_continuous. +Qed. + +Lemma derive_onemXn {R : realType} (n : nat) x : + (fun y => `1-y ^+ n : R^o)^`()%classic x = - n%:R * `1-x ^+ n.-1. +Proof. +rewrite (@derive1_comp _ (@onem R) (fun x => x ^+ n))//; last first. + exact: exprn_derivable. +rewrite derive1E exp_derive// derive1E deriveB// -derive1E. +by rewrite derive1_cst derive_id sub0r mulrN1 [in RHS]mulNr scaler1. +Qed. + +Lemma Rintegral_onemXn {R : realType} n : + \int[lebesgue_measure]_(x in `[0, 1]) (`1-x ^+ n) = n.+1%:R^-1 :> R. +Proof. +rewrite /Rintegral. +rewrite (@continuous_FTC2 _ _ (fun x => `1-x ^+ n.+1 / - n.+1%:R))//=. +- rewrite onem1 expr0n/= mul0r onem0 expr1n mul1r sub0r. + by rewrite -invrN -2!mulNrn opprK. +- apply: continuous_in_subspaceT. +- by move=> x x01; exact: continuous_onemXn. +- exact: derivable_oo_continuous_bnd_onemXnMr. +- move=> x x01. + rewrite derive1Mr//; last exact: onemXn_derivable. + by rewrite derive_onemXn mulrAC divff// mul1r. Qed. + +End about_onemXn. + +(* TODO: move to derive.v *) +Lemma derive1_onem {R : numFieldType} : + (fun x => 1 - x : R^o)^`()%classic = cst (-1). +Proof. +by apply/funext => x; rewrite derive1E deriveB// derive_id derive_cst sub0r. +Qed. + +(* TODO: move to ftc.v *) +Section integration_by_substitution_onem. +Context {R : realType}. +Let mu := (@lebesgue_measure R). +Local Open Scope ereal_scope. + +Lemma integration_by_substitution_onem (G : R -> R) (r : R) : + (0 < r <= 1)%R -> + {within `[0%R, r], continuous G} -> + (\int[mu]_(x in `[0%R, r]) (G x)%:E = + \int[mu]_(x in `[(1 - r)%R, 1%R]) (G (1 - x))%:E). +Proof. +move=> r01 cG. +have := @integration_by_substitution_decreasing R (fun x => 1 - x)%R G (1 - r) 1. +rewrite subKr subrr => -> //. +- by apply: eq_integral => x xr; rewrite !fctE derive1_onem opprK mulr1. +- by rewrite ltrBlDl ltrDr; case/andP : r01. +- by move=> x y _ _ xy; rewrite ler_ltB. +- by rewrite derive1_onem; move=> ? ?; exact: cvg_cst. +- by rewrite derive1_onem; exact: is_cvg_cst. +- by rewrite derive1_onem; exact: is_cvg_cst. +- split => /=. + + by move=> x xr1; exact: derivableB. + + apply: cvg_at_right_filter; rewrite subKr. + apply: (@continuous_comp_cvg _ R^o R^o _ (fun x => 1 - x)%R)=> //=. + by move=> x; apply: (@continuousB _ R^o) => //; exact: cvg_cst. + by under eq_fun do rewrite subKr; exact: cvg_id. + + by apply: cvg_at_left_filter; apply: (@cvgB _ R^o) => //; exact: cvg_cst. +Qed. + +Lemma Rintegration_by_substitution_onem (G : R -> R) (r : R) : + (0 < r <= 1)%R -> + {within `[0%R, r], continuous G} -> + (\int[mu]_(x in `[0%R, r]) (G x) = + \int[mu]_(x in `[(1 - r)%R, 1%R]) (G (1 - x)))%R. +Proof. +by move=> r01 cG; rewrite [in LHS]/Rintegral integration_by_substitution_onem. +Qed. + +End integration_by_substitution_onem. + +(**md about the function $x \mapsto x^a * (1 - x)^b$ *) +Section XMonemX. +Context {R : numDomainType}. +Implicit Type x : R. + +Definition XMonemX a b := fun x => x ^+ a * `1-x ^+ b. + +Lemma XMonemX_ge0 a b x : x \in `[0, 1] -> 0 <= XMonemX a b x. +Proof. +by rewrite in_itv=> /andP[? ?]; rewrite mulr_ge0 ?exprn_ge0 ?subr_ge0. +Qed. + +Lemma XMonemX_le1 a b x : x \in `[0, 1] -> XMonemX a b x <= 1. +Proof. +rewrite in_itv/= => /andP[t0 t1]. +by rewrite mulr_ile1// ?(exprn_ge0,onem_ge0,exprn_ile1,onem_le1). +Qed. + +Lemma XMonemX0 n x : XMonemX 0 n x = `1-x ^+ n. +Proof. by rewrite /XMonemX/= expr0 mul1r. Qed. + +Lemma XMonemX0' n x : XMonemX n 0 x = x ^+ n. +Proof. by rewrite /XMonemX/= expr0 mulr1. Qed. + +Lemma XMonemX00 x : XMonemX 0 0 x = 1. +Proof. by rewrite XMonemX0 expr0. Qed. + +Lemma XMonemXC a b x : XMonemX a b (1 - x) = XMonemX b a x. +Proof. by rewrite /XMonemX [in LHS]/onem opprB addrCA subrr addr0 mulrC. Qed. + +Lemma XMonemX_XMonemX a b a' b' x : + XMonemX a' b' x * XMonemX a b x = XMonemX (a + a') (b + b') x. +Proof. by rewrite mulrCA -mulrA -exprD mulrA -exprD (addnC b'). Qed. + +End XMonemX. + +Section XMonemX_realType. +Context {R : realType}. +Local Notation XMonemX := (@XMonemX R). + +Lemma continuous_XMonemX a b : continuous (XMonemX a b). +Proof. +by move=> x; apply: cvgM; [exact: exprn_continuous|exact: continuous_onemXn]. +Qed. + +Lemma within_continuous_XMonemX A a b : {within A, continuous (XMonemX a b)}. +Proof. +by apply: continuous_in_subspaceT => x _; exact: continuous_XMonemX. +Qed. + +Lemma measurable_XMonemX A a b : measurable_fun A (XMonemX a b). +Proof. +apply/measurable_funM => //; apply/measurable_funX => //. +exact: measurable_funB. +Qed. + +Lemma bounded_XMonemX a b : + [bounded XMonemX a b x : R^o | x in `[0, 1]%classic]. +Proof. +exists 1; split; [by rewrite num_real|move=> x x1 /= y y01]. +rewrite ger0_norm//; last by rewrite XMonemX_ge0. +move: y01; rewrite in_itv/= => /andP[? ?]. +rewrite (le_trans _ (ltW x1))// mulr_ile1 ?exprn_ge0//. +- by rewrite subr_ge0. +- by rewrite exprn_ile1. +- by rewrite exprn_ile1 ?subr_ge0// lerBlDl addrC -lerBlDl subrr. +Qed. + +Local Notation mu := lebesgue_measure. + +Lemma integrable_XMonemX a b : mu.-integrable `[0, 1] (EFin \o XMonemX a b). +Proof. +apply: continuous_compact_integrable => //; first exact: segment_compact. +by apply: continuous_in_subspaceT => x _; exact: continuous_XMonemX. +Qed. + +Lemma integrable_XMonemX01 a b : + mu.-integrable [set: R] (EFin \o XMonemX a.-1 b.-1 \_`[0,1]). +Proof. +rewrite -restrict_EFin; apply/integrable_restrict => //=. +by rewrite setTI; exact: integrable_XMonemX. +Qed. + +Lemma integral_XMonemX01 U a b : + (\int[mu]_(x in U) (XMonemX a b \_ `[0, 1] x)%:E = + \int[mu]_(x in U `&` `[0%R, 1%R]) (XMonemX a b x)%:E)%E. +Proof. +rewrite [RHS]integral_mkcondr /=; apply: eq_integral => x xU /=. +by rewrite restrict_EFin. +Qed. + +End XMonemX_realType. + +Section beta_fun. +Context {R : realType}. +Notation mu := (@lebesgue_measure _). +Local Open Scope ring_scope. +Local Notation XMonemX := (@XMonemX R). + +Definition beta_fun a b : R := \int[mu]_x (XMonemX a.-1 b.-1 \_`[0,1]) x. + +Lemma EFin_beta_fun a b : + ((beta_fun a b)%:E = \int[mu]_x (XMonemX a.-1 b.-1 \_`[0,1] x)%:E)%E. +Proof. +rewrite fineK//; apply: integral_fune_fin_num => //=. +under eq_fun. + move=> x. + rewrite /= -/((EFin \o ((XMonemX a.-1 b.-1) \_ _)) x) -restrict_EFin. + over. +by apply/integrable_restrict => //=; rewrite setTI; exact: integrable_XMonemX. +Qed. + +Lemma beta_fun_sym a b : beta_fun a b = beta_fun b a. +Proof. +rewrite -[LHS]Rintegral_mkcond Rintegration_by_substitution_onem//=. +- rewrite subrr -[RHS]Rintegral_mkcond; apply: eq_Rintegral => x x01. + by rewrite XMonemXC. +- by rewrite ltr01 lexx. +- exact: within_continuous_XMonemX. +Qed. + +Lemma beta_fun0 b : (0 < b)%N -> beta_fun 0 b = b%:R^-1. +Proof. +move=> b0; rewrite -[LHS]Rintegral_mkcond. +under eq_Rintegral do rewrite XMonemX0. +by rewrite Rintegral_onemXn// prednK. +Qed. + +Lemma beta_fun00 : beta_fun 0 0 = 1%R. +Proof. +rewrite -[LHS]Rintegral_mkcond. +under eq_Rintegral do rewrite XMonemX00. +rewrite Rintegral_cst//= mul1r lebesgue_measure_itv/= lte_fin ltr01. +by rewrite oppr0 adde0. +Qed. + +Lemma beta_fun1S b : beta_fun 1 b.+1 = b.+1%:R^-1. +Proof. +rewrite /beta_fun -Rintegral_mkcond. +under eq_Rintegral do rewrite XMonemX0. +by rewrite Rintegral_onemXn. +Qed. + +Lemma beta_fun11 : beta_fun 1 1 = 1. +Proof. by rewrite (beta_fun1S O) invr1. Qed. + +Lemma beta_funSSS a b : + beta_fun a.+2 b.+1 = a.+1%:R / b.+1%:R * beta_fun a.+1 b.+2. +Proof. +rewrite -[LHS]Rintegral_mkcond. +rewrite (@Rintegration_by_parts _ _ (fun x => `1-x ^+ b.+1 / - b.+1%:R) + (fun x => a.+1%:R * x ^+ a)); last 7 first. + exact: ltr01. + apply/continuous_subspaceT => x. + by apply: cvgM; [exact: cvg_cst|exact: exprn_continuous]. + split. + by move=> x x01; exact: exprn_derivable. + by apply: cvg_at_right_filter; exact: exprn_continuous. + by apply: cvg_at_left_filter; exact: exprn_continuous. + by move=> x x01; rewrite derive1E exp_derive scaler1. + by apply/continuous_subspaceT => x x01; exact: continuous_onemXn. + exact: derivable_oo_continuous_bnd_onemXnMr. + move=> x x01; rewrite derive1Mr; last exact: onemXn_derivable. + by rewrite derive_onemXn mulrAC divff// mul1r. +rewrite {1}/onem !(expr1n,mul1r,expr0n,subr0,subrr,mul0r,oppr0,sub0r)/=. +transitivity (a.+1%:R / b.+1%:R * \int[mu]_(x in `[0, 1]) XMonemX a b.+1 x :> R). + under [in LHS]eq_Rintegral. + move=> x x01. + rewrite mulrA mulrC mulrA (mulrA _ a.+1%:R) -(mulrA (_ * _)%R). + over. + rewrite /=. + rewrite RintegralZl//=; last exact: integrable_XMonemX. + by rewrite -mulNrn -2!mulNr -invrN -mulNrn opprK (mulrC _ a.+1%:R). +by rewrite Rintegral_mkcond. +Qed. + +Lemma beta_funSS a b : beta_fun a.+1 b.+1 = + a`!%:R / (\prod_(b.+1 <= i < (a + b).+1) i)%:R * beta_fun 1 (a + b).+1. +Proof. +elim: a b => [b|a ih b]. + by rewrite fact0 mul1r add0n /index_iota subnn big_nil invr1 mul1r. +rewrite beta_funSSS [in LHS]ih !mulrA; congr *%R; last by rewrite addSnnS. +rewrite -mulrA mulrCA 2!mulrA. +rewrite -natrM (mulnC a`!) -factS -mulrA -invfM; congr (_ / _). +rewrite big_add1 [in RHS]big_nat_recl/=; last by rewrite addSn ltnS leq_addl. +by rewrite -natrM addSnnS. +Qed. + +Lemma beta_fun_fact a b : + beta_fun a.+1 b.+1 = (a`! * b`!)%:R / (a + b).+1`!%:R. +Proof. +rewrite beta_funSS beta_fun1S natrM -!mulrA; congr *%R. +(* (b+1 b+2 ... b+1 b+a)^-1 / (a+b+1) = b! / (a+b+1)! *) +rewrite factS [in RHS]mulnC natrM invfM mulrA; congr (_ / _). +rewrite -(@invrK _ b`!%:R) -invfM; congr (_^-1). +apply: (@mulfI _ b`!%:R). + by rewrite gt_eqF// ltr0n fact_gt0. +rewrite mulrA divff// ?gt_eqF// ?ltr0n ?fact_gt0 ?mul1r//. +rewrite [in RHS]fact_prod -natrM; congr (_%:R). +rewrite fact_prod -big_cat/= /index_iota subn1 -iotaD. +by rewrite subSS addnK subn1 addnC. +Qed. + +Lemma beta_funE a b : beta_fun a b = + if (a == 0)%N && (0 < b)%N then + b%:R^-1 + else if (b == 0)%N && (0 < a)%N then + a%:R^-1 + else + a.-1`!%:R * b.-1`!%:R / (a + b).-1`!%:R. +Proof. +case: a => [|a]. + rewrite eqxx/=; case: ifPn => [|]. + by case: b => [|b _] //; rewrite beta_fun0. + rewrite -leqNgt leqn0 => /eqP ->. + by rewrite beta_fun00 eqxx ltnn/= fact0 mul1r divr1. +case: b => [|b]. + by rewrite beta_fun_sym beta_fun0// fact0 addn0/= mulr1 divff. +by rewrite beta_fun_fact/= natrM// addnS. +Qed. + +Lemma beta_fun_gt0 a b : 0 < beta_fun a b. +Proof. +rewrite beta_funE. +case: ifPn => [/andP[_ b0]|]; first by rewrite invr_gt0 ltr0n. +rewrite negb_and => /orP[a0|]. + case: ifPn => [/andP[_]|]; first by rewrite invr_gt0// ltr0n. + rewrite negb_and => /orP[b0|]. + by rewrite divr_gt0// ?mulr_gt0 ?ltr0n ?fact_gt0. + by rewrite -leqNgt leqn0 (negbTE a0). +rewrite -leqNgt leqn0 => /eqP ->; rewrite eqxx/=. +case: ifPn; first by rewrite invr_gt0 ltr0n. +by rewrite -leqNgt leqn0 => /eqP ->; rewrite fact0 mul1r divr1. +Qed. + +Lemma beta_fun_ge0 a b : 0 <= beta_fun a b. +Proof. exact/ltW/beta_fun_gt0. Qed. + +End beta_fun. + +Section beta_pdf. +Local Open Scope ring_scope. +Context {R : realType}. +Variables a b : nat. + +Local Notation XMonemX := (@XMonemX R). + +Definition beta_pdf t : R := XMonemX a.-1 b.-1 \_`[0,1] t / beta_fun a b. + +Lemma measurable_beta_pdf : measurable_fun setT beta_pdf. +Proof. +apply: measurable_funM => //; apply/measurable_restrict => //. +by rewrite setTI; exact: measurable_XMonemX. +Qed. + +Lemma beta_pdf_ge0 t : 0 <= beta_pdf t. +Proof. +rewrite /beta_pdf divr_ge0 ?beta_fun_ge0//. +rewrite patchE; case: ifPn => //=. +by rewrite inE/= => ?; exact: XMonemX_ge0. +Qed. + +Lemma beta_pdf_le_beta_funV x : beta_pdf x <= (beta_fun a b)^-1. +Proof. +rewrite /beta_pdf ler_pdivrMr ?beta_fun_gt0// mulVf ?gt_eqF ?beta_fun_gt0//. +by rewrite patchE; case: ifPn => //; rewrite inE/= => ?; exact: XMonemX_le1. +Qed. + +Local Notation mu := lebesgue_measure. + +Lemma integrable_beta_pdf : mu.-integrable [set: _] (EFin \o beta_pdf). +Proof. +apply/integrableP; split. + by apply/measurable_EFinP; exact: measurable_beta_pdf. +under eq_integral. + move=> /= x _. + rewrite ger0_norm//; last by rewrite beta_pdf_ge0. + over. +rewrite /=. +apply: (@le_lt_trans _ _ (\int[mu]_(x in `[0%R, 1%R]) (beta_fun a b)^-1%:E)%E). + rewrite [in leRHS]integral_mkcond/=. + apply: ge0_le_integral => //=. + - by move=> x _; rewrite lee_fin beta_pdf_ge0. + - by apply/measurable_funTS/measurable_EFinP => /=; exact: measurable_beta_pdf. + - move=> x _; rewrite patchE; case: ifPn => // _. + by rewrite lee_fin invr_ge0// beta_fun_ge0. + - exact/measurable_restrict. + - move=> x _. + rewrite patchE; case: ifPn => x01. + by rewrite lee_fin beta_pdf_le_beta_funV. + by rewrite /beta_pdf patchE (negbTE x01) mul0r. +rewrite integral_cst//= lebesgue_measure_itv//=. +by rewrite lte01 oppr0 adde0 mule1 ltry. +Qed. + +Lemma bounded_beta_pdf_01 : + [bounded beta_pdf x : R^o | x in `[0%R, 1%R]%classic : set R]. +Proof. +exists (beta_fun a b)^-1; split; first by rewrite num_real. +move=> // y y1. +near=> M => /=. +rewrite (le_trans _ (ltW y1))//. +near: M. +move=> M /=. +rewrite in_itv/= => /andP[M0 M1]. +rewrite ler_norml; apply/andP; split. + rewrite lerNl (@le_trans _ _ 0%R)// ?invr_ge0 ?beta_fun_ge0//. + by rewrite lerNl oppr0 beta_pdf_ge0. +rewrite /beta_pdf ler_pdivrMr ?beta_fun_gt0//. +rewrite mulVf ?gt_eqF ?beta_fun_gt0//. +by rewrite patchE; case: ifPn => //; rewrite inE => ?; exact: XMonemX_le1. +Unshelve. all: by end_near. Qed. + +End beta_pdf. + +(* TODO: move *) +Lemma invr_nonneg_proof (R : numDomainType) (p : {nonneg R}) : + (0 <= (p%:num)^-1)%R. +Proof. by rewrite invr_ge0. Qed. + +Definition invr_nonneg (R : numDomainType) (p : {nonneg R}) := + NngNum (invr_nonneg_proof p). +(* /TODO: move *) + +Section beta. +Local Open Scope ring_scope. +Context {R : realType}. +Variables a b : nat. + +Local Notation mu := (@lebesgue_measure R). +Local Notation XMonemX := (@XMonemX R). + +Let beta_num (U : set _) : \bar R := + \int[mu]_(x in U) (XMonemX a.-1 b.-1 \_`[0,1] x)%:E. + +Let beta_numT : beta_num [set: _] = (beta_fun a b)%:E. +Proof. by rewrite /beta_num/= EFin_beta_fun. Qed. + +Let beta_num_lty U : measurable U -> (beta_num U < +oo)%E. +Proof. +move=> mU. +apply: (@le_lt_trans _ _ (\int[mu]_(x in U `&` `[0%R, 1%R]) 1)%E); last first. + rewrite integral_cst//= ?mul1e//. + rewrite (le_lt_trans (measureIr _ _ _))//= lebesgue_measure_itv//= lte01//. + by rewrite EFinN sube0 ltry. + exact: measurableI. +rewrite /beta_num integral_XMonemX01 ge0_le_integral//=. +- exact: measurableI. +- by move=> x [_ ?]; rewrite lee_fin XMonemX_ge0. +- by apply/measurable_funTS/measurableT_comp => //; exact: measurable_XMonemX. +- by move=> x [_ ?]; rewrite lee_fin XMonemX_le1. +Qed. + +Let beta_num0 : beta_num set0 = 0%:E. +Proof. by rewrite /beta_num integral_set0. Qed. + +Let beta_num_ge0 U : (0 <= beta_num U)%E. +Proof. +rewrite /beta_num integral_ge0//= => x Ux; rewrite lee_fin. +by rewrite patchE; case: ifPn => //; rewrite inE/= => x01; exact: XMonemX_ge0. +Qed. + +Let beta_num_sigma_additive : semi_sigma_additive beta_num. +Proof. +move=> /= F mF tF mUF; rewrite /beta_num; apply: cvg_toP. + apply: ereal_nondecreasing_is_cvgn => m n mn. + apply: lee_sum_nneg_natr => // k _ _; apply: integral_ge0 => /= x Fkx. + rewrite patchE; case: ifPn => //; rewrite inE/= => x01. + by rewrite lee_fin XMonemX_ge0. +rewrite ge0_integral_bigcup//=. +- apply/measurable_funTS/measurableT_comp => //=. + by apply/measurable_restrict => //=; rewrite setTI; exact: measurable_XMonemX. +- move=> x [? _ ?]; rewrite lee_fin. + by rewrite patchE; case: ifPn => //; rewrite inE/= => x0; exact: XMonemX_ge0. +Qed. + +HB.instance Definition _ := isMeasure.Build _ _ _ beta_num + beta_num0 beta_num_ge0 beta_num_sigma_additive. + +Definition beta_prob := + @mscale _ _ _ (invr_nonneg (NngNum (beta_fun_ge0 a b))) beta_num. + +HB.instance Definition _ := Measure.on beta_prob. + +Let beta_prob_setT : beta_prob setT = 1%:E. +Proof. +rewrite /beta_prob /= /mscale /= beta_numT. +by rewrite -EFinM mulVf// gt_eqF// beta_fun_gt0. +Qed. + +HB.instance Definition _ := + @Measure_isProbability.Build _ _ _ beta_prob beta_prob_setT. + +Lemma integral_beta_pdf U : measurable U -> + (\int[mu]_(x in U) (beta_pdf a b x)%:E = beta_prob U :> \bar R)%E. +Proof. +move=> mU. +rewrite /beta_pdf. +under eq_integral do rewrite EFinM/=. +rewrite ge0_integralZr//=. +- by rewrite /beta_prob/= /mscale/= muleC. +- apply/measurable_funTS/measurableT_comp => //. + by apply/measurable_restrict => //=; rewrite setTI; exact: measurable_XMonemX. +- move=> x Ux; rewrite patchE; case: ifPn => //; rewrite inE/= => x01. + by rewrite lee_fin XMonemX_ge0. +- by rewrite lee_fin invr_ge0// beta_fun_ge0. +Qed. + +Lemma beta_prob01 : beta_prob `[0, 1] = 1%:E. +Proof. +rewrite /beta_prob /= /mscale/= /beta_num. +rewrite (_ : integral _ _ _ = (beta_fun a b)%:E); last first. + rewrite fineK; last first. + by apply: integral_fune_fin_num => //; exact: integrable_XMonemX01. + rewrite [LHS]integral_mkcond/=. + apply: eq_integral => /= x _. + by rewrite !patchE; case: ifPn => // ->. +by rewrite -EFinM mulVf// gt_eqF// beta_fun_gt0. +Qed. + +Lemma beta_prob_fin_num U : measurable U -> beta_prob U \is a fin_num. +Proof. +move=> mU; rewrite ge0_fin_numE//. +rewrite /beta_prob/= /mscale/= /beta_num lte_mul_pinfty//. + by rewrite lee_fin// invr_ge0 beta_fun_ge0. +apply: (@le_lt_trans _ _ (beta_fun a b)%:E). + rewrite EFin_beta_fun; apply: ge0_subset_integral => //=. + apply/measurable_EFinP; apply/measurable_restrict => //=. + by rewrite setTI; exact: measurable_XMonemX. + move=> x _; rewrite patchE; case: ifPn => //; rewrite inE/= => x01. + by rewrite lee_fin XMonemX_ge0. +by rewrite ltry. +Qed. + +Lemma beta_prob_dom : beta_prob `<< mu. +Proof. +move=> A mA muA0; rewrite /beta_prob /mscale/=. +apply/eqP; rewrite mule_eq0 eqe invr_eq0 gt_eqF/= ?beta_fun_gt0//; apply/eqP. +rewrite /beta_num integral_XMonemX01. +apply/eqP; rewrite eq_le; apply/andP; split; last first. + by apply: integral_ge0 => x [_ ?]; rewrite lee_fin XMonemX_ge0. +apply: (@le_trans _ _ (\int[mu]_(x in A `&` `[0%R, 1%R]) 1)%E); last first. + rewrite integral_cst ?mul1e//=; last exact: measurableI. + by rewrite -[leRHS]muA0 measureIl. +apply: ge0_le_integral => //=; first exact: measurableI. +- by move=> x [_ x01]; rewrite lee_fin XMonemX_ge0. +- by apply/measurable_funTS/measurableT_comp => //; exact: measurable_XMonemX. +- by move=> x [_ ?]; rewrite lee_fin XMonemX_le1. +Qed. + +End beta. +Arguments beta_prob {R}. + +Lemma beta_prob_uniform {R : realType} : beta_prob 1 1 = uniform_prob (@ltr01 R). +Proof. +apply/funext => U. +rewrite /beta_prob /uniform_prob. +rewrite /mscale/= beta_fun11 invr1 !mul1e. +rewrite integral_XMonemX01 integral_uniform_pdf. +apply: eq_integral => /= x. +rewrite inE => -[Ux/=]; rewrite in_itv/= => x10. +rewrite /XMonemX !expr0 mul1r. +by rewrite /uniform_pdf x10 subr0 invr1. +Qed. + +Lemma integral_beta_prob_bernoulli_prob_lty {R : realType} a b (f : R -> R) U : + measurable_fun setT f -> + (forall x, x \in `[0%R, 1%R] -> 0 <= f x <= 1) -> + (\int[beta_prob a b]_x `|bernoulli_prob (f x) U| < +oo :> \bar R)%E. +Proof. +move=> mf /= f01. +apply: (@le_lt_trans _ _ (\int[beta_prob a b]_x cst 1 x))%E. + apply: ge0_le_integral => //=. + apply: measurableT_comp => //=. + by apply: (measurableT_comp (measurable_bernoulli_prob2 _)). + by move=> x _; rewrite gee0_abs// probability_le1. +by rewrite integral_cst//= mul1e -ge0_fin_numE// beta_prob_fin_num. +Qed. + +Lemma integral_beta_prob_bernoulli_prob_onemX_lty {R : realType} n a b U : + (\int[beta_prob a b]_x `|bernoulli_prob (`1-x ^+ n) U| < +oo :> \bar R)%E. +Proof. +apply: integral_beta_prob_bernoulli_prob_lty => //=. + by apply: measurable_funX => //; exact: measurable_funB. +move=> x; rewrite in_itv/= => /andP[x0 x1]. +rewrite exprn_ge0 ?subr_ge0//= exprn_ile1// ?subr_ge0//. +by rewrite lerBlDl -lerBlDr subrr. +Qed. + +Lemma integral_beta_prob_bernoulli_prob_onem_lty {R : realType} n a b U : + (\int[beta_prob a b]_x `|bernoulli_prob (1 - `1-x ^+ n) U| < +oo :> \bar R)%E. +Proof. +apply: integral_beta_prob_bernoulli_prob_lty => //=. + apply: measurable_funB => //. + by apply: measurable_funX => //; exact: measurable_funB. +move=> x; rewrite in_itv/= => /andP[x0 x1]. +rewrite -lerBlDr opprK add0r andbC lerBlDl -lerBlDr subrr. +rewrite exprn_ge0 ?subr_ge0//= exprn_ile1// ?subr_ge0//. +by rewrite lerBlDl -lerBlDr subrr. +Qed. + +Lemma beta_prob_integrable {R : realType} a b a' b' : + (beta_prob a b).-integrable `[0, 1] (fun x : R => (XMonemX a' b' x)%:E). +Proof. +apply/integrableP; split. + by apply/measurableT_comp => //; exact: measurable_XMonemX. +apply: (@le_lt_trans _ _ (\int[beta_prob a b]_(x in `[0%R, 1%R]) 1)%E). + apply: ge0_le_integral => //=. + by do 2 apply/measurableT_comp => //; exact: measurable_XMonemX. + move=> x; rewrite in_itv/= => /andP[x0 x1]. + rewrite lee_fin ger0_norm; last first. + by rewrite !mulr_ge0// exprn_ge0// onem_ge0. + by rewrite mulr_ile1// ?exprn_ge0 ?onem_ge0// exprn_ile1// ?onem_ge0// onem_le1. +rewrite integral_cst//= mul1e. +by rewrite -ge0_fin_numE// beta_prob_fin_num. +Qed. + +Lemma beta_prob_integrable_onem {R : realType} a b a' b' : + (beta_prob a b).-integrable `[0, 1] + (fun x : R => (`1-(XMonemX a' b' x))%:E). +Proof. +apply: (eq_integrable _ (cst 1 \- (fun x : g_sigma_algebraType (R.-ocitv.-measurable) => + (XMonemX a' b' x)%:E))%E) => //. +apply: (@integrableB _ (g_sigma_algebraType R.-ocitv.-measurable)) => //=. + apply/integrableP; split => //. + rewrite (eq_integral (fun x => (\1_setT x)%:E))/=; last first. + by move=> x _; rewrite /= indicT normr1. + rewrite integral_indic//= setTI /beta_prob /mscale/= lte_mul_pinfty//. + by rewrite lee_fin invr_ge0 beta_fun_ge0. + rewrite (_ : integral _ _ _ = \int[lebesgue_measure]_x + (((@XMonemX R a.-1 b.-1) \_ `[0, 1]) x)%:E)%E; last first. + rewrite integral_mkcond/=; apply: eq_integral => /= x _. + by rewrite !patchE; case: ifPn => // ->. + have /integrableP[_] := @integrable_XMonemX01 R a b. + under eq_integral. + move=> x _. + rewrite gee0_abs//; last first. + rewrite lee_fin patchE; case: ifPn => //; rewrite inE/= => x01. + by rewrite XMonemX_ge0. + over. + by []. +exact: beta_prob_integrable. +Qed. + +Lemma beta_prob_integrable_dirac {R : realType} a b a' b' (c : bool) U : + (beta_prob a b).-integrable `[0, 1] + (fun x : R => (XMonemX a' b' x)%:E * \d_c U)%E. +Proof. +apply: integrableMl => //=; last first. + exists 1; split => // x x1/= _ _; rewrite (le_trans _ (ltW x1))//. + by rewrite ger0_norm// indicE; case: (_ \in _). +exact: beta_prob_integrable. +Qed. + +Lemma beta_prob_integrable_onem_dirac {R : realType} a b a' b' (c : bool) U : + (beta_prob a b).-integrable `[0, 1] + (fun x : R => (`1-(XMonemX a' b' x))%:E * \d_c U)%E. +Proof. +apply: integrableMl => //=; last first. + exists 1; split => // x x1/= _ _; rewrite (le_trans _ (ltW x1))//. + by rewrite ger0_norm// indicE; case: (_ \in _). +exact: beta_prob_integrable_onem. +Qed. + +Section integral_beta_prob. +Context {R : realType}. +Local Notation mu := (@lebesgue_measure R). + +Lemma integral_beta_prob a b f U : measurable U -> measurable_fun U f -> + (\int[beta_prob a b]_(x in U) `|f x| < +oo)%E -> + (\int[beta_prob a b]_(x in U) f x = \int[mu]_(x in U) (f x * (beta_pdf a b x)%:E))%E. +Proof. +move=> mU mf finf. +rewrite -(Radon_Nikodym_change_of_variables (beta_prob_dom a b)) //=; last first. + by apply/integrableP; split. +apply: ae_eq_integral => //. +- apply: emeasurable_funM => //; apply: (measurable_int mu). + apply: (integrableS _ _ (@subsetT _ _)) => //=. + by apply: Radon_Nikodym_integrable; exact: beta_prob_dom. +- apply: emeasurable_funM => //=; apply/measurableT_comp => //=. + by apply/measurable_funTS; exact: measurable_beta_pdf. +- apply: ae_eqe_mul2l => /=. + rewrite Radon_NikodymE//=; first exact: beta_prob_dom. + move=> ?. + case: cid => /= h [h1 h2 h3]. +(* uniqueness of Radon-Nikodym derivertive up to equal on non null sets of mu *) + apply: integral_ae_eq => //. + + apply: integrableS h2 => //. (* integrableST? *) + apply/measurable_funTS/measurableT_comp => //. + exact: measurable_beta_pdf. + + by move=> E E01 mE; rewrite -h3//= integral_beta_pdf. +Qed. + +End integral_beta_prob. + +(* TODO: move *) +Lemma leq_prod2 (x y n m : nat) : (n <= x)%N -> (m <= y)%N -> + (\prod_(m <= i < y) i * \prod_(n <= i < x) i <= \prod_(n + m <= i < x + y) i)%N. +Proof. +move=> nx my; rewrite big_addn -addnBA//. +rewrite [in leqRHS]/index_iota -addnBAC// iotaD big_cat/=. +rewrite mulnC leq_mul//. + by apply: leq_prod; move=> i _; rewrite leq_addr. +rewrite subnKC//. +rewrite -[in leqLHS](add0n m) big_addn. +rewrite [in leqRHS](_ : y - m = ((y - m + x) - x))%N; last first. + by rewrite -addnBA// subnn addn0. +rewrite -[X in iota X _](add0n x) big_addn -addnBA// subnn addn0. +by apply: leq_prod => i _; rewrite leq_add2r leq_addr. +Qed. + +Lemma leq_fact2 (x y n m : nat) : (n <= x) %N -> (m <= y)%N -> + (x`! * y`! * ((n + m).+1)`! <= n`! * m`! * ((x + y).+1)`!)%N. +Proof. +move=> nx my. +rewrite (fact_split nx) -!mulnA leq_mul2l; apply/orP; right. +rewrite (fact_split my) mulnCA -!mulnA leq_mul2l; apply/orP; right. +rewrite [leqRHS](_ : _ = (n + m).+1`! * \prod_((n + m).+2 <= i < (x + y).+2) i)%N; last first. + by rewrite -fact_split// ltnS leq_add. +rewrite mulnA mulnC leq_mul2l; apply/orP; right. +do 2 rewrite -addSn -addnS. +exact: leq_prod2. +Qed. + +Lemma normr_onem {R : realType} (x : R) : (0 <= x <= 1 -> `| `1-x | <= 1)%R. +Proof. +move=> /andP[x0 x1]; rewrite ler_norml; apply/andP; split. + by rewrite lerBrDl lerBlDr (le_trans x1)// lerDl. +by rewrite lerBlDr lerDl. +Qed. +(* /TODO: move *) + +Section beta_prob_bernoulliE. +Context {R : realType}. +Local Notation mu := lebesgue_measure. +Local Open Scope ring_scope. + +Definition div_beta_fun a b c d : R := beta_fun (a + c) (b + d) / beta_fun a b. + +Lemma div_beta_fun_ge0 a b c d : 0 <= div_beta_fun a b c d. +Proof. by rewrite /div_beta_fun divr_ge0// beta_fun_ge0. Qed. + +Lemma div_beta_fun_le1 a b c d : (0 < a)%N -> (0 < b)%N -> + div_beta_fun a b c d <= 1. +Proof. +move=> a0 b0. +rewrite /div_beta_fun ler_pdivrMr// ?mul1r ?beta_fun_gt0//. +rewrite !beta_funE. +rewrite addn_eq0 (gtn_eqF a0)/=. +rewrite addn_eq0 (gtn_eqF b0)/=. +rewrite ler_pdivrMr ?ltr0n ?fact_gt0//. +rewrite mulrAC. +rewrite ler_pdivlMr ?ltr0n ?fact_gt0//. +rewrite -!natrM ler_nat. +move: a a0 => [//|a _]. +rewrite addSn. +move: b b0 => [//|b _]. +rewrite [(a + c).+1.-1]/=. +rewrite [a.+1.-1]/=. +rewrite [b.+1.-1]/=. +rewrite addnS. +rewrite [(_ + b).+1.-1]/=. +rewrite (addSn b d). +rewrite [(b + _).+1.-1]/=. +rewrite (addSn (a + c)). +rewrite [_.+1.-1]/=. +rewrite addSn addnS. +by rewrite leq_fact2// leq_addr. +Qed. + +Definition beta_prob_bernoulli_prob a b c d U : \bar R := + \int[beta_prob a b]_(y in `[0, 1]) + bernoulli_prob ((@XMonemX R c d \_`[0, 1])%R y) U. + +Lemma beta_prob_bernoulli_probE a b c d U : (a > 0)%N -> (b > 0)%N -> + beta_prob_bernoulli_prob a b c d U = bernoulli_prob (div_beta_fun a b c d) U. +Proof. +move=> a0 b0. +rewrite /beta_prob_bernoulli_prob. +under eq_integral => x. + rewrite inE/= in_itv/= => x01. + rewrite bernoulli_probE/=; last first. + rewrite patchE; case: ifPn => x01'. + by rewrite XMonemX_ge0//= XMonemX_le1. + by rewrite lexx ler01. + over. +rewrite /= [in RHS]bernoulli_probE/= ?div_beta_fun_ge0 ?div_beta_fun_le1//=. +under eq_integral => x x01. + rewrite patchE x01/=. + over. +rewrite /= integralD//=; last 2 first. + exact: beta_prob_integrable_dirac. + exact: beta_prob_integrable_onem_dirac. +congr (_ + _). + rewrite integralZr//=; last exact: beta_prob_integrable. + congr (_ * _)%E. + rewrite integral_beta_prob//; last 2 first. + by apply/measurableT_comp => //; exact: measurable_XMonemX. + by have /integrableP[_] := @beta_prob_integrable R a b c d. + rewrite /beta_pdf. + under eq_integral do rewrite EFinM -muleA muleC -muleA. + rewrite /=. + transitivity ((beta_fun a b)^-1%:E * \int[mu]_(x in `[0%R, 1%R]) + (@XMonemX R (a + c).-1 (b + d).-1 \_`[0,1] x)%:E)%E. + rewrite -integralZl//=; last first. + by apply: (integrableS measurableT) => //=; exact: integrable_XMonemX01. + apply: eq_integral => x x01. + (* TODO: lemma? property of XMonemX? *) + rewrite muleA muleC muleA -(EFinM (x ^+ c)) -/(XMonemX c d x) -EFinM mulrA. + rewrite !patchE x01 XMonemX_XMonemX// -EFinM mulrC. + by move: a a0 b b0 => [//|a] _ [|b]. + rewrite /div_beta_fun mulrC EFinM; congr (_ * _)%E. + rewrite /beta_fun integral_mkcond/= fineK; last first. + by apply: integral_fune_fin_num => //; exact: integrable_XMonemX01. + by apply: eq_integral => /= x _; rewrite !patchE; case: ifPn => // ->. +under eq_integral do rewrite muleC. +rewrite /= integralZl//=; last exact: beta_prob_integrable_onem. +rewrite muleC; congr (_ * _)%E. +rewrite integral_beta_prob//=; last 2 first. + apply/measurableT_comp => //=. + by apply/measurable_funB => //; exact: measurable_XMonemX. + by have /integrableP[] := @beta_prob_integrable_onem R a b c d. +rewrite /beta_pdf. +under eq_integral do rewrite EFinM muleA. +rewrite integralZr//=; last first. + apply: integrableMr => //=. + - by apply/measurable_funB => //=; exact: measurable_XMonemX. + - apply/ex_bound => //. + + apply: (@globally_properfilter _ _ 0%R) => //=. + by apply: inferP; rewrite in_itv/= lexx ler01. + + exists 1 => t. + rewrite /= in_itv/= => t01. + apply: normr_onem; apply/andP; split. + by rewrite mulr_ge0// exprn_ge0// ?onem_ge0//; case/andP: t01. + by rewrite mulr_ile1// ?exprn_ge0 ?exprn_ile1// ?onem_ge0 ?onem_le1//; + case/andP: t01. + - exact: integrableS (integrable_XMonemX01 _ _). +transitivity ((\int[mu]_x ((@XMonemX R a.-1 b.-1 \_`[0,1] x)%:E - + (@XMonemX R (a + c).-1 (b + d).-1 \_`[0,1] x)%:E)) * (beta_fun a b)^-1%:E)%E. + congr (_ * _)%E; rewrite integral_mkcond/=; apply: eq_integral => x _. + rewrite !patchE; case: ifPn => [->|]; last by rewrite EFinN subee. + rewrite /onem -EFinM mulrBl mul1r EFinB EFinN; congr (_ - _)%E. + rewrite XMonemX_XMonemX. + by move: a a0 b b0 => [|a]// _ [|b]. +rewrite integralB_EFin//=; last 2 first. + exact: integrableS (integrable_XMonemX01 _ _). + exact: integrableS (integrable_XMonemX01 _ _). +rewrite EFinB muleBl//; last by rewrite -!EFin_beta_fun. +by rewrite -!EFin_beta_fun -EFinM divff// gt_eqF// beta_fun_gt0. +Qed. + +End beta_prob_bernoulliE.