Imported Upstream version 0.3.3
[anytun.git] / src / anytun.cpp
index de8429f..1ac9397 100644 (file)
 #include "authAlgoFactory.h"
 #include "keyDerivationFactory.h"
 #include "signalController.h"
-#ifdef WIN_SERVICE
-#include "win32/winService.h"
+#ifndef _MSC_VER
+# include "daemonService.h"
+#else
+# ifdef WIN_SERVICE
+#  include "win32/winService.h"
+# else
+#  include "nullDaemon.h"
+# endif
 #endif
 #include "packetSource.h"
 #include "tunDevice.h"
@@ -63,7 +69,6 @@
 #include "networkAddress.h"
 #endif
 
-
 #ifndef ANYTUN_NOSYNC
 #include "syncQueue.h"
 #include "syncCommand.h"
 #include "syncOnConnect.hpp"
 #endif
 
-#define MAX_PACKET_LENGTH 1600
-
 #include "cryptinit.hpp"
-#include "daemon.hpp"
 #include "sysExec.h"
 
 bool disableRouting = false;
@@ -97,6 +99,11 @@ void createConnection(const PacketSourceEndpoint& remote_end, window_size_t seqS
 #endif
 }
 
+void createConnectionResolver(PacketSourceResolverIt& it, window_size_t seqSize, mux_t mux)
+{
+  createConnection(*it, seqSize, mux);
+}
+
 void createConnectionError(const std::exception& e)
 {
   gSignalController.inject(SIGERROR, e.what());
@@ -227,7 +234,8 @@ void receiver(TunDevice* dev, PacketSource* src)
     std::auto_ptr<Cipher> c(CipherFactory::create(gOpt.getCipher(), KD_INBOUND));
     std::auto_ptr<AuthAlgo> a(AuthAlgoFactory::create(gOpt.getAuthAlgo(), KD_INBOUND));
     
-    EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH, gOpt.getAuthTagLength());
+    u_int32_t auth_tag_length = gOpt.getAuthTagLength();
+    EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH, auth_tag_length);
     PlainPacket plain_packet(MAX_PACKET_LENGTH);
     
     while(1) {
@@ -249,7 +257,7 @@ void receiver(TunDevice* dev, PacketSource* src)
       if(len < 0)
         continue; // silently ignore socket recv errors, this is probably no good idea...
 
-      if(static_cast<u_int32_t>(len) < EncryptedPacket::getHeaderLength())
+      if(static_cast<u_int32_t>(len) < (EncryptedPacket::getHeaderLength() + auth_tag_length))
         continue; // ignore short packets
       encrypted_packet.setLength(len);
       
@@ -267,7 +275,7 @@ void receiver(TunDevice* dev, PacketSource* src)
       
           // check whether auth tag is ok or not
       if(!a->checkTag(conn.kd_, encrypted_packet)) {
-        cLog.msg(Log::PRIO_NOTICE) << "wrong Authentication Tag!" << std::endl;
+        cLog.msg(Log::PRIO_NOTICE) << "wrong Authentication Tag!";
         continue;
       }        
 
@@ -313,34 +321,15 @@ void receiver(TunDevice* dev, PacketSource* src)
   }
 }
 
-#ifndef NO_DAEMON
-void startSendRecvThreads(PrivInfo& privs, TunDevice* dev, PacketSource* src)
-#else
 void startSendRecvThreads(TunDevice* dev, PacketSource* src)
-#endif
 {
   src->waitUntilReady();
   
-#ifndef NO_DAEMON
-  if(gOpt.getChrootDir() != "") {
-    try {
-      do_chroot(gOpt.getChrootDir());
-    }
-    catch(const std::runtime_error& e) {
-      cLog.msg(Log::PRIO_WARNING) << "ignoring chroot error: " << e.what();
-    }
-  }
-#ifndef NO_PRIVDROP
-  privs.drop();
-#endif
-#endif
-  
   boost::thread(boost::bind(sender, dev, src));
   boost::thread(boost::bind(receiver, dev, src)); 
 }
 
 
-
 #ifdef WIN_SERVICE
 int main(int argc, char* argv[])
 {
@@ -368,42 +357,23 @@ int main(int argc, char* argv[])
   }
 }
 
