Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics wishlist #6

Open
4 of 7 tasks
dfalbel opened this issue May 11, 2021 · 5 comments
Open
4 of 7 tasks

Metrics wishlist #6

dfalbel opened this issue May 11, 2021 · 5 comments

Comments

@dfalbel
Copy link
Member

dfalbel commented May 11, 2021

  • AUC
  • Precision
  • Recall
  • MSE
  • RMSE
  • MAE
  • Cohen's Kappa
@mattwarkentin
Copy link
Contributor

mattwarkentin commented May 13, 2021

Are you looking to create native torch implementations of any metric added to luz? I wonder if the yardstick package could possibly be used for adding immediate support for a huge number of metrics.

You would basically just need a thin wrapper around the yardstick functions to handle going back-and-forth from torch tensors to R vectors. It would mean absorbing the yardstick dependencies, but perhaps worrying about dependencies is less of a concern for a high-level package like luz.

Thoughts?

@dfalbel
Copy link
Member Author

dfalbel commented May 13, 2021

I think the main issue is that yardstick doesn't support streaming metrics, so it must see predictions for the full dataset + targets in order to compute the value. We could store the predictions/targets, but we can rapidly go out of memory if doing something like U-Net or even a classification problem with thousands of classes.

We could also compute the metric per batch and then average over all batches, but this makes it very hard to compare metrics between runs and was a significant source of confusion in Keras when reported metric value was not identical to computing it yourself.

We could definitely have a luz_metric_yardstick wrapper that could take any yardstick metric and compute the results. but then it will average the per-batch computation instead of computing in the full dataset. This would allow us to use metrics that can't be implemented in streaming mode (or are to hard to do so) and metrics that we didn't implement yet.

What do you think?

@mattwarkentin
Copy link
Contributor

mattwarkentin commented May 13, 2021

Sorry, just so I understand, does streaming mode mean that batch-wise calculations are update()d along the way, but the final metric calculation is done at the end of the epoch? So each update() never needs to see the full data set, only batches.

Whereas for yardstick, you have to make the "final" computation for each and every batch, and then average batch-wise metrics at the end of the epoch to get the epoch metric. Am I understanding correctly?

@dfalbel
Copy link
Member Author

dfalbel commented May 13, 2021

Yes, exactly! The problem is that mean(mae(x, y)) can be slightly different from mae(x, y)and also depends on the dataset order.

Edit, sorry, mae is not a good example, but auc or other metrics that aren't exactly doing a linear transf.

@maxwell-geospatial
Copy link

I have been working on some metrics for semantic segmentation tasks using luz_metrics: f1-score, precision, and recall. The goal is for each metric to allow for assessment of both multiclass and binary classification problems. For multiclass problems, both micro and macro averaging are implemented. However, note that micro averaged recall, precision, and f1-score are all equivalent and actually equal to overall accuracy. So, maybe it only makes sense to implement macro averaging. It would also be good to generalize these to work with not just 2D semantic segmentation task (e.g., scene labeling and 3D semantic segmentation tasks). I have posted the code for each metric below. I am not sure if they are working correctly yet and would appreciate any feedback. If it is determined that they are working properly or if they can be corrected/improved by input from others, it would be great to add them to luz if there is interest. I have added comments for the precision metric throughout. I have only added comments within other two metrics where they differ.

I am working on a package for geospatial deep learning where I have integrated these metrics: https://github.com/maxwell-geospatial/geodl. This is a work in progress.

Precision draft:

