토큰 데이터를 앱에 저장할 때 기존에는 UserDefault로 구현되어 있었는데, 보안이 취약하고 앱 외부에서 변경이 가능하기 때문에 개선이 필요했습니다.

KeychainSwift 라이브러리를 참고해서 Keychain에 access token을 set, get, delete, clear 할 수 있는 함수를 구현했습니다.

set할 때는 TokenType을 열거형으로 만들어서 accessToken, refreshToken 케이스를 나누어서 저장했습니다.

또한, NSLock을 통해서 동시 접근 시, 데이터 동기화 제대로 되지 않을 수 있는 문제를 해결할 수 있었습니다.

구현 코드

//  Keychain.swift
import Security
import Foundation

enum TokenType: String {
  case accessToken
  case refreshToken
}

final class Keychain {
  private let lock: NSLock = NSLock() // 멀티 스레드 환경에서 객체의 멤버에 동시 접근 방지
  private var lastResultCode: OSStatus = noErr
  
  static let shared = Keychain()
  
  private init() { }
  
  @discardableResult
  func set(_ value: String, forKey key: TokenType) -> Bool {
    if let value = value.data(using: String.Encoding.utf8) {
      return set(value, forKey: key)
    }
    
    return false
  }
  
  @discardableResult
  func set(_ value: Data, forKey key: TokenType) -> Bool {
    lock.lock()
    defer { lock.unlock() }
    
    let keyChainQuery: NSDictionary = [
      kSecClass : kSecClassGenericPassword,
      kSecAttrAccount: key.rawValue,
      kSecValueData: value
    ]
    
    deleteNolock(key)
    lastResultCode = SecItemAdd(keyChainQuery, nil)
    
    return lastResultCode == noErr
  }
  
  func get(_ key: TokenType) -> String? {
    if let data = getData(key) {
      if let currentString = String(data: data, encoding: .utf8) {
        return currentString
      }
      
      lastResultCode = -67853 // errSecInvalidEncoding
    }

    return nil
  }
  
  func getData(_ key: TokenType) -> Data? {
    lock.lock()
    defer { lock.unlock() }
    
    let query: NSDictionary = [
      kSecClass: kSecClassGenericPassword,
      kSecAttrAccount: key.rawValue,
      kSecReturnData: kCFBooleanTrue as Any,
      kSecMatchLimit: kSecMatchLimitOne
    ]
    
    var result: AnyObject?
    
    lastResultCode = withUnsafeMutablePointer(to: &result) {
      SecItemCopyMatching(query as CFDictionary,
                          UnsafeMutablePointer($0))
    }
    
    if lastResultCode == noErr { return result as? Data }
    return nil
  }
  
  @discardableResult
  func delete(_ key: TokenType) -> Bool {
    lock.lock()
    defer { lock.unlock() }
    
    return deleteNolock(key)
  }
  
  @discardableResult
  func deleteNolock(_ key: TokenType) -> Bool {
    let keyChainQuery: NSDictionary = [
      kSecClass: kSecClassGenericPassword,
      kSecAttrAccount: key.rawValue
    ]
    
    lastResultCode = SecItemDelete(keyChainQuery)
    return lastResultCode == noErr
  }
  
  @discardableResult
  func clear() -> Bool {
    lock.lock()
    defer { lock.unlock() }
    
    let keyChainQuery: NSDictionary = [
      kSecClass: kSecClassGenericPassword,
    ]
    
    lastResultCode = SecItemDelete(keyChainQuery)
    return lastResultCode == noErr
  }
}

Test Code

//  KeychainTest.swift
import XCTest
@testable import Falling

class KeychainTest: XCTestCase {
  
  var object: Keychain!
  
  override func setUp() {
    super.setUp()
    
    object = Keychain.shared
    object.clear()
  }
  
  // MARK: - Set
  func testSet() {
    XCTAssertTrue(object.set("Hello", forKey: .accessToken))
    XCTAssertEqual("Hello", object.get(.accessToken)!)
  }
  
  // MARK: - Get
  func testGet_returnNilWhenValueNotSet() {
    XCTAssert(object.get(.accessToken) == nil)
  }
  
  // MARK: - Delete
  func testDelete() {
    object.set("Hello", forKey: .accessToken)
    object.delete(.accessToken)
    
    XCTAssert(object.get(.accessToken) == nil)
  }
  
  func testDelete_deleteOnSingleKey() {
    object.set("Hello", forKey: .accessToken)
    object.set("Hello!!", forKey: .refreshToken)
    
    object.delete(.accessToken)
    
    XCTAssertEqual("Hello!!", object.get(.refreshToken)!)
  }
  
  // MARK: - Clear
  func testClear() {
    object.set("Hello", forKey: .accessToken)
    object.set("Hello!!", forKey: .refreshToken)
    
    object.clear()
    
    XCTAssert(object.get(.accessToken) == nil)
    XCTAssert(object.get(.refreshToken) == nil)
  }
  
