//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

// This header contains a preview of a portability system that enables
// CUDA C++ development with NVC++, NVCC, and supported host compilers.
// These interfaces are not guaranteed to be stable.

#ifndef __NV_TARGET_H
#define __NV_TARGET_H

#if defined(__NVCC__) || defined(__CUDACC_RTC__)
#  define _NV_COMPILER_NVCC
#elif defined(__NVCOMPILER) && __cplusplus >= 201103L
#  define _NV_COMPILER_NVCXX
#elif defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
// clang compiling CUDA code, device mode.
#  define _NV_COMPILER_CLANG_CUDA
#endif

#if (!defined(__ibmxl__)) \
  && ((defined(__cplusplus) && __cplusplus >= 201103L) || (defined(_MSC_VER) && _MSVC_LANG >= 201103L))
#  define _NV_TARGET_CPP11
#endif

// Hide `if target` support from NVRTC
#if defined(_NV_TARGET_CPP11) && !defined(__CUDACC_RTC__)

#  if defined(_NV_COMPILER_NVCXX)
#    define _NV_BITSET_ATTRIBUTE [[nv::__target_bitset]]
#  else
#    define _NV_BITSET_ATTRIBUTE
#  endif

namespace nv
{
namespace target
{
namespace detail
{

typedef unsigned long long base_int_t;

// No host specialization
constexpr base_int_t all_hosts = 1;

// NVIDIA GPUs
constexpr base_int_t sm_35_bit = 1 << 1;
constexpr base_int_t sm_37_bit = 1 << 2;
constexpr base_int_t sm_50_bit = 1 << 3;
constexpr base_int_t sm_52_bit = 1 << 4;
constexpr base_int_t sm_53_bit = 1 << 5;
constexpr base_int_t sm_60_bit = 1 << 6;
constexpr base_int_t sm_61_bit = 1 << 7;
constexpr base_int_t sm_62_bit = 1 << 8;
constexpr base_int_t sm_70_bit = 1 << 9;
constexpr base_int_t sm_72_bit = 1 << 10;
constexpr base_int_t sm_75_bit = 1 << 11;
constexpr base_int_t sm_80_bit = 1 << 12;
constexpr base_int_t sm_86_bit = 1 << 13;
constexpr base_int_t sm_87_bit = 1 << 14;
constexpr base_int_t sm_89_bit = 1 << 15;
constexpr base_int_t sm_90_bit = 1 << 16;
constexpr base_int_t all_devices =
  sm_35_bit | sm_37_bit | sm_50_bit | sm_52_bit | sm_53_bit | sm_60_bit | sm_61_bit | sm_62_bit | sm_70_bit | sm_72_bit
  | sm_75_bit | sm_80_bit | sm_86_bit | sm_87_bit | sm_89_bit | sm_90_bit;

// Store a set of targets as a set of bits
struct _NV_BITSET_ATTRIBUTE target_description
{
  base_int_t targets;