#' luz_metric_precision
#'
#' luz_metric function to calculate precision
#'
#' Calculates precision based on luz_metric() for use within training and validation
#' loops.
#'
#' @param preds Tensor of class predicted probabilities with shape
#' [Batch, Class Logits, Height, Width] for a multiclass classification. For a
#' binary classification, you can provide logits for the positive class as
#' [Batch, Positive Class Logit, Height, Width] or [Batch, Height, Width].
#' @param target Tensor of target class labels with shape
#' [Batch, Class Indices, Height, Width] for a multiclass classification. For a
#' binary classification, you can provide targets as
#' [Batch, Positive Class Index, Height, Width] or [Batch, Height, Width]. For
#' binary classification, the class index must be 1 for the positive class and
#' 0 for the background case.
#' @param nCLs number of classes being differentiated. Should be 1 for a binary classification
#' where only the postive case logit is returned. Default is 1. 
#' @param smooth a smoothing factor to avoid divide by zero errors. Default is 1.
#' @param mode Either "binary" or "multiclass". If "binary", only the logit for
#' positive class prediction should be provided. If both the positive and negative
#' or background class probability is provided for a binary classification, use
#' the "multiclass" mode.
#' @param average Either "micro" or "macro". Whether to use micro- or macro-averaging
#' for multiclass metric calculation. Ignored when mode is "binary". Default is
#' "macro". Macro averaging consists of calculating the metric separately for each class
#' and averaging the results such that all classes are equally weighted. Micro-averaging calculates the 
#' metric for all classes collectively, and classes with a larger number of samples will have a larger
#' weight in the final metric. 
#' @param zeroStart TRUE or FALSE. If class indices start at 0 as opposed to 1, this should be set to
#' TRUE. This is required  to implement one-hot encoding since R starts indexing at 1. Default is TRUE. 
#' @param chnDim TRUE or FALSE. Whether the channel dimension is included in the target tensor: 
#' [Batch, Channel, Height, Width] as opposed to [Batch, Channel, Height, Width]. If the channel dimension
#' is included, this should be set to TRUE. If it is not, this should be set to FALSE. Default is TRUE. 
#' @param usedDS TRUE or FALSE. If deep supervision was implemented and masks are produced at varying scales using
#' the defineSegDataSetDS() function, this should be set to TRUE. Only the original resolution is used
#' to calculate assessment metrics. Default is FALSE.
#' @return Calculated metric return as a base-R vector as opposed to tensor.
#' @export
luz_metric_precision <- luz::luz_metric(

  abbrev = "precision",

  initialize = function(nCls=1,
                        smooth=1,
                        mode = "multiclass",
                        average="macro",
                        zeroStart=TRUE,
                        chnDim=TRUE,
                        usedDS=FALSE){

    self$nCls <- nCls
    self$smooth <- smooth
    self$mode <- mode
    self$average <- average
    self$zeroStart <- zeroStart
    self$chnDim <- chnDim
    self$usedDS <- usedDS

    #initialize R vectors to store true positive, false negative, and false positive counts.
    #For a binary or micro-averaged multiclass metric, will obtain a vector with a length of one.
    #For a macro-averaged multiclass metric, will obtain a vector with a length equal to the number of classes. 
    self$tps <- rep(0.0, nCls)
    self$fns <- rep(0.0, nCls)
    self$fps <- rep(0.0, nCls)
  },

  update = function(preds, target){
    #If the channel dimension is not included, add it. 
    if(self$chnDim == FALSE){
      target <- target$unsqueeze(dim=2)
    }

    #If deep supervision is used, extract out only the prediction at original resolution
    if(self$usedDS == TRUE){
      preds <- preds[1]
      target <- target[1]
    }
    
    #For multiclass problems
    if(self$mode == "multiclass"){
      #Get index of class with largest logit. 
      predsMax <- torch::torch_argmax(preds, dim = 2)
      #Make sure target is of type torch_long()
      target1 <- torch::torch_tensor(target, dtype=torch::torch_long())

      if(self$zeroStart == TRUE){
        #If class indices start at zero, add 1.
        target1 <- torch::torch_tensor(target+1, dtype=torch::torch_long())
      }

      #One-hot encode the prediction results that have been passed through argmax().
      preds_one_hot <- torch::nnf_one_hot(predsMax, num_classes = self$nCls)
      #Permute the results so that the dimension order is [batch, encodings, height, width].
      preds_one_hot <- preds_one_hot$permute(c(1,4,2,3))

      #one-hot encode the targets. 
      target_one_hot <- torch::nnf_one_hot(target1, num_classes = self$nCls)
      #Remove channel dimension.
      target_one_hot <- target_one_hot$squeeze()
      #permute the results so that the order is [batch, encoding, height, width]
      target_one_hot <- target_one_hot$permute(c(1,4,2,3))

      
      #If using macro averaging, calculation of tps, fps, and fns should be performed separately for each class. 
      dims <- c(1, 3, 4)
      #If using micro averaging, calculation of tps, fps, and fns should happen collectively. 
      if(self$average == "micro"){
        dims <- c(1,2,3,4)
      }

      #Calculate true positives and add to running true positives count. 
      self$tps <- self$tps + (torch::torch_sum(preds_one_hot * target_one_hot, dims)$cpu() |>
                                as_array() |>
                                as.vector())
      #Calculate false positives and add to running false positives count. 
      self$fps <- self$fps + (torch::torch_sum(preds_one_hot * (-target_one_hot + 1.0), dims)$cpu() |>
                                as_array() |>
                                as.vector())
      #Calculate false negatives and add to running false negative count. 
      self$fns <- self$fns + (torch::torch_sum((-preds_one_hot + 1.0) * target_one_hot, dims)$cpu() |>
                                as_array() |>
                                as.vector())

    #Processing for binary classification
    }else{
      #Convert logits to probs using sigmoid function 
      preds <- torch::nnf_sigmoid(preds)
      #Round to convert probs to 0 and 1. 
      preds <- torch::torch_round(preds)
      #Flatten arrays (this generalizes the problem so that the shape of the input tensors does not matter).
      preds <- preds$flatten()
      target <- target$flatten()

      #Calculate true positives and add to running true positives count. 
      self$tps <- self$tps + sum(preds * target) |>
        as_array() |>
        as.vector()
      #Calculate false positives and add to running false positives count. 
      self$fps <- self$fps + sum((1.0 - target) * preds) |>
        as_array() |>
        as.vector()
      #Calculate false negatives and add to running false negative count. 
      self$fns <- self$fns + sum(target * (1.0 - preds)) |>
        as_array() |>
        as.vector()
    }
  },

  compute = function(){
    #Calculate precision: (tps + smooth)/(tps + fps + smooth)
    mean((self$tps + self$smooth)/(self$tps + self$fps + self$smooth))
  }
)

