From 56cfdec3473d4702dfb5708c290aa7ae7e881d68 Mon Sep 17 00:00:00 2001
From: Petteri Aimonen <jpa@git.mail.kapsi.fi>
Date: Fri, 15 Sep 2017 07:39:29 +0300
Subject: [PATCH] Make pb_decode_varint32 overflow checks exact (issue #258)

---
 pb_decode.c                               | 23 +++++++++++++++++++----
 tests/decode_unittests/decode_unittests.c |  3 +++
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/pb_decode.c b/pb_decode.c
index b4563cf..d4f71e2 100644
--- a/pb_decode.c
+++ b/pb_decode.c
@@ -191,15 +191,30 @@ bool checkreturn pb_decode_varint32(pb_istream_t *stream, uint32_t *dest)
         
         do
         {
-            if (bitpos >= 32)
-                PB_RETURN_ERROR(stream, "varint overflow");
-            
             if (!pb_readbyte(stream, &byte))
                 return false;
             
-            result |= (uint32_t)(byte & 0x7F) << bitpos;
+            if (bitpos >= 32)
+            {
+                /* Note: Technically, the varint could have trailing 0x80 bytes, even
+                 * though I haven't seen any implementation do that yet. */
+                if ((byte & 0x7F) != 0)
+                {
+                    PB_RETURN_ERROR(stream, "varint overflow");
+                }
+            }
+            else
+            {
+                result |= (uint32_t)(byte & 0x7F) << bitpos;
+            }
             bitpos = (uint_fast8_t)(bitpos + 7);
         } while (byte & 0x80);
+        
+        if (bitpos >= 32 && (byte & 0x70) != 0)
+        {
+            /* The last byte was at bitpos=28, so only bottom 4 bits fit. */
+            PB_RETURN_ERROR(stream, "varint overflow");
+        }
    }
    
    *dest = result;
diff --git a/tests/decode_unittests/decode_unittests.c b/tests/decode_unittests/decode_unittests.c
index a6f5c17..e791783 100644
--- a/tests/decode_unittests/decode_unittests.c
+++ b/tests/decode_unittests/decode_unittests.c
@@ -100,6 +100,9 @@ int main()
         TEST((s = S("\x01"), pb_decode_varint32(&s, &u) && u == 1));
         TEST((s = S("\xAC\x02"), pb_decode_varint32(&s, &u) && u == 300));
         TEST((s = S("\xFF\xFF\xFF\xFF\x0F"), pb_decode_varint32(&s, &u) && u == UINT32_MAX));
+        TEST((s = S("\xFF\xFF\xFF\xFF\x8F\x00"), pb_decode_varint32(&s, &u) && u == UINT32_MAX));
+        TEST((s = S("\xFF\xFF\xFF\xFF\x10"), !pb_decode_varint32(&s, &u)));
+        TEST((s = S("\xFF\xFF\xFF\xFF\x40"), !pb_decode_varint32(&s, &u)));
         TEST((s = S("\xFF\xFF\xFF\xFF\xFF\x01"), !pb_decode_varint32(&s, &u)));
     }
     
-- 
GitLab