Wine Quality Analysis - Homework Assignment 🍷📊

In this homework, you’ll analyze wine quality data and build machine learning models. You’ll use Tablesaw for data analysis and both SMILE and Weka for machine learning.

Learning Objectives

  • Load and explore wine quality datasets
  • Create visualizations to understand data patterns
  • Build and compare machine learning models
  • Make predictions on wine quality

Setup

Make sure you have the wine dataset at ~/wine-dataset/WineQT.csv

// Dependencies
%maven tech.tablesaw:tablesaw-core:0.43.1
%maven tech.tablesaw:tablesaw-jsplot:0.43.1
%maven com.github.haifengl:smile-core:3.0.1
%maven com.github.haifengl:smile-data:2.6.0
%maven nz.ac.waikato.cms.weka:weka-stable:3.8.6

// Imports
import tech.tablesaw.api.*;
import tech.tablesaw.io.csv.CsvReadOptions;
import tech.tablesaw.plotly.api.*;
import tech.tablesaw.plotly.components.Figure;
import tech.tablesaw.aggregate.AggregateFunctions;
import smile.classification.*;
import smile.data.*;
import smile.data.formula.Formula;
import smile.data.vector.IntVector;
import smile.math.MathEx;
import smile.validation.metric.Accuracy;
import weka.core.*;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.Evaluation;
import java.util.ArrayList;
import java.util.stream.IntStream;
import java.util.Arrays;

// Load the wine dataset
CsvReadOptions options = 
    CsvReadOptions.builder(System.getProperty("user.home") + "/wine-dataset/WineQT.csv")
        .separator(';')
        .build();

Table wine = Table.read().usingOptions(options);

System.out.println("Dataset loaded: " + wine.rowCount() + " rows, " + wine.columnCount() + " columns");
System.out.println("First 5 rows:");
System.out.println(wine.first(5));

Question 1: Data Exploration (2 parts)

Complete the code below to explore the wine dataset.

// Part A: TODO - Display summary statistics for the wine dataset
System.out.println("Summary Statistics:");
// YOUR CODE HERE

// Part B: TODO - Create a histogram of wine quality distribution
Figure qualityHist = // YOUR CODE HERE
Plot.show(qualityHist);

// Part B continued: TODO - Create a scatter plot of alcohol vs quality
Figure alcoholScatter = // YOUR CODE HERE  
Plot.show(alcoholScatter);

// Provided: Group wines by quality level
Table qualityGroups = wine.summarize(
    "alcohol", AggregateFunctions.mean,
    "pH", AggregateFunctions.mean,
    "volatile acidity", AggregateFunctions.mean
).by("quality");
System.out.println("\nCharacteristics by quality level:");
System.out.println(qualityGroups);

Question 2: Machine Learning with SMILE (2 parts)

Build a Random Forest model using the SMILE library to predict wine quality.

// Convert Tablesaw table to SMILE DataFrame
String[] colNames = wine.columnNames().toArray(String[]::new);
double[][] data = wine.as().doubleMatrix();
DataFrame df = DataFrame.of(data, colNames);

IntVector quality = IntVector.of("quality", df.doubleVector("quality").stream()
    .mapToInt(d -> (int) d)
    .toArray());
df = df.drop("quality").merge(quality);

// Split data into training and test sets (80/20 split)
int n = df.nrows();
int[] indices = IntStream.range(0, n).toArray();
MathEx.permutate(indices); 
int splitIndex = (int)(n * 0.8);

DataFrame trainDf = df.slice(0, splitIndex);
DataFrame testDf = df.slice(splitIndex, n);

// Part A: TODO - Train a Random Forest model using SMILE
smile.classification.RandomForest rf = // YOUR CODE HERE

// Part B: TODO - Calculate and display model accuracy
int[] yTrue = testDf.stream().mapToInt(r -> r.getInt("quality")).toArray();
int[] yPred = // YOUR CODE HERE

double accuracy = // YOUR CODE HERE
System.out.printf("SMILE Random Forest Accuracy: %.2f%%\n", accuracy * 100);
// Convert to Weka format
ArrayList<Attribute> attributes = new ArrayList<>();
for (String col : wine.columnNames()) {
    if (!col.equals("quality")) {
        attributes.add(new Attribute(col));
    }
}

IntColumn qualityCol = (IntColumn) wine.intColumn("quality");
int minQuality = (int) qualityCol.min();
int maxQuality = (int) qualityCol.max();
ArrayList<String> qualityVals = new ArrayList<>();
for (int i = minQuality; i <= maxQuality; i++) {
    qualityVals.add(String.valueOf(i));
}
attributes.add(new Attribute("quality", qualityVals));

Instances wData = new Instances("Wine", attributes, wine.rowCount());
wData.setClassIndex(wData.numAttributes() - 1);

for (int i = 0; i < wine.rowCount(); i++) {
    double[] vals = new double[wData.numAttributes()];
    for (int j = 0; j < wine.columnCount() - 1; j++) {
        vals[j] = ((NumberColumn<?,?>) wine.column(j)).getDouble(i);
    }
    vals[wData.numAttributes() - 1] = qualityVals.indexOf(String.valueOf(qualityCol.get(i)));
    wData.add(new DenseInstance(1.0, vals));
}

// Split data
int trainSize = (int) Math.round(wData.numInstances() * 0.8);
Instances train = new Instances(wData, 0, trainSize);
Instances test = new Instances(wData, trainSize, wData.numInstances() - trainSize);

// TODO - Train Weka Random Forest and calculate accuracy
RandomForest wekaRf = // YOUR CODE HERE
try {
    // YOUR CODE HERE - build the classifier
    
    Evaluation eval = new Evaluation(train);
    eval.evaluateModel(wekaRf, test);
    System.out.printf("Weka Random Forest Accuracy: %.2f%%\n", eval.pctCorrect());
    
    // Compare models
    System.out.println("\nModel Comparison Complete!");
    System.out.println("Which model performed better? Analyze the results above.");
    
} catch (Exception e) {
    e.printStackTrace();
}