官术网_书友最值得收藏!

Evaluating the model

Once the training has been completed, the next task would be evaluating the model. We will evaluate the model's performance on the test set. For the evaluation, we will be using Evaluation(); it creates an evaluation object with two possible classes (survived or not survived). More technically, the Evaluation class computes the evaluation metrics such as precision, recall, F1, accuracy, and Matthews' correlation coefficient. The last one is used to evaluate a binary classifier. Now let's take a brief overview on these metrics:

Accuracy is the ratio of correctly predicted samples to total samples:

Precision is the ratio of correctly predicted positive samples to the total predicted positive samples:

 

Recall is the ratio of correctly predicted positive samples to all samples in the actual class—yes:

 

F1 score is the weighted average (harmonic mean) of Precision and Recall::

 

Matthews Correlation Coefficient (MCC) is a measure of the quality of binary (two-class) classifications. MCC can be calculated directly from the confusion matrix as follows (given that TP, FP, TN, and FN are already available):

Unlike the Apache Spark-based classification evaluator, when solving a binary classification problem using the DL4J-based evaluator, special care should be taken for binary classification metrics such as F1, precision, recall, and so on.

Well, we will see these later on. First, let's iterate the evaluation over every test sample and get the network's prediction from the trained model. Finally, the eval() method checks the prediction against the true classes:

log.info("Evaluate model...."); 
Evaluation eval = new Evaluation(2) // for class 1

while(testDataIt.hasNext()){
DataSet next = testDataIt.next();
INDArray output = model.output(next.getFeatureMatrix());
eval.eval(next.getLabels(), output);
}
log.info(eval.stats());
log.info("****************Example finished********************");
>>>
==========================Scores========================================
# of classes: 2
Accuracy: 0.6496
Precision: 0.6155
Recall: 0.5803
F1 Score: 0.3946
Precision, recall & F1: reported for positive class (class 1 - "1") only
=======================================================================

Oops! Unfortunately, we have not managed to achieve very high classification accuracy for class 1 (that is, 65%). Now, we compute another metric called MCC for this binary classification problem.

// Compute Matthews correlation coefficient 
EvaluationAveraging averaging = EvaluationAveraging.Macro;
double MCC = eval.matthewsCorrelation(averaging);
System.out.println("Matthews correlation coefficient: "+ MCC);
>>>
Matthews's correlation coefficient: 0.22308172619187497

Now let's try to interpret this result based on the Matthews paper (see more at www.sciencedirect.com/science/article/pii/0005279575901099), which describes the following properties: A correlation of C = 1 indicates perfect agreement, C = 0 is expected for a prediction no better than random, and C = -1 indicates total disagreement between prediction and observation.

Following this, our result shows a weak positive relationship. Alright! Although we have not achieved good accuracy, you guys can still try by tuning hyperparameters or even by changing other networks such as LSTM, which we are going to discuss in the next section. But we'll do so for solving our cancer prediction problem, which is the main goal of this chapter. So stay with me!

主站蜘蛛池模板: 家居| 广饶县| 那坡县| 砀山县| 沐川县| 安西县| 普陀区| 固安县| 商水县| 定安县| 比如县| 益阳市| 宜川县| 黎平县| 南乐县| 郑州市| 浮山县| 昌江| 山丹县| 上犹县| 响水县| 嘉义县| 广宗县| 凌海市| 梅河口市| 巍山| 淮阳县| 定西市| 巩留县| 陈巴尔虎旗| 衢州市| 澄江县| 洪洞县| 崇文区| 湄潭县| 巴林右旗| 开远市| 萝北县| 孟州市| 光山县| 镇远县|