Skip to content

module core.hypergraph

Global Variables

  • TYPE_CHECKING
  • global_shared_events

function LoadCheckpointTask

LoadCheckpointTask(resume_from, strict=False, tags='*')

Load checkpoint from a file.

Args:

  • resume_from (str): Path to the checkpoint file.

  • strict (bool): If True, raise an exception if the checkpoint file does not exist.

  • tags (str): Tags to load.

Returns:

  • Task: Task to load the checkpoint.

function SaveCheckpointTask

SaveCheckpointTask(save_to=None, tags='*')

Save checkpoint to a file.

Args:

  • save_to (str): Path to the checkpoint file.

  • tags (str): Tags to save.

Returns:

  • Task: Task to save the checkpoint.

class ResumeTaskFailed

raised when task structure does not match during resuming.


class Task

A task is a unit of computation.

It can be a single node, or a graph. A task can be executed by a worker.

Args:

  • node: a node or a graph.

  • name: the name of the task.

  • total_steps: the total number of steps to run.

  • total_epochs: the total number of epochs to run.

  • config: a dict of configs.

method __init__

__init__(*args, **kwds)  None

property global_auto_epochs


property global_auto_steps


method load_state_dict

load_state_dict(_state_dict, dry_run=False)

method state_dict

state_dict()

class Repeat

Repeat a task for a fixed number of times.

Attributes:

  • task (Task): Task to repeat.

  • repeat (int): Number of times to repeat the task.

  • epoch_size (int): Number of steps per epoch.

  • total_steps (int): Total number of steps.

  • total_epochs (int): Total number of epochs.

  • launcher (Launcher): Launcher object.

  • hypergraph (Hypergraph): Hypergraph object.

  • events (Events): Events object.

method __init__

__init__(*args, **kwds)  None

method load_state_dict

load_state_dict(_state_dict, dry_run=False)

method state_dict

state_dict()

class Counter

Counter object.

Attributes:

  • epochs (int): Number of epochs.

  • steps (int): Number of steps.

method __init__

__init__()  None

property total


method __getitem__

__getitem__(key)

Get the value of the counter.

Args:

  • key (str): Name of the counter.

Returns:

  • int: Value of the counter.

method __setitem__

__setitem__(key, value)

Set the value of the counter.

Args:

  • key (str): Name of the counter.

  • value (int): Value of the counter.

Raises:

  • KeyError: If the key is not valid.

class GlobalCounters

Global counters object.

Attributes:

  • epochs (int): Number of epochs.

  • steps (int): Number of steps.

function __init__

__init__(
    steps: 'Counter' = <core.hypergraph.Counter object at 0x7ff1ec689c40>,
    epochs: 'Counter' = <core.hypergraph.Counter object at 0x7ff1ec6891c0>
)  None

class HyperGraph

HyperGraph is the container for all nodes.

Attributes:

  • nodes (dict): Nodes.

  • edges (dict): Edges.

  • tasks (dict): Tasks.

  • launchers (dict): Launchers.

  • global_counters (GlobalCounters): Global counters.

  • resume_from (str): Path to the checkpoint file.

  • resume_tags (str): Tags to load.

  • save_to (str): Path to the checkpoint file.

  • save_tags (str): Tags to save.

  • strict (bool): If True, raise an exception if the checkpoint file does not exist.

  • dry_run (bool): If True, do not save the checkpoint.

  • verbose (bool): If True, print the progress.

  • logger (Logger): Logger.

Raises:

  • ValueError: If the tags are not valid.

method __init__

__init__(
    autocast_enabled=False,
    autocast_dtype=None,
    grad_scaler: 'Union[bool, GradScaler]' = None
)  None

property launcher

Get the launcher.

Returns:

  • ElasticLauncher: Launcher.

Raises:

  • ValueError: If the launcher is not valid.

method __getitem__

__getitem__(uid)  Node

Get a node by uid.

Args:

  • uid (str): Uid.

Returns:

  • Node: Node.

Raises:

  • ValueError: If the uid is not valid.

method add

add(name, node: 'Node', tags='*')

Add a node.

Args:

  • name (str): Name.

  • node (Node): Node.

  • tags (str): Tags.

Returns:

  • Node: Node.

Raises:

  • ValueError: If the name is not valid.

method backup_source_files

backup_source_files(entrypoint: 'str')

Backup source files.

