batchNormTrainV2

fun batchNormTrainV2(input: DTensor, scaleShift: DTensor, runningN: Float, runningSum: DTensor, runningSumOfSquares: DTensor, momentum: Float): Pair<DTensor, Triple<Float, DTensor, DTensor>>

The batchNorm op for training

Return

Pair of:

  • output, with shape NHWC

  • Triple of:

    • New running N

    • New running sum

    • New running sum of squares

Parameters

input

: an NHWC tensor

scaleShift

: the combined scale and shift tensor, with shape (2, C)