# -*- coding: utf-8 -*-
# File: model_desc.py
import tensorflow as tf
from ..utils.argtools import memoized_method
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..compat import _backport_tensor_spec, tfv1
TensorSpec = _backport_tensor_spec()
__all__ = ['ModelDesc', 'ModelDescBase']
[docs]class ModelDescBase(object):
"""
Base class for a model description.
It is used for the simple training interface described in
`Training Interface Tutorial <https://tensorpack.readthedocs.io/tutorial/training-interface.html>`_.
Subclass is expected to implement :meth:`inputs` and :meth:`build_graph`, as they
together define a tower function.
"""
@property
def input_names(self):
"""
list[str]: the names of all the inputs.
"""
return [k.name for k in self.get_input_signature()]
[docs] def build_graph(self, *args):
"""
A subclass is expected to implement this method.
Build the whole symbolic graph.
This is supposed to be part of the "tower function" when used with :class:`TowerTrainer`.
Args:
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
Returns:
In general it returns nothing, but a subclass
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
raise NotImplementedError()
@property
def training(self):
"""
bool: whether the caller is under a training context or not.
"""
return get_current_tower_context().is_training
[docs]class ModelDesc(ModelDescBase):
"""
One subclass of :class:`ModelDescBase` with the assupmtion of
**single cost** and **single optimizer** training.
It has the following constraints in addition to :class:`ModelDescBase`:
1. `build_graph(...)` method should return a cost tensor when called under a training context.
The cost will be the final cost to be optimized by the optimizer.
Therefore it should include necessary regularization.
2. Subclass is expected to implement :meth:`optimizer()` method.
"""
[docs] @memoized_method
def get_optimizer(self):
"""
Return the memoized optimizer returned by `optimizer()`.
Users of :class:`ModelDesc` will need to implement `optimizer()`,
which will only be called once per each model.
Returns:
a :class:`tf.train.Optimizer` instance.
"""
ret = self.optimizer()
assert isinstance(ret, tfv1.train.Optimizer), \
"ModelDesc.optimizer() must return an instance of tf.train.Optimizer! Got {} instead.".format(str(ret))
return ret
[docs] def optimizer(self):
"""
A subclass is expected to implement this method.
Returns:
a `tf.train.Optimizer` instance.
"""
raise NotImplementedError()