  // MARK: - Concurrency
  func testConcurrencyDoesntCrash() {
    let expectation = self.expectation(description: "Wait for write loop")
    let expectation2 = self.expectation(description: "Wait for write loop")
    
    let dataToWrite = "{ asdf ñlk BNALSKDJFÑLAKSJDFÑLKJ ZÑCLXKJ ÑALSKDFJÑLKASJDFÑLKJASDÑFLKJAÑSDLKFJÑLKJ}"
    object.set(dataToWrite, forKey: .accessToken)
    
    var writes = 0
    
    let readQueue = DispatchQueue(label: "ReadQueue", attributes: [])
    readQueue.async {
      for _ in 0..<400 {
        let _: String? = synchronize( { completion in
          let result: String? = self.object.get(.accessToken)
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              completion(result)
            }
          }
        }, timeoutWith: nil)
      }
    }
    let readQueue2 = DispatchQueue(label: "ReadQueue2", attributes: [])
    readQueue2.async {
      for _ in 0..<400 {
        let _: String? = synchronize( { completion in
          let result: String? = self.object.get(.accessToken)
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              completion(result)
            }
          }
        }, timeoutWith: nil)
      }
    }
    let readQueue3 = DispatchQueue(label: "ReadQueue3", attributes: [])
    readQueue3.async {
      for _ in 0..<400 {
        let _: String? = synchronize( { completion in
          let result: String? = self.object.get(.accessToken)
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              completion(result)
            }
          }
        }, timeoutWith: nil)
      }
    }
    
    let deleteQueue = DispatchQueue(label: "deleteQueue", attributes: [])
    deleteQueue.async {
      for _ in 0..<400 {
        let _: Bool = synchronize( { completion in
          let result = self.object.delete(.accessToken)
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              completion(result)
            }
          }
        }, timeoutWith: false)
      }
    }
    
    let deleteQueue2 = DispatchQueue(label: "deleteQueue2", attributes: [])
    deleteQueue2.async {
      for _ in 0..<400 {
        let _: Bool = synchronize( { completion in
          let result = self.object.delete(.accessToken)
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              completion(result)
            }
          }
        }, timeoutWith: false)
      }
    }
    
    let clearQueue = DispatchQueue(label: "clearQueue", attributes: [])
    clearQueue.async {
      for _ in 0..<400 {
        let _: Bool = synchronize( { completion in
          let result = self.object.clear()
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              completion(result)
            }
          }
        }, timeoutWith: false)
      }
    }
    
    let clearQueue2 = DispatchQueue(label: "clearQueue2", attributes: [])
    clearQueue2.async {
      for _ in 0..<400 {
        let _: Bool = synchronize( { completion in
          let result = self.object.clear()
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              completion(result)
            }
          }
        }, timeoutWith: false)
      }
    }
    
    let writeQueue = DispatchQueue(label: "WriteQueue", attributes: [])
    writeQueue.async {
      for _ in 0..<500 {
        let written: Bool = synchronize({ completion in
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              let result = self.object.set(dataToWrite, forKey: .accessToken)
              completion(result)
            }
          }
        }, timeoutWith: false)
        if written {
          writes = writes + 1
        }
      }
      expectation.fulfill()
    }
    
    let writeQueue2 = DispatchQueue(label: "WriteQueue2", attributes: [])
    writeQueue2.async {
      for _ in 0..<500 {
        let written: Bool = synchronize({ completion in
          DispatchQueue.global(qos: .background).async {
            DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(5)) {
              let result = self.object.set(dataToWrite, forKey: .accessToken)
              completion(result)
            }
          }
        }, timeoutWith: false)
        if written {
          writes = writes + 1
        }
      }
      expectation2.fulfill()
    }
    
    for _ in 0..<1000 {
      self.object.set(dataToWrite, forKey: .accessToken)
      let _ = self.object.get(.accessToken)
    }
    self.waitForExpectations(timeout: 30, handler: nil)
    
    XCTAssertEqual(1000, writes)
  }
}

// Synchronizes a asynch closure
// Ref: <https://forums.developer.apple.com/thread/11519>
func synchronize<ResultType>(_ asynchClosure: (_ completion: @escaping (ResultType) -> ()) -> Void,
                             
                             timeout: DispatchTime = DispatchTime.distantFuture,
                             timeoutWith: @autoclosure @escaping () -> ResultType) -> ResultType {
  let sem = DispatchSemaphore(value: 0)
  
  var result: ResultType?
  
  asynchClosure { (r: ResultType) -> () in
    result = r
    sem.signal()
  }
  _ = sem.wait(timeout: timeout)
  if result == nil {
    result = timeoutWith()
  }
  return result!
}

평소에 구현할 때 보안에 대해서 크게 생각하지 않았었는데, 이번 기회를 통해서 클라이언트에서도 보안이 중요하다는 것을 알았음. 또한, UserDefaultKeychain의 차이를 알게 됨.