Re: Problem applying generics to my code. Is there a better solution?
Daniel,
Thank to you suggestion, I have a functional prototype that I've
included in this post. One question: In order for each class to
dispatch properly, *every* class has to implement the visit() method.
If there are subclasses that inherit from a parent class, the generic
posterior() class in the abstract Distribution class is invoked which
leads to an infinite loop.
It's just syntactic sugar, but is there a way to get around this
problem? The only solution I could come up with is that the abstract
class would have to implement a posterior() method for every class
that extends Distribution, which is not feasible since new
Distribution sub-classes will be created.
Anyway, thanks for the insight. Couldn't have gotten to this point
without eveyone's help.
P.S> The semantics of the Distribution.posterior() method is that
"this" is the prior distribution and we are attempting to combine it
with a likelihood. This has a nice side effect that every posterior()
method in a class returns a Distribution<T> object rather than the
other way around of treating "this" as the likelihood and have to take
and return a Distribution<?> object.
/**
* Experiment with a conjugate class hierarchy
*/
import java.util.*;
public class Conjugacy
{
static abstract class Distribution<T> {
public <S> Distribution<T> posterior( List<S> data,
Distribution<S> likelihood ) {
return likelihood.dispatch( data, this );
}
public Distribution<T> posterior( List<Double> data,
NormalUnknownMean likelihood ) {
System.out.println( "FAIL" );
return null;
}
public Distribution<T> posterior( List<Double> data,
NormalUnknownVariance likelihood ) {
System.out.println( "FAIL" );
return null;
}
public Distribution<T> posterior( List<Integer> data,
PoissonUnknownRate likelihood ) {
System.out.println( "FAIL" );
return null;
}
public abstract <S> Distribution<S> dispatch( List<T> data,
Distribution<S> prior );
}
/**
* Create a normal, gamma and poisson distributions
*/
static class Normal extends Distribution<Double> {
public Distribution<Double> posterior( List<Double> data,
NormalUnknownMean likelihood ) {
System.out.println( "mu ~ Normal ~ Normal-Normal" );
return new Normal();
}
public <S> Distribution<S> dispatch( List<Double> data,
Distribution<S> prior ) {
System.out.println( "FAIL" );
return null;
}
}
static class NormalUnknownMean extends Normal {
public <S> Distribution<S> dispatch( List<Double> data,
Distribution<S> prior ) {
return prior.posterior( data, this );
}
}
static class NormalUnknownVariance extends Normal {
public <S> Distribution<S> dispatch( List<Double> data,
Distribution<S> prior ) {
return prior.posterior( data, this );
}
}
// Gamma distribution
static class Gamma extends Distribution<Double> {
public Distribution<Double> posterior( List<Double> data,
NormalUnknownVariance likelihood ) {
System.out.println( "sigma ~ Gamma ~ Normal-Gamma" );
return new Gamma();
}
public Distribution<Double> posterior( List<Integer> data,
PoissonUnknownRate likelihood ) {
System.out.println( "lambda ~ Gamma ~ Poisson-Gamma" );
return new Gamma();
}
public <S> Distribution<S> dispatch( List<Double> data,
Distribution<S> prior ) {
System.out.println( "FAIL" );
return null;
}
}
// Poisson distribution
static class Poisson extends Distribution<Integer> {
public <S> Distribution<S> dispatch( List<Integer> data,
Distribution<S> prior ) {
System.out.println( "FAIL" );
return null;
}
}
static class PoissonUnknownRate extends Poisson {
public <S> Distribution<S> dispatch( List<Integer> data,
Distribution<S> prior ) {
return prior.posterior( data, this );
}
}
public static void main( String[] args )
{
Distribution<Double> normal = new Normal();
Distribution<Double> gamma = new Gamma();
Distribution<Integer> poisson = new Poisson();
Distribution<Double> normalMu = new NormalUnknownMean();
Distribution<Double> normalSig = new NormalUnknownVariance();
Distribution<Integer> poissonRate = new PoissonUnknownRate();
List<Double> data1 = new ArrayList<Double>();
List<Integer> data2 = new ArrayList<Integer>();
System.out.println( "Using a Normal as a prior distribution" );
normal.posterior( data1, normal ); // Fail
normal.posterior( data1, gamma ); // Fail
normal.posterior( data2, poisson ); // Fail
normal.posterior( data1, normalMu ); // Pass
normal.posterior( data1, normalSig ); // Fail
normal.posterior( data2, poissonRate ); // Fail
System.out.println();
System.out.println( "Using a Gamma as a prior distribution" );
gamma.posterior( data1, normal ); // Fail
gamma.posterior( data1, gamma ); // Fail
gamma.posterior( data2, poisson ); // Fail
gamma.posterior( data1, normalMu ); // Fail
gamma.posterior( data1, normalSig ); // Pass
gamma.posterior( data2, poissonRate ); // Pass
}
}