RecurrentBase

interface RecurrentBase<Recurrent : RecurrentBase<Recurrent, T>, T> : TrainableLayer<Recurrent>

Types

RecurrentBase
Link copied to clipboard
object RecurrentBase

Functions

accMap
Link copied to clipboard
open fun accMap(t: DTensor, sequenceAxis: Int, initialState: T): Pair<T, DTensor>
cell
Link copied to clipboard
abstract fun cell(state: Pair<T, DTensor>, x: DTensor): Pair<T, DTensor>
cpu
Link copied to clipboard
open override fun cpu(): Recurrent
doRecurrence
Link copied to clipboard
open fun doRecurrence(x: DTensor, initialState: T = this.initialState): DTensor

Do the recurrence.

extractTangent
Link copied to clipboard
open override fun extractTangent(output: DTensor, extractor: (DTensor, DTensor) -> DTensor): TrainableComponent.Companion.Tangent
fold
Link copied to clipboard
open fun fold(t: DTensor, sequenceAxis: Int, initialState: T): Pair<T, DTensor>
getSingleInput
Link copied to clipboard
open fun getSingleInput(inputs: Array<out DTensor>): DTensor

Helper to check that the layer was called with a single input. Returns that input if successful, else errors.

gpu
Link copied to clipboard
open override fun gpu(): Recurrent
invoke
Link copied to clipboard
open operator override fun invoke(vararg inputs: DTensor): DTensor
load
Link copied to clipboard
open override fun load(from: ByteBuffer): Recurrent
processForBatching
Link copied to clipboard
abstract fun processForBatching(initialState: T, initialOutput: DTensor, batchSize: Int): Pair<T, DTensor>
store
Link copied to clipboard
open override fun store(into: ByteBuffer): ByteBuffer
to
Link copied to clipboard
open fun to(device: Device): OnDevice
trainingStep
Link copied to clipboard
open override fun trainingStep(optim: Optimizer<*>, tangent: Trainable.Tangent): Recurrent
withTrainables
Link copied to clipboard
abstract fun withTrainables(trainables: List<Trainable<*>>): Recurrent
wrap
Link copied to clipboard
open override fun wrap(wrapper: Wrapper): Recurrent

The wrap function should return the same static type it is declared on.

Properties

accType
Link copied to clipboard
abstract val accType: RecurrentBase.RecurrentBase.AccType
batchAxis
Link copied to clipboard
open val batchAxis: Int
initialOutput
Link copied to clipboard
abstract val initialOutput: DTensor
initialState
Link copied to clipboard
abstract val initialState: T
sequenceAxis
Link copied to clipboard
open val sequenceAxis: Int
trainables
Link copied to clipboard
abstract val trainables: List<Trainable<*>>

Inheritors

GRU
Link copied to clipboard