Machine learning for Java developers

Set up a machine learning algorithm and develop your first prediction function in Java

1 2 3 Page 3
Page 3 of 3

Recall that in order to train a target function, you have to first choose the machine learning algorithm. In the code below, an instance of the LinearRegression classifier will be created. This classifier will be train by calling the buildClassifier(). The buildClassifier() method tunes the theta parameters based on the training data to find the best-fitting model. Using Weka, you do not have to worry about setting a learning rate or iteration count. Weka also does the feature scaling internally.

Classifier targetFunction = new LinearRegression();

Once it's established, the target function can be used to predict the price of a house, as shown below:

Instances unlabeledInstances = new Instances("predictionset", attributes, 1);
unlabeledInstances.setClassIndex(trainingSet.numAttributes() - 1);
Instance unlabeled = new DenseInstance(3);
unlabeled.setValue(sizeAttribute, 1330.0);
unlabeled.setValue(squaredSizeAttribute, Math.pow(1330.0, 2));

double prediction  = targetFunction.classifyInstance(unlabeledInstances.get(0));

Weka provides an Evaluation class to validate the trained classifier or model. In the code below, a dedicated validation data set is used to avoid biased results. Measures such as the cost or error rate will be printed to the console. Typically, evaluation results are used to compare models that have been trained using different machine-learning algorithms, or a variant of these:

Evaluation evaluation = new Evaluation(trainingDataset);
evaluation.evaluateModel(targetFunction, validationDataset);
System.out.println(evaluation.toSummaryString("Results", false));

The examples above uses linear regression, which predicts a numeric-valued output such as a house price based on input values. Linear regression supports the prediction of continuous, numeric values. To predict binary Yes/No values or classifiers, you could use a machine learning algorithm such as decision tree, neural network, or logistic regression:

// using logistic regression
Classifier targetFunction = new Logistic();

You might use one of these learning algorithms to predict whether an email was spam or ham, or to predict whether a house for sale could be a top-seller or not. If you wanted to train your algorithm to predict whether a house is likely to sell quickly, you would need to label your example records with a new classifying label such as topseller:

// using topseller label attribute instead price label attribute
ArrayList<String> classVal = new ArrayList<>();

Attribute topsellerAttribute = new Attribute("topsellerLabel", classVal);

This training set could be used to train a new prediction classifier: topseller. Once trained, the prediction call will return the class label index, which can be used to get the predicted value:

int idx = (int) targetFunction.classifyInstance(unlabeledInstances.get(0));
String prediction = classVal.get(idx);


Although machine learning is closely related to statistics and uses many mathematical concepts, machine learning tools make it possible to start integrating machine learning into your programs without knowing a great deal about mathematics. That said, the better you understand the inner working of machine learning algorithms such as linear regression, which we explored in this article, the more you will be able to choose the right algorithm and configure it for optimal performance.

Related links

A good way to get deeper into machine learning is to take an online course such as Andrew Ng's Machine Learning course or Udacity's Intro to Machine Learning.

1 2 3 Page 3
Page 3 of 3