//===---- LowerInvokeSimd.h - lower invoke_simd calls ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Finds and lowers __builtin_invoke_simd calls generated by invoke_simd library
// implementation:
// - Performs data flow analysis for the first argument to determine which
//   function link-time constant address it is guaranteed to hold, and replaces
//   the argument with found "target" function.
// - Marks target functions with VCStackCall attribute as required by the Intel
//   GPU backend.
// TODO:
// - move VCStackCall markup to Intel GPU-specific part (BE) or design a new
//   target-neutral attribute for markup.
// - allow "unknown" function pointers, where actual function's address is not
//   deducible when BE is ready

#include "llvm/SYCLLowerIR/LowerInvokeSimd.h"

#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"
#include "llvm/SYCLLowerIR/SYCLUtils.h"

#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/GenXIntrinsics/GenXMetadata.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"

#define DEBUG_TYPE "LowerInvokeSimd"

using namespace llvm;
using namespace llvm::sycl::utils;

namespace {

constexpr char REQD_SUB_GROUP_SIZE_MD[] = "intel_reqd_sub_group_size";

class SYCLLowerInvokeSimdLegacyPass : public ModulePass {
public:
  static char ID; // Pass identification, replacement for typeid
  SYCLLowerInvokeSimdLegacyPass() : ModulePass(ID) {
    initializeSYCLLowerInvokeSimdLegacyPassPass(
        *PassRegistry::getPassRegistry());
  }

  // run the LowerESIMD pass on the specified module
  bool runOnModule(Module &M) override {
    ModuleAnalysisManager MAM;
    auto PA = Impl.run(M, MAM);
    return !PA.areAllPreserved();
  }

private:
  SYCLLowerInvokeSimdPass Impl;
};
} // namespace

char SYCLLowerInvokeSimdLegacyPass::ID = 0;
INITIALIZE_PASS(SYCLLowerInvokeSimdLegacyPass, "SYCLLowerInvokeSimd",
                "Lower SYCL's invoke_simd calls", false, false)

// Public interface to the LowerInvokeSimdPass.
ModulePass *llvm::createSYCLLowerInvokeSimdPass() {
  return new SYCLLowerInvokeSimdLegacyPass();
}

