BatchNormTraining

open class BatchNormTraining : BatchNormTrainingBase<BatchNormTraining>

A trainable Batch Normalization transform, as described in https://arxiv.org/abs/1502.03167 . When training is complete, use its @see inferenceMode property to get the computed affine transform. This version maintains an exponential moving average of the sum of the samples, sum of the squared samples, and sample count which are used to estimate the population mean and variance.

Epsilon is hardcoded to 1.e-5f.

Constructors

BatchNormTraining
Link copied to clipboard
fun BatchNormTraining(numFeatures: Int, momentum: Float = 0.1f)

Functions

cpu
Link copied to clipboard
open override fun cpu(): BatchNormTraining
equals
Link copied to clipboard
open operator override fun equals(other: Any?): Boolean
extractTangent
Link copied to clipboard
open override fun extractTangent(output: DTensor, extractor: (DTensor, DTensor) -> DTensor): TrainableComponent.Companion.Tangent
getSingleInput
Link copied to clipboard
open fun getSingleInput(inputs: Array<out DTensor>): DTensor

Helper to check that the layer was called with a single input. Returns that input if successful, else errors.

gpu
Link copied to clipboard
open override fun gpu(): BatchNormTraining
hashCode
Link copied to clipboard
open override fun hashCode(): Int
invoke
Link copied to clipboard
open operator override fun invoke(input: DTensor): DTensor
abstract operator fun invoke(vararg inputs: DTensor): DTensor
load
Link copied to clipboard
open override fun load(from: ByteBuffer): BatchNormTraining
store
Link copied to clipboard
open override fun store(into: ByteBuffer): ByteBuffer
to
Link copied to clipboard
open fun to(device: Device): OnDevice
trainingStep
Link copied to clipboard
open override fun trainingStep(optim: Optimizer<*>, tangent: Trainable.Tangent): BatchNormTraining
withTrainables
Link copied to clipboard
open override fun withTrainables(trainables: List<Trainable<*>>): BatchNormTraining
wrap
Link copied to clipboard
open override fun wrap(wrapper: Wrapper): BatchNormTraining

The wrap function should return the same static type it is declared on.

Properties

inferenceMode
Link copied to clipboard
open override val inferenceMode: AffineTransform

Freeze the batch norm transform that was used during training, returning an affine transform to be used for inference.

stats
Link copied to clipboard
open override val stats: Pair<DTensor, DTensor>

The computed running mean and variance.

trainables
Link copied to clipboard
open override val trainables: List<Trainable<*>>

Inheritors

BatchNorm2d
Link copied to clipboard