Add the count of connections
authorPeng Li <seudut@gmail.com>
Wed, 27 Jun 2018 16:14:11 +0000 (00:14 +0800)
committerPeng Li <seudut@gmail.com>
Wed, 27 Jun 2018 16:14:11 +0000 (00:14 +0800)
sztool

diff --git a/sztool b/sztool
index ba59b45..2dd723b 100755 (executable)
--- a/sztool
+++ b/sztool
@@ -1,17 +1,24 @@
 #!/usr/bin/env python2
 
-from subprocess import call, Popen, PIPE, STDOUT
+from subprocess import call, Popen, PIPE, STDOUT, check_output
 import paho.mqtt.client as mqtt
 import argparse
 import sys
 from uuid import getnode as get_mac
 import ast
 import threading
+import shlex
 
 MQTT_SERVER = 'mqtt.suanzi.ai'
 MQTT_PORT = 1883
 URI = 'debug@autossh.suanzi.ai'
 
+def get_all_connected(host, port, user, password):
+    command = 'sshpass -p' + password + ' ssh -q -o "StrictHostKeyChecking=no"  -o "UserKnownHostsFile /dev/null" ' + user + '@' + host + '  netstat -tn | grep ' + str(port) + ' | grep ESTABLISHED | wc -l'
+    print command
+    output = check_output(shlex.split(command), shell=False)
+    return int(output)
+
 def on_connect(client, userdata, flags, rc):
     client.subscribe(userdata['id'])
     print "Connected with result code " + str(rc)
@@ -19,17 +26,16 @@ def on_connect(client, userdata, flags, rc):
         payload = {'from':userdata['id'], 'type':'request', 'command':'ssh'}
         client.publish(userdata['to'], str(payload))
     elif userdata['action'] == 'list':
-        timer = threading.Timer(5, timeout)
-        timer.start()
+        #timer = threading.Timer(20, timeout)
+        #timer.start()
         payload = {'from':userdata['id'], 'type':'request', 'command':'list'}
         client.publish(userdata['to'], str(payload))
 
 def on_message(client, userdata, msg):
-    print('Receive topic:' + msg.topic + ' payload: ' +str(msg.payload))
+#    print('Receive topic:' + msg.topic + ' payload: ' +str(msg.payload))
     payload = ast.literal_eval(str(msg.payload))
     from_id = payload['from']
     if payload['type'] == 'response':
-        print payload['command']
         if payload['command'] == 'ssh':
             port = payload['data']
             print 'run "ssh ' + URI + ' -p ' + str(port) + '" to connect to device (%s)' % from_id
@@ -37,19 +43,26 @@ def on_message(client, userdata, msg):
             sys.exit(0)
         elif payload['command'] == 'list':
             print payload['from'], payload['data']
+            client.total = client.total - 1
+            if client.total == 0:
+                client.disconnect()
+                sys.exit(0)
 
 def get_mac_str():
     mac = hex(get_mac())
     return '{:0>12}'.format(mac[2:-1])
 
-def timeout():
-    print 'timeout'
-    client.disconnect()
-    sys.exit(0)
-
+#def timeout():
+#    print 'timeout'
+#    client.disconnect()
+#    sys.exit(0)
+#
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Choose an avaiable port and run ssh.')
+    parser = argparse.ArgumentParser(description='Choose an avaiable port and run ssh.', add_help=False)
+    parser.add_argument('--help', action='help')
+    parser.add_argument('-h', '--host', help='the mqtt server', default=MQTT_SERVER)
+    parser.add_argument('-p', '--port', type=int, help='the mqtt port', default=MQTT_PORT)
     subparsers = parser.add_subparsers(title='subcommands', dest='action', help='xxx')
     parser_conn = subparsers.add_parser('connect', help='Connect to the specific device')
     parser_conn.add_argument('device', help='the MAC address of remote device, in the form of 080027a6f8dc')
@@ -57,12 +70,15 @@ if __name__ == '__main__':
     args = parser.parse_args()
     id = get_mac_str()
     print 'Mac: ', id
-
     device = args.device if args.action == 'connect' else 'all'
 
+    total = get_all_connected(MQTT_SERVER, MQTT_PORT, 'autossh', 'hard2guess')
+    print '\nFound %s connected devices: ' % total
+
     client = mqtt.Client(userdata={'id':id, 'action':args.action, 'to': device})
+    client.total = total
     client.on_connect = on_connect
     client.on_message = on_message
-    client.connect(MQTT_SERVER, MQTT_PORT, 60)
+    client.connect(args.host, args.port, 60)
     client.loop_forever()