Tuesday, June 10, 2014

Factoring Kernels

Distribute control across complex kernels
Break down complex transformations into a linear sequence of smaller transformations

A technique that I use repeatedly to make the structure of a program more flexible and easier to understand is to distribute control flow across a complex kernel.  The control structure is usually some kind of iteration, the kernel is the body of the loop.  This transformation can be applied if the steps contained in the kernel are independent of one another and they do not have to be completed in any particular order.

It is a hallmark of premature optimization to mash together several unrelated operations under a single copy of the control structure.  This is done under the mistaken notion that it is more efficient to amortize the loop overhead across them.  The assumption may be true at the outset, but coalescing loops that have slightly different filter functions leads to the insertion of additional conditions within the loop.  The disruption of control flow due to these conditions will very quickly negate the efficiency that loop coalescing was intended to provide.  That being the case, application of this kernel refactoring often enables further simplifications.

The factoring is straightforward, and can be performed incrementally using the following steps:
  1. Insert a comment to identify the start (and end) of each kernel;
  2. Remove the control structure and insert a copy around each kernel;
  3. Factor the resulting code into separate functions, if appropriate.
It should be stressed, that this refactoring is based on the assumption that each kernel can be performed independently over the entire data structure being traversed.  This assumption must be verified, either by using detailed knowledge of the effects of each kernel, or by applying verification testing after the refactoring is performed.  As a general rule of thumb, each function/method in a program should fit on a single screen.  To put numbers to this, a function longer than about 50 source lines should be considered for refactoring.

Here is a function I recently refactored using this technique [courtesy of the Chapel project, https://sourceforge.net/projects/chapel/23535/tree/trunk/compiler/resolution/functionResolution.cpp:7134]:
static void removeUnusedFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
        // Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
        // Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE) &&
             !fn->hasFlag(FLAG_EXTERN))) {
          SET_LINENO(formal);
          formal->defPoint->remove();
          VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
          fn->insertAtHead(new DefExpr(tmp));
          if (symExprs.n == 0)
            collectSymExprs(fn->body, symExprs);
          forv_Vec(SymExpr, se, symExprs) {
            if (se->var == formal) {
              if (CallExpr* call = toCallExpr(se->parentExpr))
                if (call->isPrimitive(PRIM_DEREF))
                  se->getStmtExpr()->remove();
              se->var = tmp;
            }
          }
        }
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
            Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        fn->retType : runtimeTypeMap.get(fn->retType);
            INT_ASSERT(rt);
            formal->type =  rt;
            formal->removeFlag(FLAG_TYPE_VARIABLE);
          }
        }
      }
      if (fn->where)
        fn->where->remove();
      if (fn->retTag == RET_TYPE) {
        VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
The control structure is the loop over FnSymbols, filtering out those that do not have a valide defPoint or parentSymbol.

Step 1. Add kernel comments:
static void removeUnusedFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Kernel 1: Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
 

        // Kernel 2: Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
 

        // Kernel 3: Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();


        // Kernel 4: Convert type variable formals.
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE) &&
             !fn->hasFlag(FLAG_EXTERN))) {
          SET_LINENO(formal);
          formal->defPoint->remove();
          VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
          fn->insertAtHead(new DefExpr(tmp));
          if (symExprs.n == 0)
            collectSymExprs(fn->body, symExprs);
          forv_Vec(SymExpr, se, symExprs) {
            if (se->var == formal) {
              if (CallExpr* call = toCallExpr(se->parentExpr))
                if (call->isPrimitive(PRIM_DEREF))
                  se->getStmtExpr()->remove();
              se->var = tmp;
            }
          }
        }
        if (formal->hasFlag(FLAG_TYPE_VARIABLE) &&
            formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
            Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        fn->retType : runtimeTypeMap.get(fn->retType);
            INT_ASSERT(rt);
            formal->type =  rt;
            formal->removeFlag(FLAG_TYPE_VARIABLE);
          }
        }
      }


      // Kernel 5: Remove where clauses
      if (fn->where)
        fn->where->remove();


      // Kernel 6: Convert runtime return types.
      if (fn->retTag == RET_TYPE) {
        VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
Step 2. Distribute control structure.
static void removeUnusedFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 1: Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
 
     }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 2: Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
 
     }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 3: Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();

      }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Kernel 4: Convert type variable formals.
        if (formal->hasFlag(FLAG_TYPE_VARIABLE)) {
          if (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
            if (!fn->hasFlag(FLAG_EXTERN))) {
              SET_LINENO(formal);
              formal->defPoint->remove();
              VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
              fn->insertAtHead(new DefExpr(tmp));
              if (symExprs.n == 0)
                collectSymExprs(fn->body, symExprs);
              forv_Vec(SymExpr, se, symExprs) {
                if (se->var == formal) {
                  if (CallExpr* call = toCallExpr(se->parentExpr))
                    if (call->isPrimitive(PRIM_DEREF))
                      se->getStmtExpr()->remove();
                  se->var = tmp;
                }
              }

            }
          } else {
            if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
              Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                          fn->retType : runtimeTypeMap.get(fn->retType);
              INT_ASSERT(rt);
              formal->type =  rt;
              formal->removeFlag(FLAG_TYPE_VARIABLE);
            }

          }
        }
      }

    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 5: Remove where clauses
        if (fn->where)
          fn->where->remove();

      }
    }
  }

  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 6: Convert runtime return types.
        if (fn->retTag == RET_TYPE) {
          VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
Notice that the factoring makes it clear that the vector symExprs is used only in Kernel 4, so I removed it elsewhere.  Also, in kernel 4, the test for a formal having the FLAG_TYPE_VARIABLE flag can be factored out of the two cases.  Similarly, the mutually-exclusive cases involving the FLAG_HAS_RUNTIME_TYPE can be converted into an if ... else ... construct.  Apparently, the function cannot carry FLAG_EXTERN if FLAG_HAS_RUNTIME_TYPE is true, so this test should probably be moved into an outer scope.

Step 3. Break into separate functions.
static void removeDefaultExprs() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 1: Remove formal default values
        if (formal->defaultExpr)
          formal->defaultExpr->remove();
 
     }
    }
  }

}