Args:

  • entrypoint (str): Entrypoint.

Raises:

  • ValueError: If the entrypoint is not valid.

method exec_tasks

exec_tasks(tasks, launcher: 'ElasticLauncher')

Execute the tasks.

Args:

  • tasks (List[Task]): Tasks to execute.

  • launcher (ElasticLauncher): Launcher.

Returns:

  • List[Task]: Tasks executed.

Raises:

  • ValueError: If the tasks are not valid.

method init_autocast

init_autocast(
    autocast_enabled=True,
    autocast_dtype=None,
    grad_scaler: 'Union[bool, GradScaler]' = None
)

Initialize autocast.

Args:

  • autocast_enabled (bool): If True, enable autocast.

  • autocast_dtype (str): Data type to cast the gradients to.

  • grad_scaler (GradScaler): Gradient scaler.

Raises:

  • ValueError: If the autocast_dtype is not valid.

method init_grad_scaler

init_grad_scaler(self, grad_scaler: Union[bool, GradScaler]=False, *, init_scale=2.0 ** 16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)

Ellipsis


method is_autocast_enabled

is_autocast_enabled()  bool

Check if autocast is enabled.

Returns:

  • bool: If True, autocast is enabled.

method is_grad_scaler_enabled

is_grad_scaler_enabled()  bool

Check if the gradient scaler is enabled.

Returns:

  • bool: If True, the gradient scaler is enabled.

Raises:

  • ValueError: If the grad_scaler is not valid.

method load_checkpoint

load_checkpoint(resume_from, strict=False, tags='*')

Load the checkpoint.

Args:

  • resume_from (str): Path to the checkpoint.

  • strict (bool): Whether to check the keys.

  • tags (str): Tags to load.

Raises:

  • ValueError: If the resume_from is not valid.

method print_forward_output

print_forward_output(
    *nodenames,
    every=1,
    total=None,
    tags: 'List[str]' = '*',
    train_only=True,
    localrank0_only=True
)

Print forward output.

Args:

  • nodenames (str): Node names.

  • every (int): Print every.

  • total (int): Total.

  • tags (List[str]): Tags.

  • train_only (bool): Train only.

  • localrank0_only (bool): Local rank 0 only.

Raises:

  • ValueError: If the nodenames is not valid.

method remove

remove(query)

Remove a node.

Args:

  • query (str): Query.

Raises:

  • ValueError: If the query is not valid.

method run

run(self, tasks, devices='auto', run_id: str='none', out_dir: str=None, resume_from: str=None, seed=0)

Ellipsis


method run

run(self, tasks, launcher: ElasticLauncher=None, run_id: str='none', out_dir: str=None, resume_from: str=None, seed=0)

Ellipsis


method run

run(self, tasks, devices='auto', run_id='none', nnodes='1:1', dist_backend='auto', monitor_interval=5, node_rank=0, master_addr='127.0.0.1', master_port=None, redirects='2', tee='1', out_dir=None, resume_from=None, seed=0, role='default', max_restarts=0, omp_num_threads=1, start_method='spawn')

Ellipsis


method run

run(self, tasks, devices='auto', run_id='none', nnodes='1:1', dist_backend='auto', monitor_interval=5, rdzv_endpoint='', rdzv_backend='static', rdzv_configs='', standalone=False, redirects='2', tee='1', out_dir=None, resume_from=None, seed=0, role='default', max_restarts=0, omp_num_threads=1, start_method='spawn')

Ellipsis


method save_checkpoint

save_checkpoint(save_to=None, tags='*')

Save the checkpoint.

Args:

  • save_to (str): Path to save the checkpoint.

  • tags (str): Tags to save.

Returns:

  • str: Path to the checkpoint.

Raises:

  • ValueError: If the save_to is not valid.

method select_egraph

select_egraph(query)  ExecutableGraph

Select an executable graph.

Args:

  • query (str): Query.

Returns:

  • ExecutableGraph: Executable graph.

Raises:

  • ValueError: If the query is not valid.

method select_nodes

select_nodes(*query)

Select nodes.

Args:

  • query (str): Query.

Returns:

  • list: Nodes.

Raises:

  • ValueError: If the query is not valid.

method set_gradient_accumulate

set_gradient_accumulate(every=1)

Set the gradient accumulate steps.

Args:

  • every (int): Gradient accumulate steps.

Raises:

  • ValueError: If the every is not valid.