Tuesday, June 10, 2014

Factoring Kernels

Distribute control across complex kernels
Break down complex transformations into a linear sequence of smaller transformations

A technique that I use repeatedly to make the structure of a program more flexible and easier to understand is to distribute control flow across a complex kernel.  The control structure is usually some kind of iteration, the kernel is the body of the loop.  This transformation can be applied if the steps contained in the kernel are independent of one another and they do not have to be completed in any particular order.

It is a hallmark of premature optimization to mash together several unrelated operations under a single copy of the control structure.  This is done under the mistaken notion that it is more efficient to amortize the loop overhead across them.  The assumption may be true at the outset, but coalescing loops that have slightly different filter functions leads to the insertion of additional conditions within the loop.  The disruption of control flow due to these conditions will very quickly negate the efficiency that loop coalescing was intended to provide.  That being the case, application of this kernel refactoring often enables further simplifications.

The factoring is straightforward, and can be performed incrementally using the following steps:
  1. Insert a comment to identify the start (and end) of each kernel;
  2. Remove the control structure and insert a copy around each kernel;
  3. Factor the resulting code into separate functions, if appropriate.
It should be stressed, that this refactoring is based on the assumption that each kernel can be performed independently over the entire data structure being traversed.  This assumption must be verified, either by using detailed knowledge of the effects of each kernel, or by applying verification testing after the refactoring is performed.  As a general rule of thumb, each function/method in a program should fit on a single screen.  To put numbers to this, a function longer than about 50 source lines should be considered for refactoring.

Here is a function I recently refactored using this technique [courtesy of the Chapel project, https://sourceforge.net/projects/chapel/23535/tree/trunk/compiler/resolution/functionResolution.cpp:7134]:
static void removeUnusedFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
        // Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
        // Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE) &&
             !fn->hasFlag(FLAG_EXTERN))) {
          SET_LINENO(formal);
          formal->defPoint->remove();
          VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
          fn->insertAtHead(new DefExpr(tmp));
          if (symExprs.n == 0)
            collectSymExprs(fn->body, symExprs);
          forv_Vec(SymExpr, se, symExprs) {
            if (se->var == formal) {
              if (CallExpr* call = toCallExpr(se->parentExpr))
                if (call->isPrimitive(PRIM_DEREF))
                  se->getStmtExpr()->remove();
              se->var = tmp;
            }
          }
        }
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
            Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        fn->retType : runtimeTypeMap.get(fn->retType);
            INT_ASSERT(rt);
            formal->type =  rt;
            formal->removeFlag(FLAG_TYPE_VARIABLE);
          }
        }
      }
      if (fn->where)
        fn->where->remove();
      if (fn->retTag == RET_TYPE) {
        VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
The control structure is the loop over FnSymbols, filtering out those that do not have a valide defPoint or parentSymbol.

Step 1. Add kernel comments:
static void removeUnusedFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Kernel 1: Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
 

        // Kernel 2: Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
 

        // Kernel 3: Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();


        // Kernel 4: Convert type variable formals.
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE) &&
             !fn->hasFlag(FLAG_EXTERN))) {
          SET_LINENO(formal);
          formal->defPoint->remove();
          VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
          fn->insertAtHead(new DefExpr(tmp));
          if (symExprs.n == 0)
            collectSymExprs(fn->body, symExprs);
          forv_Vec(SymExpr, se, symExprs) {
            if (se->var == formal) {
              if (CallExpr* call = toCallExpr(se->parentExpr))
                if (call->isPrimitive(PRIM_DEREF))
                  se->getStmtExpr()->remove();
              se->var = tmp;
            }
          }
        }
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
            Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        fn->retType : runtimeTypeMap.get(fn->retType);
            INT_ASSERT(rt);
            formal->type =  rt;
            formal->removeFlag(FLAG_TYPE_VARIABLE);
          }
        }
      }


      // Kernel 5: Remove where clauses
      if (fn->where)
        fn->where->remove();


      // Kernel 6: Convert runtime return types.
      if (fn->retTag == RET_TYPE) {
        VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
Step 2. Distribute control structure.
static void removeUnusedFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 1: Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
 
     }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 2: Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
 
     }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 3: Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();

      }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Kernel 4: Convert type variable formals.
        if (formal->hasFlag(FLAG_TYPE_VARIABLE)) {
          if (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
            if (!fn->hasFlag(FLAG_EXTERN))) {
              SET_LINENO(formal);
              formal->defPoint->remove();
              VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
              fn->insertAtHead(new DefExpr(tmp));
              if (symExprs.n == 0)
                collectSymExprs(fn->body, symExprs);
              forv_Vec(SymExpr, se, symExprs) {
                if (se->var == formal) {
                  if (CallExpr* call = toCallExpr(se->parentExpr))
                    if (call->isPrimitive(PRIM_DEREF))
                      se->getStmtExpr()->remove();
                  se->var = tmp;
                }
              }

            }
          } else {
            if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
              Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                          fn->retType : runtimeTypeMap.get(fn->retType);
              INT_ASSERT(rt);
              formal->type =  rt;
              formal->removeFlag(FLAG_TYPE_VARIABLE);
            }

          }
        }
      }

    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 5: Remove where clauses
        if (fn->where)
          fn->where->remove();

      }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 6: Convert runtime return types.
        if (fn->retTag == RET_TYPE) {
          VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
Notice that the factoring makes it clear that the vector symExprs is used only in Kernel 4, so I removed it elsewhere.  Also, in kernel 4, the test for a formal having the FLAG_TYPE_VARIABLE flag can be factored out of the two cases.  Similarly, the mutually-exclusive cases involving the FLAG_HAS_RUNTIME_TYPE can be converted into an if ... else ... construct.  Apparently, the function cannot carry FLAG_EXTERN if FLAG_HAS_RUNTIME_TYPE is true, so this test should probably be moved into an outer scope.

Step 3. Break into separate functions.
static void removeDefaultExprs() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 1: Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
 
     }
    }
  }

}

