Skip to main content

User Defined Types

Open tutorial in Github

When you create a DTensor or DScalar variable, internally it has an implementation of a function call wrap(), which is invoked during differentiation operations. The internal representation is used for both the calculation of the user defined function and the calculation of its derivative. Alternatively, one can create their own user defined type. A user defined type could be a class with DTensor or DScalar variables, or a list of DTensor or DScalar variables, or even more complex types. When defining a user created type, one has to implement the wrap() function as part of the type. There are a couple ways to implement the wrap() function and have it called, which are discussed below.

The advantage of the user defined type is that one has named-member access of a class instead of placing all the variables in an array or tensor and having to use indexing to access the variables.

The purpose of primalAndForwardDerivative() and primalAndReverseDerivative() is to calculate the derivatives of user defined types. The functions take a user defined input type, a user defined output type, and a user defined derivative type. In addition, the user defines a function for the calculations, and possibly a function to extract the derivatives from the calculations and place the results into the user defined derivative type. Also, the lambdas wrapInput and wrapOutput might need to be defined to get the wrap() function called internally in the code. Notice the similarity in names to primalAndForwardDerivative() and primalAndReverseDerivative(), as an "s" has been added to the end of the function names primalAndForwardDerivative() and primalAndReverseDerivative() .

primalAndForwardDerivative() and primalAndReverseDerivative() have essentially the same function signature . The function signatures are:

fun <Input : Any, Output : Any, Derivative : Any>primalAndForwardDerivative( x: Input, f: (Input) -> Output, wrapInput: ((Input, Wrapper) -> Input)? = null, wrapOutput: ((Output, Wrapper) -> Output)? = null, extractDerivative: (Input, Output, (input: DTensor, output: DTensor) -> DTensor) -> Derivative, ): Pair<Output, Derivative>

and

fun <Input : Any, Output : Any, Derivative : Any>primalAndReverseDerivative( x: Input, f: (Input) -> Output, wrapInput: ((Input, Wrapper) -> Input)? = null, wrapOutput: ((Output, Wrapper) -> Output)? = null, extractDerivative: (Input, Output, (input: DTensor, output: DTensor) -> DTensor) -> Derivative, ): Pair<Output, Derivative>

The type for Input, Output, and Derivative are user defined. The user defined types could be a class with DScalar or DTensor variables, a list with DScalar or DTensor elements, or something more complex.

The function f: (Input) -> Output has to know how to access the variables in the Input type and produce a return of Output type.

The Derivative type has to define all the possible derivates that can be produced from taking the derivative of f() with respect to the Input type.

The Input or Output types can inherit the Differentiable<T> interface, which knows how to call the wrap() function.. If the Input and Output types do not inherit from the Differentiable<T> interface, then a lambda expression needs to written for the wrapInput and/or the wrapOutput functions to call wrap() for the Input or Output type.