-int real_main(int argc, char* argv[])
+int real_main(int argc, char* argv[], WinService& service)
+{
 #else
 int main(int argc, char* argv[])
-#endif
 {
-#ifdef WIN_SERVICE
-  bool daemonized=true;
-#else
-  bool daemonized=false;
+  DaemonService service;
 #endif  
   try 
   {
     try 
     {
-      bool result = gOpt.parse(argc, argv);
-      if(!result) {
-        gOpt.printUsage();
+      if(!gOpt.parse(argc, argv))
         exit(0);
-      }
+
       StringList targets = gOpt.getLogTargets();
-      if(targets.empty()) {
-#ifndef _MSC_VER
-        cLog.addTarget("syslog:3,anytun,daemon");
-#else
- #ifdef WIN_SERVICE
-        cLog.addTarget("eventlog:3,anytun");
- #else
-        cLog.addTarget("stdout:3");
- #endif
-#endif
-      }
-      else {
-        StringList::const_iterator it;
-        for(it = targets.begin();it != targets.end(); ++it)
-          cLog.addTarget(*it);
-      }
+      for(StringList::const_iterator it = targets.begin();it != targets.end(); ++it)
+        cLog.addTarget(*it);
     }
     catch(syntax_error& e)
     {
@@ -416,44 +386,45 @@ int main(int argc, char* argv[])
     gOpt.parse_post(); // print warnings
 
         // daemonizing has to done before any thread gets started
-#ifndef NO_DAEMON
-#ifndef NO_PRIVDROP
-               PrivInfo privs(gOpt.getUsername(), gOpt.getGroupname());
-#endif
-    if(gOpt.getDaemonize()) {
-      daemonize();
-      daemonized = true;
-    }
-#endif
-
-        // this has to be called before the first thread is started
-    gSignalController.init();
-    gResolver.init();
-   
-#ifndef NO_CRYPT
-#ifndef USE_SSL_CRYPTO
-// this must be called before any other libgcrypt call
-    if(!initLibGCrypt())
-      return -1;
-#endif
-#endif
+    service.initPrivs(gOpt.getUsername(), gOpt.getGroupname());
+    if(gOpt.getDaemonize())
+      service.daemonize();
 
     OptionNetwork net = gOpt.getIfconfigParam();
     TunDevice dev(gOpt.getDevName(), gOpt.getDevType(), net.net_addr, net.prefix_length);
     cLog.msg(Log::PRIO_NOTICE) << "dev opened - name '" << dev.getActualName() << "', node '" << dev.getActualNode() << "'";
     cLog.msg(Log::PRIO_NOTICE) << "dev type is '" << dev.getTypeString() << "'";
-#ifndef NO_EXEC
+
+    SysExec * postup_script = NULL;
     if(gOpt.getPostUpScript() != "") {
       cLog.msg(Log::PRIO_NOTICE) << "executing post-up script '" << gOpt.getPostUpScript() << "'";
       StringVector args = boost::assign::list_of(dev.getActualName())(dev.getActualNode());
-      anytun_exec(gOpt.getPostUpScript(), args);
+      postup_script = new SysExec(gOpt.getPostUpScript(), args);
     }
-#endif
-    
+
+    if(gOpt.getChrootDir() != "") {
+      try {
+        service.chroot(gOpt.getChrootDir());
+      }
+      catch(const std::runtime_error& e) {
+        cLog.msg(Log::PRIO_WARNING) << "ignoring chroot error: " << e.what();
+      }
+    }
+    service.dropPrivs();
+
+    // this has to be called before the first thread is started
+    gSignalController.init(service);
+    gResolver.init();
+    boost::thread(boost::bind(&TunDevice::waitUntilReady,&dev));
+    if (postup_script)
+      boost::thread(boost::bind(&SysExec::waitAndDestroy,postup_script));
+
+    initCrypto();   
     PacketSource* src = new UDPPacketSource(gOpt.getLocalAddr(), gOpt.getLocalPort());
 
     if(gOpt.getRemoteAddr() != "")
-      gResolver.resolveUdp(gOpt.getRemoteAddr(), gOpt.getRemotePort(), boost::bind(createConnection, _1, gOpt.getSeqWindowSize(), gOpt.getMux()), boost::bind(createConnectionError, _1), gOpt.getResolvAddrType());
+      gResolver.resolveUdp(gOpt.getRemoteAddr(), gOpt.getRemotePort(), boost::bind(createConnectionResolver, _1, gOpt.getSeqWindowSize(), gOpt.getMux()), boost::bind(createConnectionError, _1), gOpt.getResolvAddrType());
 
     HostList connect_to = gOpt.getRemoteSyncHosts();
 #ifndef NO_ROUTING
@@ -480,20 +451,11 @@ int main(int argc, char* argv[])
       connectThreads.create_thread(boost::bind(syncConnector, *it));
 #endif
 
-        // wait for packet source to finish in a seperate thread in order
-        // to be still able to process signals while waiting
-#ifndef NO_DAEMON
-    boost::thread(boost::bind(startSendRecvThreads, privs, &dev, src));
-#else
+    // wait for packet source to finish in a seperate thread in order
+    // to be still able to process signals while waiting
     boost::thread(boost::bind(startSendRecvThreads, &dev, src));
-#endif
 
-#if defined(WIN_SERVICE)
-    int ret = 0;
-    gWinService.waitForStop();
-#else
     int ret = gSignalController.run();  
-#endif
 
 // TODO: stop all threads and cleanup
 // 
@@ -501,27 +463,20 @@ int main(int argc, char* argv[])
 //       delete src;
 //     if(connTo)
 //       delete connTo;
-
-#if defined(WIN_SERVICE)
-    gWinService.stop();
-#endif
     return ret; 
   }
   catch(std::runtime_error& e)
   {
     cLog.msg(Log::PRIO_ERROR) << "uncaught runtime error, exiting: " << e.what();
-    if(!daemonized)
+    if(!service.isDaemonized())
       std::cout << "uncaught runtime error, exiting: " << e.what() << std::endl;
   }
   catch(std::exception& e)
   {
     cLog.msg(Log::PRIO_ERROR) << "uncaught exception, exiting: " << e.what();
-    if(!daemonized)
+    if(!service.isDaemonized())
       std::cout << "uncaught exception, exiting: " << e.what() << std::endl;
   }
-#if defined(WIN_SERVICE)
-  gWinService.stop();
-#endif
   return -1;
 }