namespace {
// TODO support lambda and functor overloads

using ValueSetImpl = SmallPtrSetImpl<Value *>;
using ValueSet = SmallPtrSet<Value *, 4>;
using ConstValueSetImpl = SmallPtrSetImpl<const Value *>;
using ConstValueSet = SmallPtrSet<const Value *, 4>;

// Expects a call instruction in the form
// __builtin_invoke_simd(simd_func_call_helper, f,...);
// and returns <simd_func_call_helper, f> pair or
// <nullptr, nullptr> otherwise.
std::pair<Value *, Value *>
getHelperAndInvokeeIfInvokeSimdCall(const CallInst *CI) {
  Function *F = CI->getCalledFunction();

  if (F && F->getName().starts_with(esimd::INVOKE_SIMD_PREF)) {
    return {CI->getArgOperand(0), CI->getArgOperand(1)};
  }
  return {nullptr, nullptr};
}

// Deduce a single function whose address this value can only contain and
// return it, otherwise (if can't be deduced or multiple functions deduced)
// return nullptr.
Function *deduceFunction(Value *I, SmallPtrSetImpl<const Function *> &Visited) {
  while (true) {
    I = stripCasts(I);
    Function *Res = dyn_cast<Function>(I);

    if (Res) {
      return Res;
    }
    if (Argument *Arg = dyn_cast<Argument>(I)) {
      // invoke_simd target is a function pointer which came via a formal
      // parameter - do inter-procedural analysis trying to determine
      // actual target
      Function *Parent = Arg->getParent();

      if (!Visited.insert(Parent).second) {
        // alredy visited - recursion detected
        return nullptr;
      }
      // follow all calls to F and see what functions were passed as actual
      // argument
      for (const User *U : Parent->users()) {
        const auto *CI = dyn_cast<CallInst>(U);

        if (!CI || (CI->getCalledFunction() != Parent)) {
          llvm_unreachable("unsupported data flow pattern for invoke_simd A");
        }
        Value *ActualArg = CI->getArgOperand(Arg->getArgNo());
        Function *F = deduceFunction(ActualArg, Visited);

        if (!F || (Res && (Res != F))) {
          // deduction failed or a different (from previous iteration) function
          // deduced
          return nullptr;
        }
        Res = F;
      }
      return Res;
    }
    auto *LI = dyn_cast<LoadInst>(I);

    if (!LI) {
      // Function object can't be deduced
      break;
    }
    ValueSet Vals;
    Value *Addr = LI->getPointerOperand();

    auto PtrEscapes = [](const CallInst *CI) {
      // only __builtin_invoke_simd is allowed, otherwise the pointer escapes
      return getHelperAndInvokeeIfInvokeSimdCall(CI).first == nullptr;
    };
    if (!sycl::utils::collectPossibleStoredVals(Addr, Vals, PtrEscapes) ||
        Vals.size() != 1) {
      // data flow through the address of the load instruction is too
      // complicated
      break;
    }
    I = *Vals.begin();
  }
  return nullptr;
}

bool collectUsesLookTrhoughMemAndCasts(Value *V,
                                       SmallPtrSetImpl<const Use *> &Uses) {
  SmallPtrSet<const Use *, 4> TmpVUses;
  collectUsesLookThroughCasts(V, TmpVUses);

  for (const Use *U : TmpVUses) {
    User *UU = U->getUser();
    assert(!isCast(UU));

    auto *St = dyn_cast<StoreInst>(UU);

    if (!St) {
      Uses.insert(U);
      continue;
    }
    // Current user is a store (of V) instruction, see if...
    assert((V == St->getValueOperand()) &&
           "bad V param in collectUsesLookTrhoughMemAndCasts");
    Value *Addr = stripCasts(St->getPointerOperand());

    if (!isa<AllocaInst>(Addr)) {
      return false; // unsupported case of data flow through non-local memory
    }
    ValueSet StoredVals;

    auto PtrEscapes = [](const CallInst *CI) {
      // only __builtin_invoke_simd is allowed, otherwise the pointer escapes
      return getHelperAndInvokeeIfInvokeSimdCall(CI).first == nullptr;
    };
    // ... 1) V is the only possible stored value
    if (!collectPossibleStoredVals(Addr, StoredVals, PtrEscapes) ||
        (StoredVals.size() != 1) || (*StoredVals.begin() != V)) {
      return false;
    }
    // ... 2) What are uses of the values loaded from the address?
    SmallPtrSet<const Use *, 4> AddrUses;
    collectUsesLookThroughCasts(Addr, AddrUses);

    for (const Use *AddrU : AddrUses) {
      User *AddrUU = AddrU->getUser();
      assert(!isCast(AddrUU));

      if (isa<StoreInst>(AddrUU)) {
        if (AddrUU == St) {
          continue;
        }
        return false; // some unexpected store detected
      }
      if (isa<LoadInst>(AddrUU)) {
        collectUsesLookThroughCasts(AddrUU, Uses);
        continue;
      }
      return false; // some unexpected use of the store address detected
    }
  }
  return true;
}

// Shifts parameter attribute sets left by 1 and removes the rightmost.
// The index of the last paremter attribute in the result is NParams-2.
AttributeList removeFirstParamAttrAndShrink(LLVMContext &C, AttributeList Src,
                                            unsigned NParams) {
  AttributeList Dst;
  AttrBuilder FnAB(C, Src.getFnAttrs());
  Dst = Dst.addFnAttributes(C, FnAB);

  for (unsigned ParamNo = 0; ParamNo < NParams - 1; ++ParamNo) {
    AttrBuilder AB(C, Src.getParamAttrs(ParamNo + 1));
    Dst = Dst.addParamAttributes(C, ParamNo, AB);
  }
  return Dst;
}

// TODO W/a IGC crash on functions with names containing '.' (e.g.
// CloneFunction creates such name) - replace '.' with '_'.
void fixFunctionName(Function *F) {
  constexpr unsigned N = 5;
  std::string Name = F->getName().str();

  for (unsigned i = 0; (i < N) && (Name.find('.') != std::string::npos); ++i) {
    std::replace(Name.begin(), Name.end(), '.', '_');
    F->setName(Name); // setName can actually add '.<suff>' if non-unique
    Name = F->getName().str();
  }
  assert(Name.find('.') == std::string::npos);
}

void markFunctionAsESIMD(Function *F) {
  LLVMContext &C = F->getContext();

  if (!F->getMetadata(esimd::ESIMD_MARKER_MD)) {
    F->setMetadata(esimd::ESIMD_MARKER_MD, llvm::MDNode::get(C, {}));
  }
  if (!F->getMetadata(REQD_SUB_GROUP_SIZE_MD)) {
    auto One =
        ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(C), 1));
    F->setMetadata(REQD_SUB_GROUP_SIZE_MD, MDNode::get(C, One));
  }
}

