Skip to content

module core.graph

contains Node and ExecutableGraph.

Global Variables

  • TYPE_CHECKING

class InvalidURIError

An Exception raised when valid node URI is expected.


class StopTask

An Exception raised to exit current task.


class StopAllTasks

An Exception raised to exit current running.


class Node

This class defines the executable node.

A executable graph is defined by a collection of executable nodes and their dependency relationships.

A node is executable if it has at least following phases of execution: forward, backward, update. Different subclass of nodes may implement them differently.

This class is designed to be executed easily in batch mode (see ExecutableGraph.apply() for details), so that a bunch of nodes can execute together, respecting several synchronization points between phases.

The dependency relationship is determined at runtime by how user access the graph argument of Node.forward() function. The graph argument is actually a cache (a GraphOutputCache instance) of the graph nodes outputs. The results of precedent nodes will be saved in the cache, so dependents can retrieve them easily.

method __init__

__init__(*args, **kwds)  None

property board

the board writer of current task.


property device

the assigned device by current launcher.


property epoch_size

the size of current epoch.


property epoch_steps

the steps of current epoch.


property global_auto_steps

the global steps of current task.


property global_train_epochs

the global train epochs of current task.


property global_train_steps

the global train steps of current task.


property grad_acc_steps

the grad accumulator steps of current task.


property grad_scaler

the grad scaler of current task.


property launcher

the current launcher.


property name

the node name in the current activated ExecutableGraph.


property out_dir

the output directory of current task.


property run_id

the run id of current task.


property step_mode

whether current task is running by step (True) or by epoch (False).


property task

the current task.


property training

whether current task is training.


method backward

backward()

calculates gradients.


method clean_up

clean_up()

an event hook for clean up all resources at switching executable graphs.


method dry_run

dry_run()

only update states about progress.


method epoch_end

epoch_end()

an event hook for epoch end. (only for epoch mode)


method epoch_start

epoch_start()

an event hook for epoch start. (only for epoch mode)


method forward

forward()

retrieves forward output in cache or calculates it using forward_impl and save the output to the cache. Subclasses should not override this method.


method forward_impl

forward_impl(cache: "'GraphOutputCache'")

forward pass of the node, inputs of current executable graph can be directly retrieved from graph argument.


method load_state_dict

load_state_dict(_state_dict: 'Dict', strict: 'bool')

resumes node state from state_dict.


method move

move(data, device=None)

Moves data to the CPU or the GPU.

If :attr:device is None and the node has a device attribute, then that is the device where the data is moved. Otherwise, the data is moved according to the device.type. If :attr:data is a tuple or list, the function is applied recursively to each of the elements.

Args:

  • data (torch.Tensor or torch.nn.Module or list or dict): the data to move.

  • device (str or torch.device): the string or instance of torch.device in which to move the data.

Returns:

  • torch.Tensor or torch.nn.Module: the data in the requested device.

Raises:

  • RuntimeError: data is not one of torch.Tensor, torch.nn.module, list or dict.

method prepare

prepare()

an event hook for prepare all resources at switching executable graphs.


method state_dict

state_dict()  Dict

returns serialization of current node.


method update

update()

update parameters or buffers, e.g. using SGD based optimizer to update parameters.


class GraphOutputCache

a cache for storing and searching forward outputs of nodes.

This class is used to store and search forward outputs of nodes.

Attributes:

  • cache (dict): a dict for storing forward outputs.

  • egraph (ExecutableGraph): the executable graph.

  • data (Dict[str, torch.Tensor]): the cache.

method __init__

__init__(egraph: "'ExecutableGraph'")  None

method __getitem__

__getitem__(name)

Execute node with name name if not executed, return the last executed cache else.


method clear

clear()

Clear the cache, next calls to __getitem__ will recalculate.


class ExecutableGraph

an executable graph.

This class is used to execute nodes in a graph.

Attributes:

  • hypergraph (HyperGraph): the hypergraph.

  • nodes (Dict[str, Node]): a dict for storing nodes.

  • nodes_tags (Dict[str, str]): a dict for storing tags of nodes.

  • nodes_names (Dict[str, str]): a dict for storing names of nodes.

  • cache (GraphOutputCache): a cache for storing and searching forward outputs of nodes.

  • task: the task of the graph.

  • losses: the losses of the graph.

  • total_loss: the total loss of the graph.

method __init__

__init__(hypergraph)  None

property grad_scaler


method add_node

add_node(node_name, node, tags)

add a node to the graph.

Args:

  • node_name (str): the name of the node.

  • node (Node): the node.

  • tags (List[str]): the tags of the node.

Raises:

  • RuntimeError: the node is not a node of the hypergraph.

method apply

apply(
    method: 'str',
    *args,
    filter: Callable[[Node], bool] = lambda _: True,,
    **kwds
)

apply method to all nodes in the graph.

Args:

  • method (str): the method name.

  • *args: the arguments of the method.

  • filter (Callable[[Node], bool]): the filter function.

  • **kwds: the keyword arguments of the method.

Returns:

  • List[Any]: the return values of the method.

Raises:

  • RuntimeError: the method is not found.

method clean_up_nodes

clean_up_nodes()

clean up all nodes in the graph.

This method is called after the graph is executed.


method items

items()

method iterate

iterate()

iterate all nodes in the graph.


method prepare_nodes

prepare_nodes()

prepare all nodes in the graph.

This method is called before the graph is executed.