View Javadoc
1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *  
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *  
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License. 
18   *  
19   */
20  package org.apache.directory.server.kerberos.shared.crypto.encryption;
21  
22  
23  import java.security.GeneralSecurityException;
24  import java.security.MessageDigest;
25  import java.security.spec.AlgorithmParameterSpec;
26  
27  import javax.crypto.Cipher;
28  import javax.crypto.Mac;
29  import javax.crypto.SecretKey;
30  import javax.crypto.spec.IvParameterSpec;
31  import javax.crypto.spec.SecretKeySpec;
32  
33  import org.apache.directory.api.util.Strings;
34  import org.apache.directory.server.kerberos.shared.crypto.checksum.ChecksumEngine;
35  import org.apache.directory.shared.kerberos.codec.types.EncryptionType;
36  import org.apache.directory.shared.kerberos.components.EncryptedData;
37  import org.apache.directory.shared.kerberos.components.EncryptionKey;
38  import org.apache.directory.shared.kerberos.crypto.checksum.ChecksumType;
39  import org.apache.directory.shared.kerberos.exceptions.ErrorType;
40  import org.apache.directory.shared.kerberos.exceptions.KerberosException;
41  
42  
43  /**
44   * @author <a href="mailto:dev@directory.apache.org">Apache Directory Project</a>
45   */
46  public class Des3CbcSha1KdEncryption extends EncryptionEngine implements ChecksumEngine
47  {
48      private static final byte[] iv = new byte[]
49          { ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00, ( byte ) 0x00,
50              ( byte ) 0x00 };
51  
52  
53      public EncryptionType getEncryptionType()
54      {
55          return EncryptionType.DES3_CBC_SHA1_KD;
56      }
57  
58  
59      public int getConfounderLength()
60      {
61          return 8;
62      }
63  
64  
65      public int getChecksumLength()
66      {
67          return 20;
68      }
69  
70  
71      public ChecksumType checksumType()
72      {
73          return ChecksumType.HMAC_SHA1_DES3_KD;
74      }
75  
76  
77      public byte[] calculateChecksum( byte[] data, byte[] key, KeyUsage usage )
78      {
79          byte[] kc = deriveKey( key, getUsageKc( usage ), 64, 168 );
80  
81          return processChecksum( data, kc );
82      }
83  
84  
85      public byte[] calculateIntegrity( byte[] data, byte[] key, KeyUsage usage )
86      {
87          byte[] ki = deriveKey( key, getUsageKi( usage ), 64, 168 );
88  
89          return processChecksum( data, ki );
90      }
91  
92  
93      public byte[] getDecryptedData( EncryptionKey key, EncryptedData data, KeyUsage usage ) throws KerberosException
94      {
95          byte[] ke = deriveKey( key.getKeyValue(), getUsageKe( usage ), 64, 168 );
96  
97          byte[] encryptedData = data.getCipher();
98  
99          // extract the old checksum
100         byte[] oldChecksum = new byte[getChecksumLength()];
101         System
102             .arraycopy( encryptedData, encryptedData.length - getChecksumLength(), oldChecksum, 0, oldChecksum.length );
103 
104         // remove trailing checksum
105         encryptedData = removeTrailingBytes( encryptedData, 0, getChecksumLength() );
106 
107         // decrypt the data
108         byte[] decryptedData = decrypt( encryptedData, ke );
109 
110         // remove leading confounder
111         byte[] withoutConfounder = removeLeadingBytes( decryptedData, getConfounderLength(), 0 );
112 
113         // calculate a new checksum
114         byte[] newChecksum = calculateIntegrity( decryptedData, key.getKeyValue(), usage );
115 
116         // compare checksums
117         if ( !MessageDigest.isEqual( oldChecksum, newChecksum ) )
118         {
119             throw new KerberosException( ErrorType.KRB_AP_ERR_BAD_INTEGRITY );
120         }
121 
122         return withoutConfounder;
123     }
124 
125 
126     public EncryptedData getEncryptedData( EncryptionKey key, byte[] plainText, KeyUsage usage )
127     {
128         byte[] ke = deriveKey( key.getKeyValue(), getUsageKe( usage ), 64, 168 );
129 
130         // build the ciphertext structure
131         byte[] conFounder = getRandomBytes( getConfounderLength() );
132         byte[] paddedPlainText = padString( plainText );
133         byte[] dataBytes = concatenateBytes( conFounder, paddedPlainText );
134         byte[] checksumBytes = calculateIntegrity( dataBytes, key.getKeyValue(), usage );
135         byte[] encryptedData = encrypt( dataBytes, ke );
136         byte[] cipherText = concatenateBytes( encryptedData, checksumBytes );
137 
138         return new EncryptedData( getEncryptionType(), key.getKeyVersion(), cipherText );
139     }
140 
141 
142     public byte[] encrypt( byte[] plainText, byte[] keyBytes )
143     {
144         return processCipher( true, plainText, keyBytes );
145     }
146 
147 
148     public byte[] decrypt( byte[] cipherText, byte[] keyBytes )
149     {
150         return processCipher( false, cipherText, keyBytes );
151     }
152 
153 
154     /**
155      * Derived Key = DK(Base Key, Well-Known Constant)
156      * DK(Key, Constant) = random-to-key(DR(Key, Constant))
157      * DR(Key, Constant) = k-truncate(E(Key, Constant, initial-cipher-state))
158      * 
159      * @param baseKey The base key to derive
160      * @param usage The key usage
161      * @param n The number of resulting bytes
162      * @param k The number of bytes
163      * @return The derived key
164      */
165     protected byte[] deriveKey( byte[] baseKey, byte[] usage, int n, int k )
166     {
167         byte[] result = deriveRandom( baseKey, usage, n, k );
168         result = randomToKey( result );
169 
170         return result;
171     }
172 
173 
174     protected byte[] randomToKey( byte[] seed )
175     {
176         int kBytes = 24;
177         byte[] result = new byte[kBytes];
178 
179         byte[] fillingKey = Strings.EMPTY_BYTES;
180 
181         int pos = 0;
182 
183         for ( int i = 0; i < kBytes; i++ )
184         {
185             if ( pos < fillingKey.length )
186             {
187                 result[i] = fillingKey[pos];
188                 pos++;
189             }
190             else
191             {
192                 fillingKey = getBitGroup( seed, i / 8 );
193                 fillingKey = setParity( fillingKey );
194                 pos = 0;
195                 result[i] = fillingKey[pos];
196                 pos++;
197             }
198         }
199 
200         return result;
201     }
202 
203 
204     protected byte[] getBitGroup( byte[] seed, int group )
205     {
206         int srcPos = group * 7;
207 
208         byte[] result = new byte[7];
209 
210         System.arraycopy( seed, srcPos, result, 0, 7 );
211 
212         return result;
213     }
214 
215 
216     protected byte[] setParity( byte[] in )
217     {
218         byte[] expandedIn = new byte[8];
219 
220         System.arraycopy( in, 0, expandedIn, 0, in.length );
221 
222         setBit( expandedIn, 62, getBit( in, 7 ) );
223         setBit( expandedIn, 61, getBit( in, 15 ) );
224         setBit( expandedIn, 60, getBit( in, 23 ) );
225         setBit( expandedIn, 59, getBit( in, 31 ) );
226         setBit( expandedIn, 58, getBit( in, 39 ) );
227         setBit( expandedIn, 57, getBit( in, 47 ) );
228         setBit( expandedIn, 56, getBit( in, 55 ) );
229 
230         byte[] out = new byte[8];
231 
232         int bitCount = 0;
233         int index = 0;
234 
235         for ( int i = 0; i < 64; i++ )
236         {
237             if ( ( i + 1 ) % 8 == 0 )
238             {
239                 if ( bitCount % 2 == 0 )
240                 {
241                     setBit( out, i, 1 );
242                 }
243 
244                 index++;
245                 bitCount = 0;
246             }
247             else
248             {
249                 int val = getBit( expandedIn, index );
250                 boolean bit = val > 0;
251 
252                 if ( bit )
253                 {
254                     setBit( out, i, val );
255                     bitCount++;
256                 }
257 
258                 index++;
259             }
260         }
261 
262         return out;
263     }
264 
265 
266     private byte[] processCipher( boolean isEncrypt, byte[] data, byte[] keyBytes )
267     {
268         try
269         {
270             Cipher cipher = Cipher.getInstance( "DESede/CBC/NoPadding" );
271             SecretKey key = new SecretKeySpec( keyBytes, "DESede" );
272 
273             AlgorithmParameterSpec paramSpec = new IvParameterSpec( iv );
274 
275             if ( isEncrypt )
276             {
277                 cipher.init( Cipher.ENCRYPT_MODE, key, paramSpec );
278             }
279             else
280             {
281                 cipher.init( Cipher.DECRYPT_MODE, key, paramSpec );
282             }
283 
284             return cipher.doFinal( data );
285         }
286         catch ( GeneralSecurityException nsae )
287         {
288             nsae.printStackTrace();
289             return null;
290         }
291     }
292 
293 
294     private byte[] processChecksum( byte[] data, byte[] key )
295     {
296         try
297         {
298             SecretKey sk = new SecretKeySpec( key, "DESede" );
299 
300             Mac mac = Mac.getInstance( "HmacSHA1" );
301             mac.init( sk );
302 
303             return mac.doFinal( data );
304         }
305         catch ( GeneralSecurityException nsae )
306         {
307             nsae.printStackTrace();
308             return null;
309         }
310     }
311 }