/*
 *  Copyright (c) 2017, Facebook, Inc.
 *  All rights reserved.
 *
 *  This source code is licensed under the BSD-style license found in the
 *  LICENSE file in the root directory of this source tree. An additional grant
 *  of patent rights can be found in the PATENTS file in the same directory.
 *
 */

#include <wangle/concurrent/IOThreadPoolExecutor.h>

#include <glog/logging.h>

#include <folly/detail/MemoryIdler.h>

namespace wangle {

using folly::detail::MemoryIdler;
using folly::AsyncTimeout;
using folly::EventBase;
using folly::EventBaseManager;
using folly::Func;
using folly::RWSpinLock;

/* Class that will free jemalloc caches and madvise the stack away
 * if the event loop is unused for some period of time
 */
class MemoryIdlerTimeout
    : public AsyncTimeout , public EventBase::LoopCallback {
 public:
  explicit MemoryIdlerTimeout(EventBase* b) : AsyncTimeout(b), base_(b) {}

  void timeoutExpired() noexcept override { idled = true; }

  void runLoopCallback() noexcept override {
    if (idled) {
      MemoryIdler::flushLocalMallocCaches();
      MemoryIdler::unmapUnusedStack(MemoryIdler::kDefaultStackToRetain);

      idled = false;
    } else {
      std::chrono::steady_clock::duration idleTimeout =
        MemoryIdler::defaultIdleTimeout.load(
          std::memory_order_acquire);

      idleTimeout = MemoryIdler::getVariationTimeout(idleTimeout);

      scheduleTimeout(std::chrono::duration_cast<std::chrono::milliseconds>(
                        idleTimeout).count());
    }

    // reschedule this callback for the next event loop.
    base_->runBeforeLoop(this);
  }
 private:
  EventBase* base_;
  bool idled{false};
} ;

IOThreadPoolExecutor::IOThreadPoolExecutor(
    size_t numThreads,
    std::shared_ptr<ThreadFactory> threadFactory,
    EventBaseManager* ebm,
    bool waitForAll)
  : ThreadPoolExecutor(numThreads, std::move(threadFactory), waitForAll),
    nextThread_(0),
    eventBaseManager_(ebm) {
  addThreads(numThreads);
  CHECK(threadList_.get().size() == numThreads);
}

IOThreadPoolExecutor::~IOThreadPoolExecutor() {
  stop();
}

void IOThreadPoolExecutor::add(Func func) {
  add(std::move(func), std::chrono::milliseconds(0));
}

void IOThreadPoolExecutor::add(
    Func func,
    std::chrono::milliseconds expiration,
    Func expireCallback) {
  RWSpinLock::ReadHolder r{&threadListLock_};
  if (threadList_.get().empty()) {
    throw std::runtime_error("No threads available");
  }
  auto ioThread = pickThread();

  auto task = Task(std::move(func), expiration, std::move(expireCallback));
  auto wrappedFunc = [ ioThread, task = std::move(task) ]() mutable {
    runTask(ioThread, std::move(task));
    ioThread->pendingTasks--;
  };

  ioThread->pendingTasks++;
  if (!ioThread->eventBase->runInEventBaseThread(std::move(wrappedFunc))) {
    ioThread->pendingTasks--;
    throw std::runtime_error("Unable to run func in event base thread");
  }
}

std::shared_ptr<IOThreadPoolExecutor::IOThread>
IOThreadPoolExecutor::pickThread() {
  auto& me = *thisThread_;
  auto& ths = threadList_.get();
  // When new task is added to IOThreadPoolExecutor, a thread is chosen for it
  // to be executed on, thisThread_ is by default chosen, however, if the new
  // task is added by the clean up operations on thread destruction, thisThread_
  // is not an available thread anymore, thus, always check whether or not
  // thisThread_ is an available thread before choosing it.
  if (me && std::find(ths.cbegin(), ths.cend(), me) != ths.cend()) {
    return me;
  }
  auto thread = ths[nextThread_++ % ths.size()];
  return std::static_pointer_cast<IOThread>(thread);
}

EventBase* IOThreadPoolExecutor::getEventBase() {
  return pickThread()->eventBase;
}

EventBase* IOThreadPoolExecutor::getEventBase(
    ThreadPoolExecutor::ThreadHandle* h) {
  auto thread = dynamic_cast<IOThread*>(h);

  if (thread) {
    return thread->eventBase;
  }

  return nullptr;
}

EventBaseManager* IOThreadPoolExecutor::getEventBaseManager() {
  return eventBaseManager_;
}

std::shared_ptr<ThreadPoolExecutor::Thread>
IOThreadPoolExecutor::makeThread() {
  return std::make_shared<IOThread>(this);
}

void IOThreadPoolExecutor::threadRun(ThreadPtr thread) {
  const auto ioThread = std::static_pointer_cast<IOThread>(thread);
  ioThread->eventBase = eventBaseManager_->getEventBase();
  thisThread_.reset(new std::shared_ptr<IOThread>(ioThread));

  auto idler = std::make_unique<MemoryIdlerTimeout>(ioThread->eventBase);
  ioThread->eventBase->runBeforeLoop(idler.get());

  ioThread->eventBase->runInEventBaseThread(
      [thread]{ thread->startupBaton.post(); });
  while (ioThread->shouldRun) {
    ioThread->eventBase->loopForever();
  }
  if (isJoin_) {
    while (ioThread->pendingTasks > 0) {
      ioThread->eventBase->loopOnce();
    }
  }
  idler.reset();
  if (isWaitForAll_) {
    // some tasks, like thrift asynchronous calls, create additional
    // event base hookups, let's wait till all of them complete.
    ioThread->eventBase->loop();
  }

  std::lock_guard<std::mutex> guard(ioThread->eventBaseShutdownMutex_);
  ioThread->eventBase = nullptr;
  eventBaseManager_->clearEventBase();
}

// threadListLock_ is writelocked
void IOThreadPoolExecutor::stopThreads(size_t n) {
  std::vector<ThreadPtr> stoppedThreads;
  stoppedThreads.reserve(n);
  for (size_t i = 0; i < n; i++) {
    const auto ioThread = std::static_pointer_cast<IOThread>(
        threadList_.get()[i]);
    for (auto& o : observers_) {
      o->threadStopped(ioThread.get());
    }
    ioThread->shouldRun = false;
    stoppedThreads.push_back(ioThread);
    std::lock_guard<std::mutex> guard(ioThread->eventBaseShutdownMutex_);
    if (ioThread->eventBase) {
      ioThread->eventBase->terminateLoopSoon();
    }
  }
  for (auto thread : stoppedThreads) {
    stoppedThreads_.add(thread);
    threadList_.remove(thread);
  }
}

// threadListLock_ is readlocked
uint64_t IOThreadPoolExecutor::getPendingTaskCount() {
  uint64_t count = 0;
  for (const auto& thread : threadList_.get()) {
    auto ioThread = std::static_pointer_cast<IOThread>(thread);
    size_t pendingTasks = ioThread->pendingTasks;
    if (pendingTasks > 0 && !ioThread->idle) {
      pendingTasks--;
    }
    count += pendingTasks;
  }
  return count;
}

} // namespace wangle
