#!/usr/bin/python2 -s

# (c) 2014, Will Thames <will@thames.id.au>
#
# ansible-inventory-grapher is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ansible-inventory-grapher is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ansible-inventory-grapher.  If not, see <http://www.gnu.org/licenses/>.

import jinja2
import optparse
import os
import sys
import ansible.inventory
import ansible.constants as C
try:
    from ansible.parsing.dataloader import DataLoader
    from ansible.vars import VariableManager
    from ansible.cli import CLI
    read_vault_file = CLI.read_vault_password_file
    def ask_vault_pass():
        return CLI.ask_vault_passwords()[0]
    ANSIBLE_VERSION = 2
except ImportError:
    from ansible import utils
    read_vault_file = utils.read_vault_file
    def ask_vault_pass():
        return utils.ask_passwords(ask_vault_pass=True)[2]
    ANSIBLE_VERSION = 1

import ansibleinventorygrapher
from ansibleinventorygrapher.version import __version__


DEFAULT_TEMPLATE = """digraph {{pattern|labelescape}} {
  {{ attributes }}

{% for node in nodes|sort(attribute='name') %}
{% if node.leaf %}
  {{ node.name|labelescape }} [shape=record style=rounded label=<
<table border="0" cellborder="0">
  <tr><td><b>
  <font face="Times New Roman, Bold" point-size="16">{{ node.name}}</font>
  </b></td></tr>
{% if node.vars and showvars %}<hr/><tr><td><font face="Times New Roman, Bold" point-size="14">{% for var in node.vars|sort %}{{var}}<br/>{%endfor %}</font></td></tr>{% endif %}
</table>
>]
{% else %}
  {{ node.name|labelescape }} [shape=record label=<
<table border="0" cellborder="0">
  <tr><td><b>
  <font face="Times New Roman, Bold" point-size="16">{{ node.name}}</font>
  </b></td></tr>
{% if node.vars and showvars %}<hr/><tr><td><font face="Times New Roman, Bold" point-size="14">{% for var in node.vars|sort %}{{var}}<br/>{%endfor %}</font></td></tr>{% endif %}
</table>
>]
{% endif %}{% endfor %}

{% for edge in edges|sort(attribute='source') %}
  {{ edge.source|labelescape }} -> {{ edge.target|labelescape }};
{% endfor %}
}

"""


def options_parser():
    usage = "%prog [options] pattern1 [pattern2 ...]"
    parser = optparse.OptionParser(usage=usage,version="%prog " + __version__)
    parser.add_option('-i', dest='inventory',
                      help="specify inventory host file [%default]",
                      default=C.DEFAULT_HOST_LIST)
    parser.add_option('-d', dest='directory',
                      help="Location to output resulting files [current directory]",
                      default=os.getcwd())
    parser.add_option('-o', '--format', dest='format', default="-",
                      help="python format string to name output files " +
                      '(e.g. {}.dot) [defaults to stdout]')
    parser.add_option('-q', '--no-variables', dest='showvars',
                      action="store_false", default=True,
                      help="Turn off variable display in default template")
    parser.add_option('-t', dest='template',
                      help='path to jinja2 template used for creating output')
    parser.add_option('-T', dest='print_template', action="store_true",
                      help='print default template')
    parser.add_option('-a', dest='attributes',
                      help='include top-level graphviz attributes from ' +
                      'http://www.graphviz.org/doc/info/attrs.html [%default]',
                      default='rankdir=TB;')
    parser.add_option('--ask-vault-pass', dest='ask_vault_pass',
                      action="store_true",
                      help="prompt for vault password", default=False)
    parser.add_option('--vault-password-file', dest='vault_password_file',
                      help="Location of file with cleartext vault password",
                      default=False)
    return parser


def labelescape(name):
    return name.replace("-", "_").replace(".", "_")


def load_template(options):
    env = jinja2.Environment(trim_blocks=True, loader=jinja2.FileSystemLoader(os.getcwd()))
    env.filters['labelescape'] = labelescape

    if options.template:
        template = env.get_template(options.template)
    else:
        template = env.from_string(DEFAULT_TEMPLATE)
    return template


def render_graph(pattern, options):
    vault_pass = None
    if options.ask_vault_pass:
        vault_pass = ask_vault_pass()
    if ANSIBLE_VERSION > 1:
        variable_manager = VariableManager()
        loader = DataLoader()
        if options.vault_password_file:
            vault_pass = read_vault_file(options.vault_password_file, loader=loader)
        if vault_pass:
            loader.set_vault_password(vault_pass)
        inventory = ansible.inventory.Inventory(loader=loader, variable_manager=variable_manager, host_list=options.inventory)
        variable_manager.set_inventory(inventory)
    else:
        if options.vault_password_file:
            vault_pass = read_vault_file(options.vault_password_file)
        inventory = ansible.inventory.Inventory(options.inventory, vault_password=vault_pass)
    hosts = inventory.list_hosts(pattern)
    template = load_template(options)
    if not hosts:
        print "No hosts matched for pattern %s" % pattern
        return
    if not os.path.exists(options.directory):
        os.makedirs(options.directory)
    edges = set()
    nodes = set()
    for host in hosts:
        # inventory.list_hosts changed from list of hostnames to list
        # of hosts between Ansible 1.9 and 2.0
        if isinstance(host, basestring):
            host = inventory.get_host(host)
        (host_edges, host_nodes) = ansibleinventorygrapher.generate_graph_for_host(host, inventory)
        edges |= host_edges
        nodes |= host_nodes
    output = template.render(edges=edges, nodes=nodes, pattern=pattern,
                             attributes=options.attributes,
                             showvars=options.showvars)
    if options.format != '-':
        filename = options.format.format(pattern)
        fullpath = os.path.join(options.directory, filename)
        with open(fullpath, 'w') as f:
            f.write(output)
    else:
        sys.stdout.write(output)


def main():
    parser = options_parser()
    (options, args) = parser.parse_args()
    if options.print_template:
        print DEFAULT_TEMPLATE
        sys.exit()
    if not args:
        parser.print_help()
        sys.exit()
    for arg in args:
        render_graph(arg, options)


if __name__ == "__main__":
    main()
