inferno.extensions.containers package

Submodules

inferno.extensions.containers.graph module

class inferno.extensions.containers.graph.NNGraph(data=None, **attr)[source]

Bases: networkx.classes.digraph.DiGraph

A NetworkX DiGraph, except that node and edge ordering matters.

ATTRIBUTES_TO_NOT_COPY = {'payload'}
adjlist_dict_factory

alias of OrderedDict

copy(**init_kwargs)[source]
node_dict_factory

alias of OrderedDict

class inferno.extensions.containers.graph.Graph(graph=None)[source]

Bases: torch.nn.modules.module.Module

A graph structure to build networks with complex architectures. The resulting graph model can be used like any other torch.nn.Module. The graph structure used behind the scenes is a networkx.DiGraph. This internal graph is exposed by the apply_on_graph method, which can be used with any NetworkX function (e.g. for plotting with matplotlib or GraphViz).

Examples

The naive inception module (without the max-pooling for simplicity) with ELU-layers of 64 units can be built as following, (assuming 64 input channels):

>>> from inferno.extensions.layers.reshape import Concatenate
>>> from inferno.extensions.layers.convolutional import ConvELU2D
>>> import torch
>>> from torch.autograd import Variable
>>> # Build the model
>>> inception_module = Graph()
>>> inception_module.add_input_node('input')
>>> inception_module.add_node('conv1x1', ConvELU2D(64, 64, 3), previous='input')
>>> inception_module.add_node('conv3x3', ConvELU2D(64, 64, 3), previous='input')
>>> inception_module.add_node('conv5x5', ConvELU2D(64, 64, 3), previous='input')
>>> inception_module.add_node('cat', Concatenate(),
>>>                           previous=['conv1x1', 'conv3x3', 'conv5x5'])
>>> inception_module.add_output_node('output', 'cat')
>>> # Build dummy variable
>>> input = Variable(torch.rand(1, 64, 100, 100))
>>> # Get output
>>> output = inception_module(input)
add_edge(from_node, to_node)[source]

Add an edge between two nodes.

Parameters:
  • from_node (str) – Name of the source node.
  • to_node (str) – Name of the target node.
Returns:

self

Return type:

Graph

Raises:

AssertionError – if either of the two nodes is not in the graph, or if the edge is not ‘legal’.

add_input_node(name)[source]

Add an input to the graph. The order in which input nodes are added is the order in which the forward method accepts its inputs.

Parameters:name (str) – Name of the input node.
Returns:self
Return type:Graph
add_node(name, module, previous=None)[source]

Add a node to the graph.

Parameters:
  • name (str) – Name of the node. Nodes are identified by their names.
  • module (torch.nn.Module) – Torch module for this node.
  • previous (str or list of str) – (List of) name(s) of the previous node(s).
Returns:

self

Return type:

Graph

add_output_node(name, previous=None)[source]

Add an output to the graph. The order in which output nodes are added is the order in which the forward method returns its outputs.

Parameters:name (str) – Name of the output node.
Returns:self
Return type:Graph
apply_on_graph(function, *args, **kwargs)[source]

Applies a function on the internal graph.

assert_graph_is_valid()[source]

Asserts that the graph is valid.

clear_payloads(graph=None)[source]
forward(*inputs)[source]
forward_through_node(name, input=None)[source]
get_module_for_nodes(names)[source]

Gets the torch.nn.Module object for nodes corresponding to names.

Parameters:names (str or list of str) – Names of the nodes to fetch the modules of.
Returns:Module or a list of modules corresponding to names.
Return type:list or torch.nn.Module
get_parameters_for_nodes(names, named=False)[source]

Get parameters of all nodes listed in names.

graph
graph_is_valid

Checks if the graph is valid.

input_nodes

Gets a list of input nodes. The order is relevant and is the same as that in which the forward method accepts its inputs.

Returns:A list of names (str) of the input nodes.
Return type:list
is_node_in_graph(name)[source]

Checks whether a node is in the graph.

Parameters:name (str) – Name of the node.
Returns:
Return type:bool
is_sink_node(name)[source]

Checks whether a given node (by name) is a sink node. A sink node has no outgoing edges.

Parameters:name (str) – Name of the node.
Returns:
Return type:bool
Raises:AssertionError – if node is not found in the graph.
is_source_node(name)[source]

Checks whether a given node (by name) is a source node. A source node has no incoming edges.

Parameters:name (str) – Name of the node.
Returns:
Return type:bool
Raises:AssertionError – if node is not found in the graph.
output_nodes

Gets a list of output nodes. The order is relevant and is the same as that in which the forward method returns its outputs.

Returns:A list of names (str) of the output nodes.
Return type:list
to_device(names, target_device, device_ordinal=None, async=False)[source]

Transfer nodes in the network to a specified device.

inferno.extensions.containers.sequential module

class inferno.extensions.containers.sequential.Sequential1(*args)[source]

Bases: torch.nn.modules.container.Sequential

Like torch.nn.Sequential, but with a few extra methods.

class inferno.extensions.containers.sequential.Sequential2(*args)[source]

Bases: inferno.extensions.containers.sequential.Sequential1

Another sequential container. Identitcal to torch.nn.Sequential, except that modules may return multiple outputs and accept multiple inputs.

forward(*input)[source]

Module contents