static void removeTypeExprs() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 2: Remove formal type expressions
        if (formal->typeExpr)
          formal->typeExpr->remove();
 
     }
    }
  }

}

static void removeMethodAndTokenFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 3: Remove method and leader token formals
        if (formal->type == dtMethodToken || formal->hasFlag(FLAG_INSTANTIATED_PARAM))
          formal->defPoint->remove();

      }
    }
  }

}

static void convertTypeVariableFormals() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      Vec<SymExpr*> symExprs;
      for_formals(formal, fn) {
        // Kernel 4: Convert type variable formals.
        if (formal->hasFlag(FLAG_TYPE_VARIABLE)) {
          if (!formal->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
            if (!fn->hasFlag(FLAG_EXTERN))) {
              SET_LINENO(formal);
              formal->defPoint->remove();
              VarSymbol* tmp = newTemp("_formal_type_tmp_", formal->type);
              fn->insertAtHead(new DefExpr(tmp));
              if (symExprs.n == 0)
                collectSymExprs(fn->body, symExprs);
              forv_Vec(SymExpr, se, symExprs) {
                if (se->var == formal) {
                  if (CallExpr* call = toCallExpr(se->parentExpr))
                    if (call->isPrimitive(PRIM_DEREF))
                      se->getStmtExpr()->remove();
                  se->var = tmp;
                }
              }

            }
          } else {
            if (FnSymbol* fn = valueToRuntimeTypeMap.get(formal->type)) {
              Type* rt = (fn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                          fn->retType : runtimeTypeMap.get(fn->retType);
              INT_ASSERT(rt);
              formal->type =  rt;
              formal->removeFlag(FLAG_TYPE_VARIABLE);
            }

          }
        }
 
     }
    }
  }

}

static void removeWhereClauses() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 5: Remove where clauses
        if (fn->where)
          fn->where->remove();

      }
    }
  }
 