  constexpr target_description(base_int_t a)
      : targets(a)
  {}
};

// The type of the user-visible names of the NVIDIA GPU targets
enum class sm_selector : base_int_t
{
  sm_35 = 35,
  sm_37 = 37,
  sm_50 = 50,
  sm_52 = 52,
  sm_53 = 53,
  sm_60 = 60,
  sm_61 = 61,
  sm_62 = 62,
  sm_70 = 70,
  sm_72 = 72,
  sm_75 = 75,
  sm_80 = 80,
  sm_86 = 86,
  sm_87 = 87,
  sm_89 = 89,
  sm_90 = 90,
};

constexpr base_int_t toint(sm_selector a)
{
  return static_cast<base_int_t>(a);
}

constexpr base_int_t bitexact(sm_selector a)
{
  return toint(a) == 35 ? sm_35_bit
       : toint(a) == 37 ? sm_37_bit
       : toint(a) == 50 ? sm_50_bit
       : toint(a) == 52 ? sm_52_bit
       : toint(a) == 53 ? sm_53_bit
       : toint(a) == 60 ? sm_60_bit
       : toint(a) == 61 ? sm_61_bit
       : toint(a) == 62 ? sm_62_bit
       : toint(a) == 70 ? sm_70_bit
       : toint(a) == 72 ? sm_72_bit
       : toint(a) == 75 ? sm_75_bit
       : toint(a) == 80 ? sm_80_bit
       : toint(a) == 86 ? sm_86_bit
       : toint(a) == 87 ? sm_87_bit
       : toint(a) == 89 ? sm_89_bit
       : toint(a) == 90 ? sm_90_bit
                        : 0;
}

constexpr base_int_t bitrounddown(sm_selector a)
{
  return toint(a) >= 90 ? sm_90_bit
       : toint(a) >= 89 ? sm_89_bit
       : toint(a) >= 87 ? sm_87_bit
       : toint(a) >= 86 ? sm_86_bit
       : toint(a) >= 80 ? sm_80_bit
       : toint(a) >= 75 ? sm_75_bit
       : toint(a) >= 72 ? sm_72_bit
       : toint(a) >= 70 ? sm_70_bit
       : toint(a) >= 62 ? sm_62_bit
       : toint(a) >= 61 ? sm_61_bit
       : toint(a) >= 60 ? sm_60_bit
       : toint(a) >= 53 ? sm_53_bit
       : toint(a) >= 52 ? sm_52_bit
       : toint(a) >= 50 ? sm_50_bit
       : toint(a) >= 37 ? sm_37_bit
       : toint(a) >= 35 ? sm_35_bit
                        : 0;
}

// Public API for NVIDIA GPUs

constexpr target_description is_exactly(sm_selector a)
{
  return target_description(bitexact(a));
}

constexpr target_description provides(sm_selector a)
{
  return target_description(~(bitrounddown(a) - 1) & all_devices);
}

// Boolean operations on target sets

constexpr target_description operator&&(target_description a, target_description b)
{
  return target_description(a.targets & b.targets);
}

constexpr target_description operator||(target_description a, target_description b)
{
  return target_description(a.targets | b.targets);
}

constexpr target_description operator!(target_description a)
{
  return target_description(~a.targets & (all_devices | all_hosts));
}
} // namespace detail

using detail::sm_selector;
using detail::target_description;

// The predicates for basic host/device selection
constexpr target_description is_host    = target_description(detail::all_hosts);
constexpr target_description is_device  = target_description(detail::all_devices);
constexpr target_description any_target = target_description(detail::all_hosts | detail::all_devices);
constexpr target_description no_target  = target_description(0);

// The public names for NVIDIA GPU architectures
constexpr sm_selector sm_35 = sm_selector::sm_35;
constexpr sm_selector sm_37 = sm_selector::sm_37;
constexpr sm_selector sm_50 = sm_selector::sm_50;
constexpr sm_selector sm_52 = sm_selector::sm_52;
constexpr sm_selector sm_53 = sm_selector::sm_53;
constexpr sm_selector sm_60 = sm_selector::sm_60;
constexpr sm_selector sm_61 = sm_selector::sm_61;
constexpr sm_selector sm_62 = sm_selector::sm_62;
constexpr sm_selector sm_70 = sm_selector::sm_70;
constexpr sm_selector sm_72 = sm_selector::sm_72;
constexpr sm_selector sm_75 = sm_selector::sm_75;
constexpr sm_selector sm_80 = sm_selector::sm_80;
constexpr sm_selector sm_86 = sm_selector::sm_86;
constexpr sm_selector sm_87 = sm_selector::sm_87;
constexpr sm_selector sm_89 = sm_selector::sm_89;
constexpr sm_selector sm_90 = sm_selector::sm_90;

using detail::is_exactly;
using detail::provides;
} // namespace target
} // namespace nv

#endif // C++11  && !defined(__CUDACC_RTC__)

#include <nv/detail/__target_macros>

#endif // __NV_TARGET_H