static void removeTypeExprs() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 2: Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
 
     }
    }
  }

}

static void removeMethodAndTokenFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 3: Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();

      }
    }
  }

}

static void convertTypeVariableFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Kernel 4: Convert type variable formals.
        if (formal->hasFlag(FLAG_TYPE_VARIABLE)) {
          if (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
            if (!fn->hasFlag(FLAG_EXTERN))) {
              SET_LINENO(formal);
              formal->defPoint->remove();
              VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
              fn->insertAtHead(new DefExpr(tmp));
              if (symExprs.n == 0)
                collectSymExprs(fn->body, symExprs);
              forv_Vec(SymExpr, se, symExprs) {
                if (se->var == formal) {
                  if (CallExpr* call = toCallExpr(se->parentExpr))
                    if (call->isPrimitive(PRIM_DEREF))
                      se->getStmtExpr()->remove();
                  se->var = tmp;
                }
              }

            }
          } else {
            if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
              Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                          fn->retType : runtimeTypeMap.get(fn->retType);
              INT_ASSERT(rt);
              formal->type =  rt;
              formal->removeFlag(FLAG_TYPE_VARIABLE);
            }

          }
        }
 
     }
    }
  }

}

static void removeWhereClauses() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 5: Remove where clauses
        if (fn->where)
          fn->where->remove();

      }
    }
  }
 

static void convertRuntimeReturnTypes() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 6: Convert runtime return types.
        if (fn->retTag == RET_TYPE) {
          VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
 I used the function name convention in Renaming Functions [TBS] to choose appropriate names for the new functions.

The original function  must be replaced with one that calls each of the new functions:
static void removeUnusedFormals() {
  removeDefaultExprs();
  removeTypeExprs();
  removeMethodAndTokenFormals();
  convertTypeVariableFormals();
  removeWhereClauses();
  convertRuntimeReturnTypes();
}

This makes it readily apparent that the original function was performing (at least) six separate actions -- some related, some unrelated.  It is conceivable that the conversion of runtime type formals and return types was added after the original remove<>() functions were already there, so the programmer could save the effort of cutting-and-pasting the control structure (thereby also avoiding providing a clear delineation of the actions being performed). Those conversions should have been put in their own routine (convertRuntimeTypeFormalsAndReturnValues(), e.g.) to make obvious the intent of the added code.

Making the code run efficiently is the compiler's job.  Making the code obvious, and consequently easy to understand and maintain is the programmer's job.

No comments:

Post a Comment