module core.graph
contains Node
and ExecutableGraph
.
Global Variables
- TYPE_CHECKING
class InvalidURIError
An Exception raised when valid node URI is expected.
class StopTask
An Exception raised to exit current task.
class StopAllTasks
An Exception raised to exit current running.
class Node
This class defines the executable node.
A executable graph is defined by a collection of executable nodes and their dependency relationships.
A node is executable if it has at least following phases of execution: forward
, backward
, update
. Different subclass of nodes may implement them differently.
This class is designed to be executed easily in batch mode (see ExecutableGraph.apply()
for details), so that a bunch of nodes can execute together, respecting several synchronization points between phases.
The dependency relationship is determined at runtime by how user access the graph
argument of Node.forward()
function. The graph
argument is actually a cache (a GraphOutputCache
instance) of the graph nodes outputs. The results of precedent nodes will be saved in the cache, so dependents can retrieve them easily.
method __init__
property board
the board writer of current task.
property device
the assigned device by current launcher.
property epoch_size
the size of current epoch.
property epoch_steps
the steps of current epoch.
property global_auto_steps
the global steps of current task.
property global_train_epochs
the global train epochs of current task.
property global_train_steps
the global train steps of current task.
property grad_acc_steps
the grad accumulator steps of current task.
property grad_scaler
the grad scaler of current task.
property launcher
the current launcher.
property name
the node name in the current activated ExecutableGraph
.
property out_dir
the output directory of current task.
property run_id
the run id of current task.
property step_mode
whether current task is running by step (True) or by epoch (False).
property task
the current task.
property training
whether current task is training.
method backward
calculates gradients.
method clean_up
an event hook for clean up all resources at switching executable graphs.
method dry_run
only update states about progress.
method epoch_end
an event hook for epoch end. (only for epoch mode)
method epoch_start
an event hook for epoch start. (only for epoch mode)
method forward
retrieves forward output in cache or calculates it using forward_impl
and save the output to the cache. Subclasses should not override this method.
method forward_impl
forward pass of the node, inputs of current executable graph can be directly retrieved from graph
argument.
method load_state_dict
resumes node state from state_dict.
method move
Moves data to the CPU or the GPU.
If :attr:device
is None
and the node has a device
attribute, then that is the device where the data
is moved. Otherwise, the data is moved according to the device.type
. If :attr:data
is a tuple or list,
the function is applied recursively to each of the elements.
Args:
-
data
(torch.Tensor or torch.nn.Module or list or dict): the data to move. -
device
(str or torch.device): the string or instance of torch.device in which to move the data.
Returns:
torch.Tensor or torch.nn.Module
: the data in the requested device.
Raises:
RuntimeError
: data is not one of torch.Tensor, torch.nn.module, list or dict.
method prepare
an event hook for prepare all resources at switching executable graphs.
method state_dict
returns serialization of current node.
method update
update parameters or buffers, e.g. using SGD based optimizer to update parameters.
class GraphOutputCache
a cache for storing and searching forward outputs of nodes.
This class is used to store and search forward outputs of nodes.
Attributes:
-
cache
(dict): a dict for storing forward outputs. -
egraph
(ExecutableGraph): the executable graph. -
data
(Dict[str, torch.Tensor]): the cache.
method __init__
method __getitem__
Execute node with name name
if not executed, return the last executed cache else.
method clear
Clear the cache, next calls to __getitem__
will recalculate.
class ExecutableGraph
an executable graph.
This class is used to execute nodes in a graph.
Attributes:
-
hypergraph
(HyperGraph): the hypergraph. -
nodes
(Dict[str, Node]): a dict for storing nodes. -
nodes_tags
(Dict[str, str]): a dict for storing tags of nodes. -
nodes_names
(Dict[str, str]): a dict for storing names of nodes. -
cache
(GraphOutputCache): a cache for storing and searching forward outputs of nodes. -
task
: the task of the graph. -
losses
: the losses of the graph. -
total_loss
: the total loss of the graph.
method __init__
property grad_scaler
method add_node
add a node to the graph.
Args:
-
node_name
(str): the name of the node. -
node
(Node): the node. -
tags
(List[str]): the tags of the node.
Raises:
RuntimeError
: the node is not a node of the hypergraph.
method apply
apply method
to all nodes in the graph.
Args:
-
method
(str): the method name. -
*args
: the arguments of the method. -
filter
(Callable[[Node], bool]): the filter function. -
**kwds
: the keyword arguments of the method.
Returns:
List[Any]
: the return values of the method.
Raises:
RuntimeError
: the method is not found.
method clean_up_nodes
clean up all nodes in the graph.
This method is called after the graph is executed.
method items
method iterate
iterate all nodes in the graph.
method prepare_nodes
prepare all nodes in the graph.
This method is called before the graph is executed.