static void convertRuntimeReturnTypes() {
  forv_Vec(FnSymbol, fn, gFnSymbols) {
    if (fn->defPoint && fn->defPoint->parentSymbol) {
      for_formals(formal, fn) {
        // Kernel 6: Convert runtime return types.
        if (fn->retTag == RET_TYPE) {
          VarSymbol* ret = toVarSymbol(fn->getReturnSymbol());
        if (ret && ret->type->symbol->hasFlag(FLAG_HAS_RUNTIME_TYPE)) {
          if (FnSymbol* rtfn = valueToRuntimeTypeMap.get(ret->type)) {
            Type* rt = (rtfn->retType->symbol->hasFlag(FLAG_RUNTIME_TYPE_VALUE)) ?
                        rtfn->retType : runtimeTypeMap.get(rtfn->retType);
            INT_ASSERT(rt);
            ret->type = rt;
            fn->retType = ret->type;
            fn->retTag = RET_VALUE;
          }
        }
      }
    }
  }
}
 I used the function name convention in Renaming Functions [TBS] to choose appropriate names for the new functions.

The original function  must be replaced with one that calls each of the new functions:
static void removeUnusedFormals() {
  removeDefaultExprs();
  removeTypeExprs();
  removeMethodAndTokenFormals();
  convertTypeVariableFormals();
  removeWhereClauses();
  convertRuntimeReturnTypes();
}

This makes it readily apparent that the original function was performing (at least) six separate actions -- some related, some unrelated.  It is conceivable that the conversion of runtime type formals and return types was added after the original remove<>() functions were already there, so the programmer could save the effort of cutting-and-pasting the control structure (thereby also avoiding providing a clear delineation of the actions being performed). Those conversions should have been put in their own routine (convertRuntimeTypeFormalsAndReturnValues(), e.g.) to make obvious the intent of the added code.

Making the code run efficiently is the compiler's job.  Making the code obvious, and consequently easy to understand and maintain is the programmer's job.

Friday, June 6, 2014

Renaming Variables

The simplest refactoring involves renaming a variable.  The goal is to enhance readability by making the code as self-documenting as possible.  The standard rules-of-thumb apply: in particular, if the reader is forced to look elsewhere to resolve such things as scope, type and meaning, then the name isn't very good.

A balancing consideration is that overly verbose names take longer to type.  In addition, longer names must be chosen with care based on their context: they will tend to inhibit cutting-and-pasting or at least require extensive fix-up.  For that reason, I typically choose very short names for local variables and reserve longer ones for where I'm forced to place a variable in an outer scope.  [TBS: Moving globals into singletons]

Some coding guidelines used to include type suffixes, so the type of an object could be determined without referring back to the object's declaration.  But this practice has largely become obsolete for two reasons: the type of a variable can often be inferred from context; and IDEs can quickly display the type of a variable in response to a mouse-over or "Show Definition" command.

Coding guidelines also used to encourage prefixes such as "g_" and "s_" to mark variables as belonging to the global or filescope-static namespaces.  This can still be useful in legacy code, since it prevents conflicts between names that are used in the three main namespaces provided by C.  Again, smarter IDEs and the trend toward encapsulating every state variable in a class make this practice less valuable in modern programming.

Let's pick a random code sample and see how I might fix it up.  Here's the original, excerpted from functionResolution.cpp:3295ff [Courtesy of the Chapel project -- https://sourceforge.net/projects/chapel/]:

static void resolveSetMember(CallExpr* call) {
  // Get the field name.
  SymExpr* sym = toSymExpr(call->get(2));
  if (!sym)
    INT_FATAL(call, "bad set member primitive");
  VarSymbol* var = toVarSymbol(sym->var);
  if (!var || !var->immediate)
    INT_FATAL(call, "bad set member primitive");
  const char* fieldName = var->immediate->v_string;

  // Special case: An integer field name is actually a tuple member index.
  {
    int64_t i;
    if (get_int(sym, &i)) {
      name = astr("x", istr(i));
      call->get(2)->replace(new SymExpr(new_StringSymbol(name)));
    }
  }

  AggregateType* ct = toAggregateType(call->get(1)->typeInfo());
  if (!ct)
    INT_FATAL(call, "bad set member primitive");

  Symbol* fs = NULL;
  for_fields(field, ct) {
    if (!strcmp(field->name, name)) {
      fs = field; break;
    }
  }

  if (!fs)
    INT_FATAL(call, "bad set member primitive");

  Type* t = call->get(3)->typeInfo();
  // I think this never happens, so can be turned into an assert. <hilde>
  if (t == dtUnknown)
    INT_FATAL(call, "Unable to resolve field type");

  if (t == dtNil && fs->type == dtUnknown)
    USR_FATAL(call->parentSymbol, "unable to determine type of field from nil");
  if (fs->type == dtUnknown)
    fs->type = t;

  if (t != fs->type && t != dtNil && t != dtObject) {
    USR_FATAL(userCall(call),
              "cannot assign expression of type %s to field of type %s",
              toString(t), toString(fs->type));
  }
}
And here is my rework:
static void resolveSetMember(CallExpr* call) {
  // Get the field name.
  SymExpr* fieldUse = toSymExpr(call->get(2));
  if (!fieldUse)
    INT_FATAL(call, "bad set member primitive");
  VarSymbol* fieldSym = toVarSymbol(fieldUse->var);
  if (!fieldSym || !fieldSym->immediate)
    INT_FATAL(call, "bad set member primitive");
  const char* fieldName = fieldSym->immediate->v_string;

  // Special case: An integer field name is actually a tuple member index.
  {
    int64_t tuple_member_idx;
    if (get_int(fieldUse, &tuple_member_idx)) {
      fieldName = astr("x", istr(tuple_member_idx));
      call->get(2)->replace(new SymExpr(new_StringSymbol(fieldName)));
    }
  }

  AggregateType* at = toAggregateType(call->get(1)->typeInfo());
  if (!at)
    INT_FATAL(call, "bad set member primitive");

  Symbol* namedField = NULL;
  for_fields(field, at) {
    if (!strcmp(field->name, name)) {
      namedField = field; break;
    }
  }

  if (!namedField)
    INT_FATAL(call, "bad set member primitive");

  Type* valType = call->get(3)->typeInfo();
  // I think this never happens, so can be turned into an assert. <hilde>
  if (valType == dtUnknown)
    INT_FATAL(call, "Unable to resolve field type");

  if (valType == dtNil && namedField->type == dtUnknown)
    USR_FATAL(call->parentSymbol, "unable to determine type of field from nil");
  if (namedField->type == dtUnknown)
    namedField->type = valType;

  if (valType != namedField->type && valType != dtNil && valType != dtObject) {
    USR_FATAL(userCall(call),
              "cannot assign expression of type %s to field of type %s",
              toString(valType), toString(namedField->type));
  }
}
Here's what I did:
  • Replaced "sym" with "fieldUse"; "sym" tells me nothing, because that's the type of the variable.  In this code variables of type SymExpr appear as arguments in CallExpr expressions, so it is natural to call them uses.  This variable represents a use of the field symbol in the "set member" primitive we're trying to resolve.
  • Replaced "var" with "fieldDef".  The "var" is the symbol referred to in the SymExpr.
  • Replaced "name" with "fieldName".  This routine is complex enough, we might forget what name we were looking for.
  • Replaced "i" with "tuple_member_idx".  Except for the horrid special-case use of a field name that is parsable as an integer to represent a tuple index, this name substitution makes the next little bit of code almost self-documenting. [How to avoid that special-casing may be the subject of a future blog.]
  •  Replaced "ct" with "at".  AggregateType used to be call ClassType, so this just updates the local variable to correspond to the new type name.  We could choose a more indicative name here.  But given the context that we're resolving a "set member" primitive, the reader should be able to figure out that the AggregateType is the type of the aggregate whose member is getting set.
  • Replaced "fs" with "namedField".  I use "named" in the sense of "selected".  This is the field we are looking for, and we find it in the AggregateType by name.
  • Replaced "t" with "valType".  The third argument is the value being used to overwrite the contents of the named field.  There are three arguments to the "set member" primitive, so giving the type variable a more elaborate name than "t" helps the reader figure our (or remember) which type this is.

You may notice that we reuse the variable fieldName for the main case as well as  the special case of a tuple member index.  Reusing names is generally frowned upon [perhaps the subject of a future blog].  Stay tuned.