Package org.diffkt.tracing

Types

DedaggedTracingTensor
Link copied to clipboard
data class DedaggedTracingTensor<T : Any>(numInputs: Int, numTemps: Int, numResults: Int, assignments: List<Pair<Int, Traceable>>, value: T, traceId: TraceId, canScalarEval: Boolean)

The result of removing reused nodes in a data structure. Reused nodes are replaced by a TracingTensor.Variable, and an assignment to that variable is placed into the resulting assignments.

JitEvaluatorToUse
Link copied to clipboard
enum JitEvaluatorToUse : Enum<JitEvaluatorToUse>

The caller specifies which evaluator to use. This is temporary; we want to select automatically.

JittedFunction
Link copied to clipboard
interface JittedFunction<TInput, TOutput> : Function1<TInput, TOutput>
PrintedTensor
Link copied to clipboard
class PrintedTensor(printed: String, shape: Shape) : TracingTensor

This class is used when "printing" complex data structures by tracingPrintedForm. Tracing tensors are replaced by an instance of this class to hold the printed form.

Result
Link copied to clipboard
typealias Result = Pair<String, Boolean>

The result type of the tracing printer is a string containing the value, and a boolean which is true if the resulting expression is not a primary and may need parens in an enclosing expression.

Traceable
Link copied to clipboard
interface Traceable
TraceId
Link copied to clipboard
class TraceId

A TraceId is used to distinguish one trace from another. That way the variables from one trace will not be confused with the variables from another possibly nested trace, as they may share identifiers.

TracingRandomKey
Link copied to clipboard
interface TracingRandomKey : RandomKey, Traceable
TracingScalar
Link copied to clipboard
interface TracingScalar : TracingTensor, DScalar
TracingTensor
Link copied to clipboard
interface TracingTensor : DTensor, Traceable
TracingTensorOperations
Link copied to clipboard
object TracingTensorOperations : Operations
TracingVisitor
Link copied to clipboard
interface TracingVisitor<R>

Functions

dedag
Link copied to clipboard
fun <T : Any> dedag(value: T, numInputs: Int, traceId: TraceId, rewriteVariableReferences: Boolean = true): DedaggedTracingTensor<T>

Dedag (remove reused tracing tensor nodes from) an arbitrary data structure.

eval
Link copied to clipboard
fun <W : Wrappable<W>> W.eval(variables: Array<DTensor?>, traceId: TraceId): W
jit
Link copied to clipboard
fun <TInput : Any, TOutput : Any> jit(f: (TInput) -> TOutput): JittedFunction<TInput, TOutput>
fun <TInput : Any, TOutput : Any> jit(f: (TInput) -> TOutput, wrapInput: (TInput, Wrapper) -> TInput? = null, cacheSizeLimit: Int = 5, evaluatorToUse: JitEvaluatorToUse = JitEvaluatorToUse.BestAvailable, loggingName: String = "jit", shouldLog: Boolean = false): JittedFunction<TInput, TOutput>

Transform a (differentiable) function into a function of the same signature but which unrolls all loops and control constructs, and performs a set of local optimizations on the result. Because control-flow is removed, any data-dependent control-flow in the function will no longer depend on the input data in the second and subsequent invocations of the returned function.

printedForm
Link copied to clipboard
fun Traceable.printedForm(numInputs: Int? = null): String

Print a tracing tensor, with reused nodes represented using an assignment to a temporary variable.

rawPrintedForm
Link copied to clipboard
fun Traceable.rawPrintedForm(): String

Print a tracing tensor, assuming there are no reused nodes.

simplify
Link copied to clipboard
fun <TData : Any> simplify(data: TData): TData
topologicalSort
Link copied to clipboard
fun <TNode> topologicalSort(roots: List<TNode>, successors: (TNode) -> List<TNode>, skip: (TNode) -> Boolean = { false }): List<TNode>?

A topological sort, which processes an acyclic graph and returns a topologically sorted list of its nodes, in which each node precedes any appearance of its successors. Returns null if the input graph is found to have a cycle.

useCounts
Link copied to clipboard
fun useCounts(roots: List<Traceable>): HashMap<Traceable, Int>