Recall draft:

#' luz_metric_recall
#'
#' luz_metric function to calculate recall
#'
#' Calculates recall based on luz_metric() for use within training and validation
#' loops.
#'
#' @param preds Tensor of class predicted probabilities with shape
#' [Batch, Class Logits, Height, Width] for a multiclass classification. For a
#' binary classification, you can provide logits for the positive class as
#' [Batch, Positive Class Logit, Height, Width] or [Batch, Height, Width].
#' @param target Tensor of target class labels with shape
#' [Batch, Class Indices, Height, Width] for a multiclass classification. For a
#' binary classification, you can provide targets as
#' [Batch, Positive Class Index, Height, Width] or [Batch, Height, Width]. For
#' binary classification, the class index must be 1 for the positive class and
#' 0 for the background case.
#' @param nCLs number of classes being differentiated. Should be 1 for a binary classification
#' where only the postive case logit is returned. Default is 1. 
#' @param smooth a smoothing factor to avoid divide by zero errors. Default is 1.
#' @param mode Either "binary" or "multiclass". If "binary", only the logit for
#' positive class prediction should be provided. If both the positive and negative
#' or background class probability is provided for a binary classification, use
#' the "multiclass" mode.
#' @param average Either "micro" or "macro". Whether to use micro- or macro-averaging
#' for multiclass metric calculation. Ignored when mode is "binary". Default is
#' "macro". Macro averaging consists of calculating the metric separately for each class
#' and averaging the results such that all classes are equally weighted. Micro-averaging calculates the 
#' metric for all classes collectively, and classes with a larger number of samples will have a larger
#' weight in the final metric. 
#' @param zeroStart TRUE or FALSE. If class indices start at 0 as opposed to 1, this should be set to
#' TRUE. This is required  to implement one-hot encoding since R starts indexing at 1. Default is TRUE. 
#' @param chnDim TRUE or FALSE. Whether the channel dimension is included in the target tensor: 
#' [Batch, Channel, Height, Width] as opposed to [Batch, Channel, Height, Width]. If the channel dimension
#' is included, this should be set to TRUE. If it is not, this should be set to FALSE. Default is TRUE. 
#' @param usedDS TRUE or FALSE. If deep supervision was implemented and masks are produced at varying scales using
#' the defineSegDataSetDS() function, this should be set to TRUE. Only the original resolution is used
#' to calculate assessment metrics. Default is FALSE.
#' @return Calculated metric return as a base-R vector as opposed to tensor.
#' @export
luz_metric_recall <- luz::luz_metric(

  abbrev = "recall",

  initialize = function(nCls=1,
                        smooth=1,
                        mode = "multiclass",
                        average="micro",
                        zeroStart=TRUE,
                        chnDim=TRUE,
                        usedDS=FALSE){

    self$nCls <- nCls
    self$smooth <- smooth
    self$mode <- mode
    self$average <- average
    self$zeroStart <- zeroStart
    self$chnDim <- chnDim
    self$usedDS <- usedDS

    self$tps <- rep(0.0, nCls)
    self$fns <- rep(0.0, nCls)
    self$fps <- rep(0.0, nCls)
  },

  update = function(preds, target){
    if(self$chnDim == FALSE){
      target <- target$unsqueeze(dim=2)
    }

    if(self$usedDS == TRUE){
      preds <- preds[1]
      target <- target[1]
    }
    if(self$mode == "multiclass"){
      predsMax <- torch::torch_argmax(preds, dim = 2)
      target1 <- torch::torch_tensor(target, dtype=torch::torch_long())

      if(self$zeroStart == TRUE){
        target1 <- torch::torch_tensor(target+1, dtype=torch::torch_long())
      }

      preds_one_hot <- torch::nnf_one_hot(predsMax, num_classes = self$nCls)
      preds_one_hot <- preds_one_hot$permute(c(1,4,2,3))

      target_one_hot <- torch::nnf_one_hot(target1, num_classes = self$nCls)
      target_one_hot <- target_one_hot$squeeze()
      target_one_hot <- target_one_hot$permute(c(1,4,2,3))

      dims <- c(1, 3, 4)
      if(self$average == "micro"){
        dims <- c(1,2,3,4)
      }

      self$tps <- self$tps + (torch::torch_sum(preds_one_hot * target_one_hot, dims)$cpu() |>
                                as_array() |>
                                as.vector())
      self$fps <- self$fps + (torch::torch_sum(preds_one_hot * (-target_one_hot + 1.0), dims)$cpu() |>
                                as_array() |>
                                as.vector())
      self$fns <- self$fns + (torch::torch_sum((-preds_one_hot + 1.0) * target_one_hot, dims)$cpu() |>
                                as_array() |>
                                as.vector())

    }else{
      preds <- torch::nnf_sigmoid(preds)
      preds <- torch::torch_round(preds)
      preds <- preds$flatten()
      target <- target$flatten()

      self$tps <- self$tps + sum(preds * target) |>
        as_array() |>
        as.vector()
      self$fps <- self$fps + sum((1.0 - target) * preds) |>
        as_array() |>
        as.vector()
      self$fns <- self$fns + sum(target * (1.0 - preds)) |>
        as_array() |>
        as.vector()
    }
  },

  compute = function(){
    #Calculate recall: (tps + smooth)/(tps + fns + smooth)
    mean((self$tps + self$smooth)/(self$tps + self$fns + self$smooth))
  }
)

