/*
 *
 *    Copyright (c) 2020-2021 Project CHIP Authors
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */
#pragma once

#include <lib/core/CHIPError.h>
#include <lib/support/CodeUtils.h>
#include <lib/support/Pool.h>
#include <system/TimeSource.h>
#include <transport/SecureSession.h>

namespace chip {
namespace Transport {

constexpr uint16_t kMaxSessionID       = UINT16_MAX;
constexpr uint16_t kUnsecuredSessionId = 0;

/**
 * Handles a set of sessions.
 *
 * Intended for:
 *   - handle session active time and expiration
 *   - allocate and free space for sessions.
 */
template <size_t kMaxSessionCount>
class SecureSessionTable
{
public:
    ~SecureSessionTable() { mEntries.ReleaseAll(); }

    void Init() { mNextSessionId = chip::Crypto::GetRandU16(); }

    /**
     * Allocate a new secure session out of the internal resource pool.
     *
     * @param secureSessionType secure session type
     * @param localSessionId unique identifier for the local node's secure unicast session context
     * @param peerNodeId represents peer Node's ID
     * @param peerCATs represents peer CASE Authenticated Tags
     * @param peerSessionId represents the encryption key ID assigned by peer node
     * @param fabric represents fabric ID for the session
     * @param config represents the reliable message protocol configuration
     *
     * @note the newly created state will have an 'active' time set based on the current time source.
     *
     * @returns CHIP_NO_ERROR if state could be initialized. May fail if maximum session count
     *          has been reached (with CHIP_ERROR_NO_MEMORY).
     */
    CHECK_RETURN_VALUE
    Optional<SessionHandle> CreateNewSecureSession(SecureSession::Type secureSessionType, uint16_t localSessionId,
                                                   NodeId peerNodeId, CATValues peerCATs, uint16_t peerSessionId,
                                                   FabricIndex fabric, const ReliableMessageProtocolConfig & config)
    {
        SecureSession * result =
            mEntries.CreateObject(secureSessionType, localSessionId, peerNodeId, peerCATs, peerSessionId, fabric, config);
        return result != nullptr ? MakeOptional<SessionHandle>(*result) : Optional<SessionHandle>::Missing();
    }

    /**
     * Allocate a new secure session out of the internal resource pool with the
     * specified session ID.  The returned secure session will not become active
     * until the call to SecureSession::Activate.  If there is a resident
     * session at the passed ID, an empty Optional will be returned to signal
     * the error.
     *
     * This variant of the interface is primarily useful in testing, where
     * session IDs may need to be predetermined.
     *
     * @param localSessionId unique identifier for the local node's secure unicast session context
     * @returns allocated session, or NullOptional on failure
     */
    CHECK_RETURN_VALUE
    Optional<SessionHandle> CreateNewSecureSession(uint16_t localSessionId)
    {
        Optional<SessionHandle> rv = Optional<SessionHandle>::Missing();
        SecureSession * allocated  = nullptr;
        VerifyOrExit(localSessionId != kUnsecuredSessionId, rv = NullOptional);
        VerifyOrExit(!FindSecureSessionByLocalKey(localSessionId).HasValue(), rv = NullOptional);
        allocated = mEntries.CreateObject(localSessionId);
        VerifyOrExit(allocated != nullptr, rv = Optional<SessionHandle>::Missing());
        rv = MakeOptional<SessionHandle>(*allocated);
    exit:
        return rv;
    }

    /**
     * Allocate a new secure session out of the internal resource pool with a
     * non-colliding session ID and increments mNextSessionId to give a clue to
     * the allocator for the next allocation.  The secure session session will
     * not become active until the call to SecureSession::Activate.
     *
     * @returns allocated session, or NullOptional on failure
     */
    CHECK_RETURN_VALUE
    Optional<SessionHandle> CreateNewSecureSession()
    {
        Optional<SessionHandle> rv = Optional<SessionHandle>::Missing();
        auto sessionId             = FindUnusedSessionId();
        SecureSession * allocated  = nullptr;
        VerifyOrExit(sessionId.HasValue(), rv = Optional<SessionHandle>::Missing());
        allocated = mEntries.CreateObject(sessionId.Value());
        VerifyOrExit(allocated != nullptr, rv = Optional<SessionHandle>::Missing());
        rv             = MakeOptional<SessionHandle>(*allocated);
        mNextSessionId = sessionId.Value() == kMaxSessionID ? static_cast<uint16_t>(kUnsecuredSessionId + 1)
                                                            : static_cast<uint16_t>(sessionId.Value() + 1);
    exit:
        return rv;
    }

