batchNormTrainV1

fun batchNormTrainV1(input: DTensor, scaleShift: DTensor, runningMean: DTensor, runningVariance: DTensor, momentum: Float): Triple<DTensor, DTensor, DTensor>

The batchNorm op for training, V1 compatibility version

Return

Triple of:

  • output, with shape NHWC

  • mean over input, with shape C

  • sample variance over input, with shape C

Parameters

input

: an NHWC tensor

scaleShift

: the combined scale and shift tensor, with shape (2, C)