F1-Score draft:

#' luz_metric_f1score
#'
#' luz_metric function to calculate the F1-score
#'
#' Calculates F1-score based on luz_metric() for use within training and validation
#' loops.
#'
#' @param preds Tensor of class predicted probabilities with shape
#' [Batch, Class Logits, Height, Width] for a multiclass classification. For a
#' binary classification, you can provide logits for the positive class as
#' [Batch, Positive Class Logit, Height, Width] or [Batch, Height, Width].
#' @param target Tensor of target class labels with shape
#' [Batch, Class Indices, Height, Width] for a multiclass classification. For a
#' binary classification, you can provide targets as
#' [Batch, Positive Class Index, Height, Width] or [Batch, Height, Width]. For
#' binary classification, the class index must be 1 for the positive class and
#' 0 for the background case.
#' @param nCLs number of classes being differentiated. Should be 1 for a binary classification
#' where only the postive case logit is returned. Default is 1. 
#' @param smooth a smoothing factor to avoid divide by zero errors. Default is 1.
#' @param mode Either "binary" or "multiclass". If "binary", only the logit for
#' positive class prediction should be provided. If both the positive and negative
#' or background class probability is provided for a binary classification, use
#' the "multiclass" mode.
#' @param average Either "micro" or "macro". Whether to use micro- or macro-averaging
#' for multiclass metric calculation. Ignored when mode is "binary". Default is
#' "macro". Macro averaging consists of calculating the metric separately for each class
#' and averaging the results such that all classes are equally weighted. Micro-averaging calculates the 
#' metric for all classes collectively, and classes with a larger number of samples will have a larger
#' weight in the final metric. 
#' @param zeroStart TRUE or FALSE. If class indices start at 0 as opposed to 1, this should be set to
#' TRUE. This is required  to implement one-hot encoding since R starts indexing at 1. Default is TRUE. 
#' @param chnDim TRUE or FALSE. Whether the channel dimension is included in the target tensor: 
#' [Batch, Channel, Height, Width] as opposed to [Batch, Channel, Height, Width]. If the channel dimension
#' is included, this should be set to TRUE. If it is not, this should be set to FALSE. Default is TRUE. 
#' @param usedDS TRUE or FALSE. If deep supervision was implemented and masks are produced at varying scales using
#' the defineSegDataSetDS() function, this should be set to TRUE. Only the original resolution is used
#' to calculate assessment metrics. Default is FALSE.
#' @return Calculated metric return as a base-R vector as opposed to tensor.
#' @export
luz_metric_f1score <- luz::luz_metric(

  abbrev = "F1Score",

  initialize = function(nCls=1,
                        smooth=1,
                        mode = "multiclass",
                        average="micro",
                        zeroStart=TRUE,
                        chnDim=TRUE,
                        usedDS=FALSE){

    self$nCls <- nCls
    self$smooth <- smooth
    self$mode <- mode
    self$average <- average
    self$zeroStart <- zeroStart
    self$chnDim <- chnDim
    self$usedDS <- usedDS

    self$tps <- rep(0.0, nCls)
    self$fns <- rep(0.0, nCls)
    self$fps <- rep(0.0, nCls)
  },

  update = function(preds, target){
    if(self$chnDim == FALSE){
      target <- target$unsqueeze(dim=2)
    }

    if(self$usedDS == TRUE){
      preds <- preds[1]
      target <- target[1]
    }
    if(self$mode == "multiclass"){
      predsMax <- torch::torch_argmax(preds, dim = 2)
      target1 <- torch::torch_tensor(target, dtype=torch::torch_long())

      if(self$zeroStart == TRUE){
        target1 <- torch::torch_tensor(target+1, dtype=torch::torch_long())
      }

      preds_one_hot <- torch::nnf_one_hot(predsMax, num_classes = self$nCls)
      preds_one_hot <- preds_one_hot$permute(c(1,4,2,3))

      target_one_hot <- torch::nnf_one_hot(target1, num_classes = self$nCls)
      target_one_hot <- target_one_hot$squeeze()
      target_one_hot <- target_one_hot$permute(c(1,4,2,3))

      dims <- c(1, 3, 4)
      if(self$average == "micro"){
        dims <- c(1,2,3,4)
      }

      self$tps <- self$tps + (torch::torch_sum(preds_one_hot * target_one_hot, dims)$cpu() |>
        as_array() |>
        as.vector())
      self$fps <- self$fps + (torch::torch_sum(preds_one_hot * (-target_one_hot + 1.0), dims)$cpu() |>
        as_array() |>
        as.vector())
      self$fns <- self$fns + (torch::torch_sum((-preds_one_hot + 1.0) * target_one_hot, dims)$cpu() |>
        as_array() |>
        as.vector())

    }else{
      preds <- torch::nnf_sigmoid(preds)
      preds <- torch::torch_round(preds)
      preds <- preds$flatten()
      target <- target$flatten()

      self$tps <- self$tps + sum(preds * target) |>
        as_array() |>
        as.vector()
      self$fps <- self$fps + sum((1.0 - target) * preds) |>
        as_array() |>
        as.vector()
      self$fns <- self$fns + sum(target * (1.0 - preds)) |>
        as_array() |>
        as.vector()
    }
  },

  compute = function(){
    #Calculate f1-score: (2*tps + smooth)/(2*tps + fns + fps + smooth)
    #Not sure if this is the best way to do this or if it should be calculated directly from precision and recall as (2*precision*recall)/(precision+recall)
    mean(((2.0*self$tps) + self$smooth)/((2.0*self$tps) + self$fns + self$fps + self$smooth))
  }
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants