batch Norm Train V2
fun batchNormTrainV2(input: DTensor, scaleShift: DTensor, runningN: Float, runningSum: DTensor, runningSumOfSquares: DTensor, momentum: Float): Pair<DTensor, Triple<Float, DTensor, DTensor>>
Content copied to clipboard
The batchNorm op for training
Return
Pair of:
output, with shape NHWC
Triple of:
New running N
New running sum
New running sum of squares
Parameters
input
: an NHWC tensor
scale Shift
: the combined scale and shift tensor, with shape (2, C)