void adjustAddressSpace(Function *F, uint32_t ArgNo, uint32_t ArgAddrSpace) {
  Argument *Arg = F->getArg(ArgNo);
  for (User *ArgUse : Arg->users()) {
    Instruction *Instr = dyn_cast<Instruction>(ArgUse);
    if (!Instr)
      continue;
    const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(ArgUse);
    if (ASC) {
      if (ASC->getDestAddressSpace() == ArgAddrSpace)
        continue;
    }

    const CallInst *CI = dyn_cast<CallInst>(ArgUse);
    if (CI) {
      Function *Callee = CI->getCalledFunction();
      if (!Callee || Callee->isDeclaration())
        continue;

      for (uint32_t i = 0; i < CI->getNumOperands(); ++i) {
        if (CI->getOperand(i) == Arg) {
          adjustAddressSpace(Callee, i, ArgAddrSpace);
        }
      }
    } else {
      for (unsigned int i = 0; i < ArgUse->getNumOperands(); ++i) {
        if (ArgUse->getOperand(i) == Arg) {
          PointerType *NPT = PointerType::get(Arg->getContext(), ArgAddrSpace);

          auto *NewInstr = new AddrSpaceCastInst(ArgUse->getOperand(i), NPT);
          NewInstr->insertBefore(Instr);
          NewInstr->setDebugLoc(Instr->getDebugLoc());

          ArgUse->setOperand(i, NewInstr);
        }
      }
    }
  }
}

