From 87bca87ab318ff254fab4b87ac8e469925f44d22 Mon Sep 17 00:00:00 2001
From: Bent Bisballe Nyeng <deva@aasimon.org>
Date: Fri, 30 May 2014 10:53:23 +0200
Subject: Add (more) error checking for SRTP.

---
 src/srtp.cc | 116 +++++++++++++++++++++++-------------------------------------
 src/srtp.h  |   8 -----
 2 files changed, 45 insertions(+), 79 deletions(-)

(limited to 'src')

diff --git a/src/srtp.cc b/src/srtp.cc
index 36f3148..0a02994 100644
--- a/src/srtp.cc
+++ b/src/srtp.cc
@@ -35,6 +35,9 @@
 
 #include "asc2bin.h"
 
+// Global SRTP instance status
+static err_status_t active_srtp_instance_status = err_status_init_fail;
+
 // This macro translates srtp status codes into exceptions:
 #define SRTP_THROW(s)                                                   \
   do {                                                                  \
@@ -80,6 +83,10 @@ struct SRTP::prv {
 SRTP::SRTP(std::string key, unsigned int ssrc)
   _throw(enum lrtp_status_t)
 {
+  if(active_srtp_instance_status != err_status_ok) {
+    SRTP_THROW(active_srtp_instance_status);
+  }
+
   prv = NULL;
 
   err_status_t status;
@@ -89,8 +96,12 @@ SRTP::SRTP(std::string key, unsigned int ssrc)
 
   prv->ssrc = ssrc;
 
-  setupKey(key);
-  setupPolicy(true, true);
+  try {
+    setupKey(key);
+    setupPolicy(true, true);
+  } catch(enum lrtp_status_t s) {
+    throw s;
+  }
 
   status = srtp_create(&prv->session, &prv->policy);
   if(status != err_status_ok) SRTP_THROW(status);
@@ -102,16 +113,10 @@ SRTP::SRTP(std::string key, unsigned int ssrc)
 SRTP::~SRTP()
 {
   err_status_t status = srtp_remove_stream(prv->session, htonl(prv->ssrc));
-  if(status != err_status_ok) {
-    // TODO: Error handling
-    printf("srtp_remove_stream failed %d\n", status);
-  }
+  if(status != err_status_ok) SRTP_THROW(status);
 
   status = srtp_dealloc(prv->session);
-  if(status != err_status_ok) {
-    // TODO: Error handling
-    printf("srtp_dealloc failed %d\n", status);
-  }
+  if(status != err_status_ok) SRTP_THROW(status);
 
   if(prv) {
     free(prv->key);
@@ -125,16 +130,16 @@ void SRTP::setupKey(const std::string &key)
   prv->key = (char *)calloc(MASTER_KEY_LEN, 1);
   prv->key_len = MASTER_KEY_LEN;
 
-  if(key.length() > MASTER_KEY_LEN * 2) printf("KeyTooLong\n"); // TODO
+  if(key.length() > MASTER_KEY_LEN * 2) throw LRTP_KEY_TOO_LONG;
 
   // Read key from hexadecimal on command line into an octet string
   ssize_t len = asc2bin(prv->key, prv->key_len, key.c_str(), key.size());
 
-  if(len == -1) printf("InvalidHexKeyString\n"); // TODO
+  if(len == -1) throw LRTP_INVALID_KEY_STRING;
   prv->key_len = len;
 
   // check that hex string is the right length.
-  if(len < MASTER_KEY_LEN) printf("KeyTooShort\n"); // TODO
+  if(len < MASTER_KEY_LEN) throw LRTP_KEY_TOO_SHORT;
 }
 
 void SRTP::setupPolicy(bool confidentiality, bool authentication)
@@ -142,9 +147,12 @@ void SRTP::setupPolicy(bool confidentiality, bool authentication)
 {
 #ifndef USE_CRYPTO
   confidentiality = authentication = false;
-  printf("No crypto!\n");
+  //  printf("No crypto!\n");
 #endif  
 
+  // Apparently not all fields in prv->policy are set during initialisation.
+  memset(&prv->policy, 0, sizeof(srtp_policy_t));
+
   /* 
    * create policy structure, using the default mechanisms but 
    * with only the security services requested on the command line,
@@ -190,11 +198,9 @@ int SRTP::encrypt(char *packet, size_t size)
   _throw(enum lrtp_status_t)
 {
   int sz = size;
+
   err_status_t status = srtp_protect(prv->session, packet, &sz);
-  if(status != err_status_ok) {
-    // TODO: throw SRTP::UnprotectException();
-    printf("srtp_protect failed %d\n", status);
-  }
+  if(status != err_status_ok) SRTP_THROW(status);
 
   return sz;
 }
@@ -203,69 +209,37 @@ int SRTP::decrypt(char *packet, size_t size)
   _throw(enum lrtp_status_t)
 {
   int sz = size;
-  err_status_t status = srtp_unprotect(prv->session, packet, &sz);
-  switch(status) {
-  case err_status_ok:
-    // No errors.
-    break;
-  case err_status_replay_fail:
-    // TODO: throw SRTP::ReplayException();// (replay check failed)
-    printf("srtp_unprotect failed replay %d\n", status);
-    sz = -1;
-    break;
-  case err_status_replay_old:
-    // TODO: throw SRTP::ReplayOldException();// (replay check failed)
-    printf("srtp_unprotect failed replay_old %d\n", status);
-    sz = -1;
-    break;
-  case err_status_auth_fail:
-    // TODO: throw SRTP::AuthCheckException();// (auth check failed)
-    printf("srtp_unprotect failed auth %d\n", status);
-    sz = -1;
-    break;
-  default:
-    // TODO: throw SRTP::UnprotectException();
-    printf("srtp_unprotect failed %d\n", status);
-    sz = -1;
-    break;
-  }
 
-  /*
-	if(octets_recvd - RTP_HEADER_LEN > (ssize_t)size) {
-    printf("BufferSize %d\n", status);
-  }
-  */
-
-  // TODO: rtp.fromBuffer(packet, size);
-  //memcpy(buf, prv->receiver->message.body, octets_recvd - RTP_HEADER_LEN);
+  err_status_t status = srtp_unprotect(prv->session, packet, &sz);
+  if(status != err_status_ok) SRTP_THROW(status);
 
   return sz;
 }
 
-// Global SRTP instance reference counter
-static int active_srtp_instances = 0;
+class SRTPInstance {
+public:
+  SRTPInstance();
+  ~SRTPInstance();
+};
 
-SRTP::SRTPInstance::SRTPInstance()
-  _throw(enum lrtp_status_t)
+SRTPInstance::SRTPInstance()
 {
-  err_status_t status;
-
-  if(active_srtp_instances == 0) {
-    status = srtp_init();
-    active_srtp_instances++;
-    if(status != err_status_ok) SRTP_THROW(status);
-  }
+  // printf("SRTP init\n");
+  active_srtp_instance_status = srtp_init();
 }
 
-SRTP::SRTPInstance::~SRTPInstance()
-  _throw(enum lrtp_status_t)
+SRTPInstance::~SRTPInstance()
 {
-  err_status_t status;
+  // printf("SRTP shutdown\n");
 
-  active_srtp_instances--;
-
-  if(active_srtp_instances == 0) {
-    status = srtp_shutdown();
-    if(status != err_status_ok) SRTP_THROW(status);
+  if(active_srtp_instance_status == err_status_ok) {
+    active_srtp_instance_status = srtp_shutdown();
+    //if(status != err_status_ok) Nothing we can do here really...
   }
+
+  // Mark SRTP instance as 'shut down'
+  active_srtp_instance_status = err_status_init_fail;
 }
+
+// One global instance.
+SRTPInstance instance;
diff --git a/src/srtp.h b/src/srtp.h
index 66197bb..00161fd 100644
--- a/src/srtp.h
+++ b/src/srtp.h
@@ -46,14 +46,6 @@ public:
   int decrypt(char *packet, size_t size) _throw(enum lrtp_status_t);
 
 private:
-  class SRTPInstance {
-  public:
-    SRTPInstance() _throw(enum lrtp_status_t);
-    ~SRTPInstance() _throw(enum lrtp_status_t);
-  };
-
-  SRTPInstance instance;
-
   struct prv;
   struct prv *prv;
 
-- 
cgit v1.2.3