MLIR  21.0.0git
NVVMRequiresSMTraits.h
Go to the documentation of this file.
1 //===--- NVVMRequiresSMTraits.h - NVVM Requires SM Traits -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines op traits for the NVVM Dialect in MLIR
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
14 #define NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
15 
16 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/StringExtras.h"
19 
20 namespace mlir {
21 
22 namespace NVVM {
23 
24 // Struct to store and check compatibility of SM versions.
26  // Set to true if the SM version is accelerated (e.g., sm_90a).
28 
29  // List of SM versions.
30  // Typically only has one version except for cases where multiple
31  // arch-accelerated versions are supported.
32  // For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
34 
35  template <typename... Ints>
36  NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
37  : archAccelerated(archAccelerated), smVersionList({smVersions...}) {
38  assert((archAccelerated || smVersionList.size() == 1) &&
39  "non arch-accelerated SM version list must be a single version!");
40  }
41 
42  bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
43  assert(targetSM.smVersionList.size() == 1 &&
44  "target SM version list must be a single version!");
45 
46  if (archAccelerated) {
47  if (!targetSM.archAccelerated)
48  return false;
49 
50  for (auto version : smVersionList) {
51  if (version == targetSM.smVersionList[0])
52  return true;
53  }
54  } else {
55  return targetSM.smVersionList[0] >= smVersionList[0];
56  }
57 
58  return false;
59  }
60 
61  bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }
62 
63  // Parses an SM version string and returns an equivalent NVVMCheckSMVersion
64  // object.
65  static const NVVMCheckSMVersion
66  getTargetSMVersionFromStr(StringRef smVersionString) {
67  bool isAA = smVersionString.back() == 'a';
68 
69  int smVersionInt;
70  smVersionString.drop_front(3)
71  .take_while([](char c) { return llvm::isDigit(c); })
72  .getAsInteger(10, smVersionInt);
73 
74  return NVVMCheckSMVersion(isAA, smVersionInt);
75  }
76 };
77 
78 } // namespace NVVM
79 } // namespace mlir
80 
81 #include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h.inc"
82 
83 namespace mlir {
84 
85 namespace OpTrait {
86 
87 template <int MinVersion>
89 public:
90  template <typename ConcreteOp>
91  class Impl
92  : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
93  public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
94  public:
96  return NVVM::NVVMCheckSMVersion(false, MinVersion);
97  }
98  };
99 };
100 
101 template <int... SMVersions>
103 public:
104  template <typename ConcreteOp>
105  class Impl : public OpTrait::TraitBase<ConcreteOp,
106  NVVMRequiresSMa<SMVersions...>::Impl>,
107  public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
108  public:
110  return NVVM::NVVMCheckSMVersion(true, SMVersions...);
111  }
112  };
113 };
114 
115 } // namespace OpTrait
116 } // namespace mlir
117 #endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const
Helper class for implementing traits.
Definition: OpDefinition.h:377
Include the generated interface declarations.
NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
llvm::SmallVector< int, 1 > smVersionList
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)