Interfaces#
- class t3w.Interface#
This is a special abstract class to indicate its direct subclasses are interface and should be implemented by the user.
Note
Every direct subclass of
Interface
has its name starting with letter “I”.
- class t3w.IDatum#
The interface for a single datum. Your dataset is meant to have abundant of this type of objects.
Note
TrainLoop
orEvalLoop
can directly make use of attributes and methods defined here through theIDataset.datum_type
link.- uuid: Hashable | None#
universal unique identifier for tracking datum information, optional.
- train_batch_size: int = 2#
TrainLoop
will use this value throughIDataset.datum_type
as default batch size value.
- val_batch_size: int = 2#
EvalLoop
will use this value throughIDataset.datum_type
as default batch size value.
- num_workers: int = 0#
TrainLoop
orEvalLoop
will use this value throughIDataset.datum_type
for DataLoader.num_workers.
- static collate(data: Sequence[IDatum]) IMiniBatch #
collate a mini batch of data into an
IMiniBatch
consumable by user’s model andIMiniBatchMetric
.This function is called by
DataLoader
in theTrainLoop
orEvalLoop
to collate independently sampled data into a mini batch, before which is passed to theTopLevelModule.forward()
.Note
It is strongly suggested not to use the pytorch’s
default_collate()
function but override it here. This is becausedefault_collate()
does not support extra types other than numerical tensors, and make users not easy to manage their batch data in an object-oriented-programming paradigm. Be explicit, and write your own collate function to produce yourIMiniBatch
type.- Parameters:
data (list(IDatum)) – sampled mini batch of data as a sequence.
- Returns:
the collated mini batch.
- Return type:
- classmethod from_uuid(uuid: Hashable) IDatum #
build the exact datum from given
uuid
.Ignore this factory method for basic training workflow. It is useful for some workflows that mine, track and analyse specific interesting data. One may embed dataset_type, split, and index information to fully recover a datum’s identity.
- Parameters:
uuid (Hashable) – universal unique identifier of the disired datum.
- Return type:
- class t3w.IMiniBatch#
The interface for a minibatch of datum.
It plays the central role of a data context passing around models, metrics, losses and callbacks. All interfaces in
t3w
has a strict typing requirement, this especially emphasize that functions are required to return specific type of data, instead of dynamic types. The flexibility oft3w
is not harmed though, because it is user’s freedom to modify the definition of subclasses of theIMiniBatch
interface. And all the functions receives aIMiniBatch
as input, are allowed (or supposed) to modify it in place. Therefore, standard behaviors oft3w
core libraries and :class:ISideEffect
based plugin system can rely on the typing system, while users can write flexible code in their independent namespace.See also
- model: TopLevelModule = None#
The
TopLevelModule
which is consuming and modifying this instance ofIMiniBatch
.The
TopLevelModule
will fill this attribute right before calling the forward() method of its internal user_model. Therefore, theuser_model.forward()
,ISideEffect.on_evak_step_finished()
, andISideEffect.on_train_step_finished()
can make use of it. Before that, this attribute is defaulted toNone
.
- to(device: device)#
defines how to move this mini-batch type into specified device.
The default implementation in the base class
IMiniBatch
moves all direct attributes which are instances oftorch.Tensor
to target device. This can be good enough for common usage, but nothing stops you from customizing it.Note
TopLevelModule.forward()
will call this function right before callinguser_model.forward()
, so don’t bother to do it yourself. You only need to specify the target device at theTopLevelModule
level using itsTopLevelModule.to()
method.See also
- Parameters:
device (torch.device) – target device.
- class t3w.IDataset(root: str, split: str | None = None)#
The interface for the entire dataset (and its split).
This is very simillar to
torch.Dataset
(sized dataset) interface, with only subtle modifications that__getitem__()
is required to return an instance ofdatum_type
.- datum_type: Type[IDatum]#
User implemented IDatum subclass’ typename.
Note
Subclass of
IDataset
must specify this attribute in order to fetch the class attribute includingIDatum.train_batch_size
,IDatum.val_batch_size
, andIDatum.num_workers
, etc.
- __init__(root: str, split: str | None = None) None #
- Parameters:
root (str) – path of the root directory of specified dataset.
split (str, optional) – identifier for a subset. Defaults to None.
- class t3w.IMiniBatchMetric(*args, **kwargs)#
The interface of compute metric for datum in a mini-batch.
Note
We differentiate the
IMiniBatchMetric
andIDatasetMetric
, where the former compute metric value for a batch of data, while the latter aggregate datum metric of each batch for a entire dataset (typically an “average meter”). The dataset level metric mainly focus on correct computation and synchronization of variable batch sizes and among devices.See also
- higher_better: bool#
specifies whether higher value of the metric implies better performance. This can be useful for e.g. metric based best model saving. Always explicitly specify this class variable in your subclass definition.
- forward(mb: IMiniBatch) MiniBatchFloats | FloatScalarTensor #
Calling of this method is delegated to
TrainLoop
or :class:EvalLoop` at their construction time. Themb
argument must have been through the user_model’sforward
method already, and the loops pass it on to metrics to calculate the metric value, which must be a float scalar tensor.- Parameters:
mb (IMiniBatch) – the mini-batch of data which have been processed by user_model.
- Returns:
returns a sequence or a single value (regarded as the same for every example in the minibatch) of metric values.
- Return type:
MiniBatchFloats | FloatScalarTensor
- class t3w.ILoss(*args, **kwargs)#
The interface of compute loss for datum in a mini-batch.
Note
In t3w, we adopt the fact that a loss function is a special type of metric that support backpropagation. Therefore
ILoss
inheritesIMiniBatchMetric
and you can use anILoss
whereever anIMiniBatchMetric
is suited.- loss_reweight: float = 1.0#
Losses has the raw version and reweighted version in lots of circumstances. Use this standard attribute to specify your weight of current loss.
Note
In the “losses” subdict of
StepReturn
, the value of each loss will be reported as a pair of float value(loss_reweight, loss_raw_value)
with eventon_train_step_finished()
emitted. This is the standard behavior that extension codes can rely on. Non-loss metrics are reported as pure floats.
- higher_better: bool = False#
higher value of a loss always implies worse performance. Don’t bother to specify it in the subclasses.
- forward(mb: IMiniBatch) MiniBatchFloats | FloatScalarTensor #
Calling of this method is delegated to
TrainLoop
at its construction time. Themb
argument must have been through the user_model’sforward
method already, and the loops pass it on to various losses to calculate the loss value, which must be a float scalar tensor. And then the losses will be reweighted and summed together for a backward autodiff pass.- Parameters:
mb (IMiniBatch) – the mini-batch of data which have been processed by user_model.
- Returns:
return loss scalar value.
- Return type:
FloatScalarTensor
- class t3w.IDatasetMetric(minibatch_metric: IMiniBatchMetric)#
The interface of a dataset level metric (aggregation algorithm on multi devices).
Note
We differentiate the
IMiniBatchMetric
andIDatasetMetric
, where the former compute metric value for a batch of data, while the latter aggregate datum metric of each batch for a entire dataset (typically an “average meter”). The dataset level metric mainly focus on correct computation and synchronization of variable batch sizes and among devices.Warning
Typically,
TrainLoop
only accepts datum metric because of changing parameters, whileEvalLoop
only accepts dataset metric because it want statistics on the entire dataset, and only callIDatasetMetric.synchronize()
once right before emitting eventon_eval_finished()
. Other use cases are still possible andsynchronize()
should also remain correct after arbitrary times of call, e.g. a running average dataset metric can be used in a train loop, but it is still considered non-standard behavior int3w
and should prefer implemented and maintained in user space using the side effects system.See also
- __init__(minibatch_metric: IMiniBatchMetric) None #
store
minibatch_metric
and reset the statistics.- Parameters:
minibatch_metric (IMiniBatchMetric) – internal minibatch_metric instance to embed.
- minibatch_metric: IMiniBatchMetric#
The dataset metric has a standard behavior to composite a datum metric instance, and the calling of the datum metric is delegated to
update()
.
- forward(mb: IMiniBatch) MiniBatchFloats #
allows using dataset metric like datum metric.
Warning
It is non-standard behavior so please be careful implementing a
IDatasetMetric
that supposed to be used as a datum metric, especially ensure multiple call of synchronization among multiple devices does not have ambiguity, for example by separately store the synchronized and local version of states.This will be changed in a future version, by using a breaking API with stateless synchronization.
- Parameters:
mb (IMiniBatch) – a mini-batch of data.
- Returns:
current dataset metric value.
- Return type:
MiniBatchFloats
- eval() float #
get target metric value based on internal statistics.
- reset() None #
clear internal statistics.
- update(mb: IMiniBatch) None #
clear internal statistics.
- synchronize() None #
synchronize local statistics.
- float()#
Casts all floating point parameters and buffers to
float
datatype.Note
This method modifies the module in-place.
- Returns:
self
- Return type:
Module
- class t3w.ISideEffect#
- is_distributed: bool = False#
if False, the event will only be invoked on the rank 0 sub-process; otherwise on all sub-processes.
- on_eval_step_started(loop: EvalLoop, step: int, mb: IMiniBatch)#
- on_eval_step_finished(loop: EvalLoop, step: int, mb: IMiniBatch)#
- on_train_step_started(loop: TrainLoop, step: int, mb: IMiniBatch)#
- on_train_step_finished(loop: TrainLoop, step: int, mb: IMiniBatch, step_return: StepReturnDict)#