LinearAfterResetGru

class LinearAfterResetGru(numInputs: Int, numHidden: Int, initialHidden: DTensor?, accType: RecurrentBase.RecurrentBase.AccType, xh2u: Dense, xh2r: Dense, xh2n: Dense) : GRU

Linear-after-reset GRU

In the computation of the candidate activation vector, the linear transform is applied after the hidden state goes through the reset gate.

\hat{h}t = tanh(W_h x_t + U_h (r_t * h{t-1}) + bias)

Constructors

LinearAfterResetGru
Link copied to clipboard
fun LinearAfterResetGru(numInputs: Int, numHidden: Int, random: Random, initialHidden: DTensor? = null, acc: RecurrentBase.RecurrentBase.AccType = AccType.Fold)
LinearAfterResetGru
Link copied to clipboard
fun LinearAfterResetGru(numInputs: Int, numHidden: Int, initialHidden: DTensor? = null, accType: RecurrentBase.RecurrentBase.AccType = AccType.Fold, xh2u: Dense, xh2r: Dense, xh2n: Dense)

Functions

accMap
Link copied to clipboard
open fun accMap(t: DTensor, sequenceAxis: Int, initialState: DTensor): Pair<DTensor, DTensor>
cell
Link copied to clipboard
open override fun cell(state: Pair<DTensor, DTensor>, x: DTensor): Pair<DTensor, DTensor>
cpu
Link copied to clipboard
open override fun cpu(): GRU
doRecurrence
Link copied to clipboard
open fun doRecurrence(x: DTensor, initialState: DTensor = this.initialState): DTensor

Do the recurrence.

equals
Link copied to clipboard
open operator override fun equals(other: Any?): Boolean
extractTangent
Link copied to clipboard
open override fun extractTangent(output: DTensor, extractor: (DTensor, DTensor) -> DTensor): TrainableComponent.Companion.Tangent
fold
Link copied to clipboard
open fun fold(t: DTensor, sequenceAxis: Int, initialState: DTensor): Pair<DTensor, DTensor>
getSingleInput
Link copied to clipboard
open fun getSingleInput(inputs: Array<out DTensor>): DTensor

Helper to check that the layer was called with a single input. Returns that input if successful, else errors.

gpu
Link copied to clipboard
open override fun gpu(): GRU
hashCode
Link copied to clipboard
open override fun hashCode(): Int
invoke
Link copied to clipboard
open operator override fun invoke(vararg inputs: DTensor): DTensor
load
Link copied to clipboard
open override fun load(from: ByteBuffer): GRU
processForBatching
Link copied to clipboard
open override fun processForBatching(initialState: DTensor, initialOutput: DTensor, batchSize: Int): Pair<DTensor, DTensor>
store
Link copied to clipboard
open override fun store(into: ByteBuffer): ByteBuffer
to
Link copied to clipboard
open fun to(device: Device): OnDevice
trainingStep
Link copied to clipboard
open override fun trainingStep(optim: Optimizer<*>, tangent: Trainable.Tangent): GRU
withTrainables
Link copied to clipboard
open override fun withTrainables(trainables: List<Trainable<*>>): GRU
wrap
Link copied to clipboard
open override fun wrap(wrapper: Wrapper): GRU

The wrap function should return the same static type it is declared on.

Properties

accType
Link copied to clipboard
open override val accType: RecurrentBase.RecurrentBase.AccType
batchAxis
Link copied to clipboard
open val batchAxis: Int
initialHidden
Link copied to clipboard
val initialHidden: DTensor? = null
initialOutput
Link copied to clipboard
open override val initialOutput: FloatTensor
initialState
Link copied to clipboard
open override val initialState: DTensor
numHidden
Link copied to clipboard
val numHidden: Int
numInputs
Link copied to clipboard
val numInputs: Int
sequenceAxis
Link copied to clipboard
open val sequenceAxis: Int
trainables
Link copied to clipboard
open override val trainables: List<Trainable<*>>