def __init__()

in twml/twml/trainers/trainer.py [0:0]


  def __init__(self, name, params, build_graph_fn,
               metric_fn=None,
               optimize_loss_fn=None,
               run_config=None,
               save_dir=None,
               init_from_dir=None,
               init_map=None,
               warm_start_from=None,
               profiler_steps=None,
               **kwargs):
    """

    Args:
      name (String):
        string name of this estimator; used as scope names for variables and tensors.
      params (HParams, Namespace, or Dict):
        hyper-parameters to be passed to Estimator constructor.
        Must include params.train_batch_size and params.eval_batch_size.
        Note that params is passed to twml.util.convert_to_hparams() to produce an HParams.
      build_graph_fn:
        A function for building tensorflow graphs.
        This matches TensorFlow Estimator's model_fn signature.
        For example,

        .. code-block:: python

          def build_graph(features, label, mode, params, config=None):
            # Implements a simple binary logistic regression model
            sparse_tf = twml.util.convert_to_sparse(features, params.input_size_bits)

            logits = twml.layers.full_sparse(sparse_tf, 1 << params.input_size_bits, 1)

            if mode == 'infer':
              loss = None
            else:
              loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logits)
              loss = twml.util.weighted_average(loss, features['weights'])

            output = tf.nn.sigmoid(logits)

            return {'output': output, 'loss': loss}

        Args:
          features (dict of Tensor keyed by a string name):
            input tensors.
          mode (tf.estimator.ModeKeys / String):
            one of 'train', 'eval', 'infer'.
          label (Tensor):
            if in ``mode == 'train'`` mode, these contain the corresponding labels for input.
          params (HParams):
            hyper parameters that control how to build a graph.
          config:
            the RunConfig object passed to Estimator constructor.

        This function is expected to return a dictionary containing the following keys:

        * 'output': a node representing model output; required.
        * 'loss': (required) a loss node used for optimization; required for training and
          evaluation.
        * 'train_op': (optional) an operation that minimizes the loss (as output by
          `tf.train.Optimizer.minimize`). If train_op is specified, train_op is used
          for optimization as opposed to loss. Loss is always logged to tensorboard.

        Notes:

        * any tf.summary written inside build graph are logged to tensorboard during training.
        * the ``build_graph_fn`` is called once or twice per epoch (once per training,
          once per evaluation). All data loading (and preprocessing) logic not required
          for serving should be in the ``input_fn`` passed to ``learn``, ``train``,
          ``evalulate``, etc.

      optimize_loss_fn:
        Defaults to Trainer.get_train_op. A function that takes params and loss as arguments
        and returns a training op. The training op is used to update parameters (that is, to learn).
      metric_fn:
        A function that returns the eval_metric_ops dict given graph_output, labels and weights.
        Defaults to None.
        Use ``twml.metrics.get_binary_class_metric_fn()`` to return a ``metric_fn``
        which implements many binary classification metrics.
      run_config (RunConfig):
        optional configuration to be passed to Estimator constructor. Defaults to None.
      save_dir (String):
        optional directory where to save model checkpoints,
        tensorboard event files and trained parameters.
        Overwrites and defaults to run_config.model_dir.
      init_from_dir (String):
        optional directory to load weights from.
        if set to None (the default), do not init from any directory.
      init_map (map from String to String):
        Must be specified if init_from_dir is specified.
        Defines which scopes and variables to load.
        Keys are the variables and scopes to load from the directory.
        Values are the destinations (in the current graph) to load into.
        See tf.init_from_checkpoint for more information.
        Note that the the trainer prepends name_scope of the form `name`/model/ to the name_scope
        of any variable defined inside `build_graph_fn` and this should be taken into account when
        defining the values.
      warm_start_from:
        Optional string filepath to a checkpoint to warm-start from,
        or a tf.estimator.WarmStartSettings object to fully configure warm-starting.
        If the string filepath is provided instead of a WarmStartSettings,
        then all variables are warm-started, and it is assumed that
        vocabularies and Tensor names are unchanged.
      profiler_steps (Integer):
        Defaults to None. If set defines the number of steps in the
        `tf.train.ProfileHook <https://www.tensorflow.org/api_docs/python/tf/train/ProfilerHook>`_.
        Captures CPU/GPU profiling information every ``profiler_steps`` steps or seconds.
        When executing ``learn``, ``train`` or ``predict`` methods,
        with ``profiler_steps`` set to a number,
        a ``timeline_X.json`` file is created in the save_dir. This file contains profiling data
        storedin Chrome trace format. To view stored data, use the Chrome browser to follow
        these steps:

        1) Go to the page chrome://tracing.
        2) In the upper left corner, you will find Load button.
        3) Press it and load our JSON file, which can be found in the ``save_dir``

        *Warning*: This could create too many these json files which can be a potential problem,
        e.g. for  HDFS there is normally quota forfile count, so use with caution.

        Note: this argument is ignored when a non-None ``hooks`` argument is pasesd to
        ``train``, ``learn``, or ``predict`` methods. The hook can be added manually by passing
        ``trainer.train(..., hooks=myhooks.extend(trainer.get_train_hooks()))``, for example.
    """

    if tensorflow.__version__ >= "2.0":
      RuntimeError("Trainer not yet supported for Tensorflow >= 2.0")

    self._name = name
    self._build_graph_fn = build_graph_fn
    self._metric_fn = metric_fn
    self._tensorboard_handle = None
    self._current_estimator_spec = None  # holds the current estimator spec
    self._profiler_steps = profiler_steps
    self._export_output_fn = None
    self._is_early_stopping = False

    # NOTE: Sanitize all HDFS paths first.
    save_dir = sanitize_hdfs_path(save_dir)
    init_from_dir = sanitize_hdfs_path(init_from_dir)

    # warm_start_from can be of type tf.estimator.WarmStartSettings.
    if isinstance(warm_start_from, str):
      warm_start_from = sanitize_hdfs_path(warm_start_from)

    # convert to twitter.deepbird.hparam.hparam.HParams object
    params = twml.util.convert_to_hparams(params)

    # keep a copy of the params because calling self._estimator.params creates a deepcopy
    self._params = params
    self.check_params()

    self._using_hogwild = True if os.environ.get('TWML_HOGWILD_PORTS') else False
    # configure Hogwild (needs to be called before RunConfig is created)
    self._hogwild_setup()

    if not run_config:
      session_config = tf.ConfigProto()
      # By default each process tries to allocate (almost) all of the memory.
      # This option ensures the gpu memory grows dynamically instead.
      session_config.gpu_options.allow_growth = True  # pylint: disable=no-member

      if 'TWML_NUM_CPUS' in os.environ:
        num_available_cpus = int(os.environ.get("TWML_MESOS_CPU", "8"))
        if params.num_mkl_threads > 1:
          os.environ["OMP_NUM_THREADS"] = str(params.num_mkl_threads)
          os.environ["MKL_NUM_THREADS"] = str(params.num_mkl_threads)
          session_config.inter_op_parallelism_threads = num_available_cpus // params.num_mkl_threads
          session_config.intra_op_parallelism_threads = params.num_mkl_threads

      run_config = tf.estimator.RunConfig(
        session_config=session_config,
        keep_checkpoint_max=self._params.get('keep_checkpoint_max', 20),
        log_step_count_steps=10000,
        save_checkpoints_secs=self._params.get('save_checkpoints_secs', 600),
        tf_random_seed=self._tf_random_seed())
    elif not isinstance(run_config, tf.estimator.RunConfig):
      raise ValueError("Expecting run_config argument of type None or tf.estimator.RunConfig"
        "Got %s instead." % type(run_config).__name__)
    elif os.environ.get('TWML_HOGWILD_PORTS'):
      raise ValueError("Custom RunConfig not supported with Hogwild")

    if run_config.model_dir is None and save_dir is None:
      raise ValueError(
          "Expecting either save_dir or run_config.model_dir to be specified. Got None for each.")
    elif run_config.model_dir is None:
      run_config = run_config.replace(model_dir=save_dir)
    elif save_dir is None:
      save_dir = run_config.model_dir

    self._save_dir = save_dir
    self.experiment_tracker = ExperimentTracker(self._params, run_config, self._save_dir)

    # Check if should delete the tsd running this training job. In certain use case when 
    # there are other tf operations following trainer.train_and_evaluate (or trainer.learn),
    # additional state files need to be specified to ensure those steps are executed after job restart.
    kwargs['gke_state_files'] = kwargs.get('gke_state_files', ['_SUCCESS'])
    self._maybe_del_tsd_exit(kwargs['gke_state_files'])
    logging.info("Checkpoint and event files will be saved at save_dir=%s", save_dir)
    self._optimize_loss_fn = self.get_train_op if optimize_loss_fn is None else optimize_loss_fn

    # overwrite the current save_dir
    if self._params.get('overwrite_save_dir') and tf.io.gfile.exists(self._save_dir):
      logging.info("Trainer overwriting existing save directory: %s (params.overwrite_save_dir)"
                   % self._save_dir)
      # if distributed or hogwild:
      if self._params.get('distributed', False):
        # sleep for 30 seconds to allow each worker to get to this point.
        time.sleep(30)
        if run_config.is_chief:
          logging.info("Chief deleting the save_dir now")
          delete_file_or_dir(self._save_dir)
        # sleep for 30 seconds to allow each worker to get to this point.
        time.sleep(30)
      else:
        delete_file_or_dir(self._save_dir)

    # Exposing stats to a /vars.json endpoint that will be collected
    # by the absorber
    if self._params.get('stats_port'):
      try:
        stats_server_utils.start_stats_server(self._params.get('stats_port'), self._save_dir)
      except Exception as err:
        logging.error('Failed to start the stats server. Error: %s', str(err))

    checkpoint = os.path.join(self._save_dir, 'checkpoint')
    if tf.io.gfile.exists(checkpoint):
      logging.info("The provided save_dir directory %s already exists."
                   " Training will be resumed."
                   % checkpoint)

    self._maybe_restore_checkpoint = lambda: init_from_checkpoint(init_from_dir, init_map)

    if init_from_dir is not None and init_map is None:
      raise ValueError("Need to provide init_map when init_from_dir is provided.")

    if not tf.io.gfile.exists(self._save_dir):
      # so tensorboard can point to a directory that exists
      tf.io.gfile.mkdir(self._save_dir)

    self._estimator = tf.estimator.Estimator(
      model_fn=self._model_fn,
      params=self._params,  # HParams
      config=run_config,  # RunConfig
      warm_start_from=warm_start_from,
      model_dir=self._save_dir,  # By this point it is same as run_config.model_dir
    )

    # Log parameters that are used to construct trainer. This allows people to see default values.
    logging.info("Trainer constructed using the following parameters: ")
    pp_params = pp.pformat(self._params.values())
    logging.info(pp_params)

    # Start TensorBoard
    if self._params.get('disable_tensorboard', False):
      logging.info("Skipping launching TensorBoard [--disable_tensorboard is set]")
    elif "tensorboard_port" in self._params.values() and self._params.tensorboard_port is not None:
      self.start_tensorboard(self._params.tensorboard_port)

    # Export gauge that will track whether a model was exported
    self.stats_exporter = StatsExporter("twml.trainer")
    self.export_gauge = AtomicGauge('export_model')
    self.stats_exporter.register_metrics(self.export_gauge)