    void ReleaseSession(SecureSession * session) { mEntries.ReleaseObject(session); }

    template <typename Function>
    Loop ForEachSession(Function && function)
    {
        return mEntries.ForEachActiveObject(std::forward<Function>(function));
    }

    /**
     * Get a secure session given its session ID.
     *
     * @param localSessionId the identifier of a secure unicast session context within the local node
     *
     * @return the session if found, NullOptional if not found
     */
    CHECK_RETURN_VALUE
    Optional<SessionHandle> FindSecureSessionByLocalKey(uint16_t localSessionId)
    {
        SecureSession * result = nullptr;
        mEntries.ForEachActiveObject([&](auto session) {
            if (session->GetLocalSessionId() == localSessionId)
            {
                result = session;
                return Loop::Break;
            }
            return Loop::Continue;
        });
        return result != nullptr ? MakeOptional<SessionHandle>(*result) : Optional<SessionHandle>::Missing();
    }

    /**
     * Iterates through all active sessions and expires any sessions with an idle time
     * larger than the given amount.
     *
     * Expiring a session involves callback execution and then clearing the internal state.
     */
    template <typename Callback>
    void ExpireInactiveSessions(System::Clock::Timestamp maxIdleTime, Callback callback)
    {
        mEntries.ForEachActiveObject([&](auto session) {
            if (session->GetSecureSessionType() != SecureSession::Type::kPending &&
                session->GetLastActivityTime() + maxIdleTime < System::SystemClock().GetMonotonicTimestamp())
            {
                callback(*session);
                ReleaseSession(session);
            }
            return Loop::Continue;
        });
    }

private:
    /**
     * Find an available session ID that is unused in the secure session table.
     *
     * The search algorithm iterates over the session ID space in the outer loop
     * and the session table in the inner loop to locate an available session ID
     * from the starting mNextSessionId clue.
     *
     * The outer-loop considers 64 session IDs in each iteration to give a
     * runtime complexity of O(kMaxSessionCount^2/64).  Speed up could be
     * achieved with a sorted session table or additional storage.
     *
     * @return an unused session ID if any is found, else NullOptional
     */
    CHECK_RETURN_VALUE
    Optional<uint16_t> FindUnusedSessionId()
    {
        uint16_t candidate_base = 0;
        uint64_t candidate_mask = 0;
        for (uint32_t i = 0; i <= kMaxSessionID; i += 64)
        {
            // candidate_base is the base session ID we are searching from.
            // We have a 64-bit mask anchored at this ID and iterate over the
            // whole session table, setting bits in the mask for in-use IDs.
            // If we can iterate through the entire session table and have
            // any bits clear in the mask, we have available session IDs.
            candidate_base = static_cast<uint16_t>(i + mNextSessionId);
            candidate_mask = 0;
            {
                uint16_t shift = static_cast<uint16_t>(kUnsecuredSessionId - candidate_base);
                if (shift <= 63)
                {
                    candidate_mask |= (1ULL << shift); // kUnsecuredSessionId is never available
                }
            }
            mEntries.ForEachActiveObject([&](auto session) {
                uint16_t shift = static_cast<uint16_t>(session->GetLocalSessionId() - candidate_base);
                if (shift <= 63)
                {
                    candidate_mask |= (1ULL << shift);
                }
                if (candidate_mask == UINT64_MAX)
                {
                    return Loop::Break; // No bits clear means this bucket is full.
                }
                return Loop::Continue;
            });
            if (candidate_mask != UINT64_MAX)
            {
                break; // Any bit clear means we have an available ID in this bucket.
            }
        }
        if (candidate_mask != UINT64_MAX)
        {
            uint16_t offset = 0;
            while (candidate_mask & 1)
            {
                candidate_mask >>= 1;
                ++offset;
            }
            uint16_t available = static_cast<uint16_t>(candidate_base + offset);
            return MakeOptional<uint16_t>(available);
        }
        else
        {
            return NullOptional;
        }
    }

    BitMapObjectPool<SecureSession, kMaxSessionCount> mEntries;
    uint16_t mNextSessionId = 0;
};

} // namespace Transport
} // namespace chip