// Process 'invoke_simd(sub_group_obj, f, spmd_args...);' call.
//
// If f is a function name or a function pointer, this call is lowered into
//   >  __builtin_invoke_simd(simd_func_call_helper, f, unwrap(spmd_args)...);
// where simd_func_call_helper is a helper function defined by the SYCL library
// which performs parameter conversion and then calls f via a pointer to f:
//   > Ret simd_func_call_helper(Callable f, T ... args) {
//   >   args_conv... = C++-convert(args)...;
//   >   return f(args_conv...);
//   > }
// This function tries to determine actual function (F) given the pointer f. If
// successful, it transforms the __builtin_invoke_simd invocation and the helper
// to git rid of the indirect call of f:
// - clone simd_func_call_helper into simd_func_call_helper_unique_clone and
//   transform it as follows:
//   > Ret simd_func_call_helper_unique_clone(T ... args) {
//   >   args_conv... = C++-convert(args)...;
//   >   return F(args_conv...);
//   > }
// - remove f in __builtin_invoke_simd invocation and redirect it to the clone:
//   > __builtin_invoke_simd(
//   >   simd_func_call_helper_unique_clone,
//   >   unwrap(spmd_args)...
//   > );
bool processInvokeSimdCall(CallInst *InvokeSimd,
                           SmallPtrSetImpl<Function *> &ClonedHelpers) {
  std::pair<Value *, Value *> HandI =
      getHelperAndInvokeeIfInvokeSimdCall(InvokeSimd);
  Value *H = HandI.first;
  Value *I = HandI.second;

  if (!H) {
    llvm_unreachable(
        ("bad use of " + Twine(esimd::INVOKE_SIMD_PREF)).str().c_str());
  }
  // "helper" defined in invoke_simd.hpp which converts arguments.
  auto *Helper = cast<Function>(H);
  // Mark helper as explicit SIMD function. Some BEs need this info.
  {
    markFunctionAsESIMD(Helper);
    // Fixup helper's linkage, which is linkonce_odr after the FE. It is dropped
    // from the ESIMD module after global DCE in post-link if not fixed up.
    Helper->setLinkage(GlobalValue::LinkageTypes::WeakODRLinkage);

    // VC backend requires the helper to always be marked VCStackCall
    if (!Helper->hasFnAttribute(llvm::genx::VCFunctionMD::VCStackCall)) {
      Helper->addFnAttr(llvm::genx::VCFunctionMD::VCStackCall);
    }
  }
  SmallPtrSet<const Function *, 8> Visited;
  Function *SimdF = deduceFunction(I, Visited);

  if (!SimdF) {
    // Call target is not known - don't do anything.
    return false;
  }
  if (!SimdF->hasFnAttribute(INVOKE_SIMD_DIRECT_TARGET_ATTR)) {
    SimdF->addFnAttr(INVOKE_SIMD_DIRECT_TARGET_ATTR);
  }

  if (!SimdF->isDeclaration()) {
    // The real arguments for invoke_simd callee start at index 2.
    for (uint32_t i = 2; i < InvokeSimd->arg_size(); ++i) {
      const Value *Arg = InvokeSimd->getArgOperand(i);
      if (Arg->getType()->isPointerTy()) {
        uint32_t AddressSpace = Arg->getType()->getPointerAddressSpace();
        if (AddressSpace == 4) {
          const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(Arg);
          if (!ASC)
            continue;

          AddressSpace =
              ASC->getOperand(0)->getType()->getPointerAddressSpace();
        }
        adjustAddressSpace(SimdF, i - 2, AddressSpace);
      }
    }
  }

  // The invoke_simd target is known at compile-time - optimize.
  // 1. find the call to f within the cloned helper - it is its first parameter
  constexpr unsigned SimdCallTargetArgNo = 0;
  SmallPtrSet<const Use *, 4> Uses;
  Argument *SimdCallTargetArg = Helper->getArg(SimdCallTargetArgNo);

  if (!collectUsesLookTrhoughMemAndCasts(SimdCallTargetArg, Uses)) {
    llvm_unreachable("simd_func_call_helper broken");
  }
  CallInst *TheCall = nullptr;

  for (const Use *U : Uses) {
    auto *CI = dyn_cast<CallInst>(U->getUser());

    if (!CI) {
      continue;
    }
    assert(CI->getCalledOperand() == U->get());
    assert(!TheCall && "unexpected multiple calls to SIMD target");
    TheCall = CI;
#ifndef NDEBUG
    break;
#endif
  }
  assert(TheCall && "simd_func_call_helper broken 1");

  // 2. Clone and transform simd_func_call_helper signature and code.
  Function *NewHelper = nullptr;
  {
    ValueToValueMapTy VMap;
    // mark the 'f' paremeter for deletion:
    VMap[SimdCallTargetArg] = PoisonValue::get(SimdCallTargetArg->getType());
    NewHelper = CloneFunction(Helper, VMap);
    // make the call to the user simd function direct:
    CallInst *TheTformedCall = cast<CallInst>(VMap[TheCall]);
    TheTformedCall->setCalledFunction(SimdF);
    fixFunctionName(NewHelper);
    // When we will do ESIMD split, that helper will be moved into ESIMD module
    // where it has no uses. To prevent it being internalized and killed by DCE
    // during post-split cleanup, we need to add this attribtue and set proper
    // linkage.
    NewHelper->addFnAttr("referenced-indirectly");
  }

  // 3. Clone and transform __builtin_invoke_simd call:
  //    - remove the 'f' formal parameter from the declaration
  //    - remove the 'f' actual argument (SIMD target function pointer)
  {
    // 3.1. Create a new declaration for the intrinsic (with 1 parameter less):
    constexpr unsigned HelperArgNo = 0;
    Function *InvokeSimdF = InvokeSimd->getCalledFunction();
    assert(InvokeSimdF && "Unexpected IR for invoke_simd");
    // - type of the obsolete (unmodified) helper:
    PointerType *HelperArgTy =
        cast<PointerType>(InvokeSimdF->getArg(HelperArgNo)->getType());
    unsigned AS = HelperArgTy->getAddressSpace();
    FunctionType *InvokeSimdFTy = InvokeSimdF->getFunctionType();
    // - create the list of new formal parameter types (the old one, with the
    //   second element removed):
    SmallVector<Type *, 8> NewArgTys;
    NewArgTys.push_back(PointerType::get(NewHelper->getFunctionType(), AS));
    std::copy(std::next(InvokeSimdFTy->param_begin(), 2),
              InvokeSimdFTy->param_end(), std::back_inserter(NewArgTys));
    FunctionType *NewInvokeSimdFTy =
        FunctionType::get(InvokeSimdFTy->getReturnType(), NewArgTys, false);
    // - create the new declaration:
    Function *NewInvokeSimdF =
        Function::Create(NewInvokeSimdFTy, InvokeSimdF->getLinkage(),
                         InvokeSimdF->getName(), InvokeSimdF->getParent());
    fixFunctionName(NewInvokeSimdF);
    LLVMContext &C = NewInvokeSimdF->getContext();
    // - truncate and shift-left parameter attributes in the new declaration:
    AttributeList NewAttrsF = removeFirstParamAttrAndShrink(
        C, InvokeSimdF->getAttributes(), InvokeSimdF->arg_size());
    NewInvokeSimdF->setAttributes(NewAttrsF);

    // 3.2. Create the new __builtin_invoke_simd call.
    // - create a list of new arguments (the old one with the second element
    //   removed):
    SmallVector<Value *, 4> NewInvokeSimdArgs;
    NewInvokeSimdArgs.push_back(NewHelper);
    auto ThirdArg = std::next(InvokeSimd->arg_begin(), 2);
    NewInvokeSimdArgs.append(ThirdArg, InvokeSimd->arg_end());
    CallInst *NewInvokeSimd = CallInst::Create(
        NewInvokeSimdF, NewInvokeSimdArgs, "", InvokeSimd->getIterator());
    // - transfer flags, attributes (with shrinking), calling convention:
    NewInvokeSimd->copyIRFlags(InvokeSimd);
    NewInvokeSimd->setCallingConv(InvokeSimd->getCallingConv());
    AttributeList NewAttrs = removeFirstParamAttrAndShrink(
        C, InvokeSimd->getAttributes(), InvokeSimd->arg_size());
    NewInvokeSimd->setAttributes(NewAttrs);

    InvokeSimd->replaceAllUsesWith(NewInvokeSimd);
    InvokeSimd->eraseFromParent();
  }
  ClonedHelpers.insert(Helper);
  return true;
}
} // namespace

namespace llvm {
PreservedAnalyses SYCLLowerInvokeSimdPass::run(Module &M,
                                               ModuleAnalysisManager &MAM) {
  bool Modified = false;
  // A set of encountered instances of the 'simd_func_call_helper' function
  // template defined in the invoke_simd.hpp, which are cloned. Those might no
  // longer be used after the pass finishes.
  SmallPtrSet<Function *, 4> ClonedHelpers;
  SetVector<CallInst *> ISCalls;

  for (Function &F : M) {
    if (!F.isDeclaration() ||
        !F.getName().starts_with(esimd::INVOKE_SIMD_PREF)) {
      continue;
    }
    SmallVector<User *, 4> Users(F.users());
    for (User *Usr : Users) {
      // a call can be the only use of the invoke_simd built-in
      CallInst *CI = cast<CallInst>(Usr);
      ISCalls.insert(CI);
    }
  }
  for (CallInst *CI : ISCalls) {
    Modified |= processInvokeSimdCall(CI, ClonedHelpers);
  }
  for (Function *F : ClonedHelpers) {
    if (F->getNumUses() == 0) {
      F->eraseFromParent();
    }
  }
  return Modified ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
} // namespace llvm
