BatchNormTrainingV1

A trainable Batch Normalization transform, as described in https://arxiv.org/abs/1502.03167 . When training is complete, use its @see inferenceMode property to get the computed affine transform. This version is provided to imitate the behavior in V1, the previous implementation, in that it calculates a running mean and running variance rather than gathering the raw input to compute the mean and variance. It applies Bessel's correction (https://en.wikipedia.org/wiki/Bessel%27s_correction) to the sample variance to get an estimate of the population variance for each batch, and uses an exponential moving average of those values as an estimate the population variance when @see inferenceMode is applied.

Epsilon is hardcoded to 1.e-5f

Types

Companion
Link copied to clipboard
object Companion

Functions

cpu
Link copied to clipboard
open override fun cpu(): BatchNormTrainingV1
equals
Link copied to clipboard
open operator override fun equals(other: Any?): Boolean
extractTangent
Link copied to clipboard
open override fun extractTangent(output: DTensor, extractor: (DTensor, DTensor) -> DTensor): TrainableComponent.Companion.Tangent
getSingleInput
Link copied to clipboard
open fun getSingleInput(inputs: Array<out DTensor>): DTensor

Helper to check that the layer was called with a single input. Returns that input if successful, else errors.

gpu
Link copied to clipboard
open override fun gpu(): BatchNormTrainingV1
hashCode
Link copied to clipboard
open override fun hashCode(): Int
invoke
Link copied to clipboard
open operator override fun invoke(input: DTensor): DTensor
abstract operator fun invoke(vararg inputs: DTensor): DTensor
load
Link copied to clipboard
open override fun load(from: ByteBuffer): BatchNormTrainingV1
store
Link copied to clipboard
open override fun store(into: ByteBuffer): ByteBuffer
to
Link copied to clipboard
open fun to(device: Device): OnDevice
trainingStep
Link copied to clipboard
open override fun trainingStep(optim: Optimizer<*>, tangent: Trainable.Tangent): BatchNormTrainingV1
withTrainables
Link copied to clipboard
open override fun withTrainables(trainables: List<Trainable<*>>): BatchNormTrainingV1
wrap
Link copied to clipboard
open override fun wrap(wrapper: Wrapper): BatchNormTrainingV1

The wrap function should return the same static type it is declared on.

Properties

inferenceMode
Link copied to clipboard
open override val inferenceMode: AffineTransform

Freeze the batch norm transform that was used during training, returning an affine transform to be used for inference.

stats
Link copied to clipboard
open override val stats: Pair<DTensor, DTensor>

The computed running mean and variance.

trainables
Link copied to clipboard
open override val trainables: List<Trainable<*>>