inferno.extensions.containers package¶
Submodules¶
inferno.extensions.containers.graph module¶
-
class
inferno.extensions.containers.graph.
NNGraph
(data=None, **attr)[source]¶ Bases:
networkx.classes.digraph.DiGraph
A NetworkX DiGraph, except that node and edge ordering matters.
-
ATTRIBUTES_TO_NOT_COPY
= {'payload'}¶
-
adjlist_dict_factory
¶ alias of
OrderedDict
-
node_dict_factory
¶ alias of
OrderedDict
-
-
class
inferno.extensions.containers.graph.
Graph
(graph=None)[source]¶ Bases:
torch.nn.modules.module.Module
A graph structure to build networks with complex architectures. The resulting graph model can be used like any other torch.nn.Module. The graph structure used behind the scenes is a networkx.DiGraph. This internal graph is exposed by the apply_on_graph method, which can be used with any NetworkX function (e.g. for plotting with matplotlib or GraphViz).
Examples
The naive inception module (without the max-pooling for simplicity) with ELU-layers of 64 units can be built as following, (assuming 64 input channels):
>>> from inferno.extensions.layers.reshape import Concatenate >>> from inferno.extensions.layers.convolutional import ConvELU2D >>> import torch >>> from torch.autograd import Variable >>> # Build the model >>> inception_module = Graph() >>> inception_module.add_input_node('input') >>> inception_module.add_node('conv1x1', ConvELU2D(64, 64, 3), previous='input') >>> inception_module.add_node('conv3x3', ConvELU2D(64, 64, 3), previous='input') >>> inception_module.add_node('conv5x5', ConvELU2D(64, 64, 3), previous='input') >>> inception_module.add_node('cat', Concatenate(), >>> previous=['conv1x1', 'conv3x3', 'conv5x5']) >>> inception_module.add_output_node('output', 'cat') >>> # Build dummy variable >>> input = Variable(torch.rand(1, 64, 100, 100)) >>> # Get output >>> output = inception_module(input)
-
add_edge
(from_node, to_node)[source]¶ Add an edge between two nodes.
Parameters: - from_node (str) – Name of the source node.
- to_node (str) – Name of the target node.
Returns: self
Return type: Raises: AssertionError
– if either of the two nodes is not in the graph, or if the edge is not ‘legal’.
-
add_input_node
(name)[source]¶ Add an input to the graph. The order in which input nodes are added is the order in which the forward method accepts its inputs.
Parameters: name (str) – Name of the input node. Returns: self Return type: Graph
-
add_node
(name, module, previous=None)[source]¶ Add a node to the graph.
Parameters: - name (str) – Name of the node. Nodes are identified by their names.
- module (torch.nn.Module) – Torch module for this node.
- previous (str or list of str) – (List of) name(s) of the previous node(s).
Returns: self
Return type:
-
add_output_node
(name, previous=None)[source]¶ Add an output to the graph. The order in which output nodes are added is the order in which the forward method returns its outputs.
Parameters: name (str) – Name of the output node. Returns: self Return type: Graph
-
get_module_for_nodes
(names)[source]¶ Gets the torch.nn.Module object for nodes corresponding to names.
Parameters: names (str or list of str) – Names of the nodes to fetch the modules of. Returns: Module or a list of modules corresponding to names. Return type: list or torch.nn.Module
-
graph
¶
-
graph_is_valid
¶ Checks if the graph is valid.
-
input_nodes
¶ Gets a list of input nodes. The order is relevant and is the same as that in which the forward method accepts its inputs.
Returns: A list of names (str) of the input nodes. Return type: list
-
is_node_in_graph
(name)[source]¶ Checks whether a node is in the graph.
Parameters: name (str) – Name of the node. Returns: Return type: bool
-
is_sink_node
(name)[source]¶ Checks whether a given node (by name) is a sink node. A sink node has no outgoing edges.
Parameters: name (str) – Name of the node. Returns: Return type: bool Raises: AssertionError
– if node is not found in the graph.
-
is_source_node
(name)[source]¶ Checks whether a given node (by name) is a source node. A source node has no incoming edges.
Parameters: name (str) – Name of the node. Returns: Return type: bool Raises: AssertionError
– if node is not found in the graph.
-
output_nodes
¶ Gets a list of output nodes. The order is relevant and is the same as that in which the forward method returns its outputs.
Returns: A list of names (str) of the output nodes. Return type: list
-
inferno.extensions.containers.sequential module¶
-
class
inferno.extensions.containers.sequential.
Sequential1
(*args)[source]¶ Bases:
torch.nn.modules.container.Sequential
Like torch.nn.Sequential, but with a few extra methods.
-
class
inferno.extensions.containers.sequential.
Sequential2
(*args)[source]¶ Bases:
inferno.extensions.containers.sequential.Sequential1
Another sequential container. Identitcal to torch.nn.Sequential, except that modules may return multiple outputs and accept multiple inputs.