User Defined Types
Open tutorial in Github
When you create a DTensor or DScalar variable, internally it has an implementation of a function call wrap()
, which is invoked during differentiation operations. The internal representation is used for both the calculation of the user defined function and the calculation of its derivative. Alternatively, one can create their own user defined type. A user defined type could be a class with DTensor or DScalar variables, or a list of DTensor or DScalar variables, or even more complex types. When defining a user created type, one has to implement the wrap()
function as part of the type. There are a couple ways to implement the wrap()
function and have it called, which are discussed below.
The advantage of the user defined type is that one has named-member access of a class instead of placing all the variables in an array or tensor and having to use indexing to access the variables.
The purpose of primalAndForwardDerivative() and primalAndReverseDerivative() is to calculate the derivatives of user defined types. The functions take a user defined input type, a user defined output type, and a user defined derivative type. In addition, the user defines a function for the calculations, and possibly a function to extract the derivatives from the calculations and place the results into the user defined derivative type. Also, the lambdas wrapInput
and wrapOutput
might need to be defined to get the wrap()
function called internally in the code. Notice the similarity in names to primalAndForwardDerivative() and primalAndReverseDerivative(), as an "s" has been added to the end of the function names primalAndForwardDerivative() and primalAndReverseDerivative() .
primalAndForwardDerivative() and primalAndReverseDerivative() have essentially the same function signature . The function signatures are:
fun <Input : Any, Output : Any, Derivative : Any>
primalAndForwardDerivative(
x: Input,
f: (Input) -> Output,
wrapInput: ((Input, Wrapper) -> Input)? = null,
wrapOutput: ((Output, Wrapper) -> Output)? = null,
extractDerivative: (Input, Output, (input: DTensor, output: DTensor) -> DTensor) -> Derivative,
): Pair<Output, Derivative>
and
fun <Input : Any, Output : Any, Derivative : Any>
primalAndReverseDerivative(
x: Input,
f: (Input) -> Output,
wrapInput: ((Input, Wrapper) -> Input)? = null,
wrapOutput: ((Output, Wrapper) -> Output)? = null,
extractDerivative: (Input, Output, (input: DTensor, output: DTensor) -> DTensor) -> Derivative,
): Pair<Output, Derivative>
The type for Input
, Output
, and Derivative
are user defined. The user defined types could be a class with DScalar or DTensor variables, a list with DScalar or DTensor elements, or something more complex.
The function f: (Input) -> Output
has to know how to access the variables in the Input
type and produce a return of Output
type.
The Derivative
type has to define all the possible derivates that can be produced from taking the derivative of f()
with respect to the Input
type.
The Input
or Output
types can inherit the Differentiable<T>
interface, which knows how to call the wrap()
function.. If the Input
and Output
types do not inherit from the Differentiable<T>
interface, then a lambda expression needs to written for the wrapInput
and/or the wrapOutput
functions to call wrap()
for the Input
or Output
type.