MLIR 23.0.0git
NVVMRequiresSMTraits.h
Go to the documentation of this file.
1//===--- NVVMRequiresSMTraits.h - NVVM Requires SM Traits -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines op traits for the NVVM Dialect in MLIR
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
14#define NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
15
18#include "llvm/ADT/StringExtras.h"
19
20namespace mlir {
21
22namespace NVVM {
23
24// Struct to store and check compatibility of SM versions.
26 static constexpr char kArchAcceleratedSuffix = 'a';
27 static constexpr char kFamilySpecificSuffix = 'f';
28
29 // List of supported full SM versions.
30 // This is used to check compatibility with a target SM version.
31 // The full SM version is encoded as SM * 10 + ArchSuffixOffset where:
32 // - SM is the SM version (e.g., 100)
33 // - ArchSuffixOffset is 0 for base, 2 for family-specific, and 3 for
34 // architecture-accelerated
35 //
36 // For example, sm_100 is encoded as 1000 (100 * 10 + 0), sm_100f is encoded
37 // as 1002 (100 * 10 + 2) and sm_100a is encoded as 1003 (100 * 10 + 3).
39
40 template <typename... Versions>
41 NVVMCheckSMVersion(Versions... fullSmVersions)
42 : fullSmVersionList({fullSmVersions...}) {}
43
44 bool isCompatibleWith(const unsigned &targetFullSmVersion) const {
45 return llvm::any_of(
46 fullSmVersionList, [&](const unsigned &requiredFullSmVersion) {
47 if (hasArchAcceleratedFeatures(requiredFullSmVersion))
48 return hasArchAcceleratedFeatures(targetFullSmVersion) &&
49 (getSMVersion(targetFullSmVersion) ==
50 getSMVersion(requiredFullSmVersion));
51
52 if (hasFamilySpecificFeatures(requiredFullSmVersion))
53 return hasFamilySpecificFeatures(targetFullSmVersion) &&
54 (getSMFamily(targetFullSmVersion) ==
55 getSMFamily(requiredFullSmVersion)) &&
56 (getSMVersion(targetFullSmVersion) >=
57 getSMVersion(requiredFullSmVersion));
58
59 return targetFullSmVersion >= requiredFullSmVersion;
60 });
61 }
62
63 // Parses an SM version string and returns an equivalent full SM version
64 // integer.
65 static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString) {
66 bool isAA = smVersionString.back() == kArchAcceleratedSuffix;
67 bool isFS = smVersionString.back() == kFamilySpecificSuffix;
68
69 unsigned smVersion;
70 smVersionString.drop_front(3)
71 .take_while([](char c) { return llvm::isDigit(c); })
72 .getAsInteger(10, smVersion);
73
74 return smVersion * 10 + (isAA ? 3 : 0) + (isFS ? 2 : 0);
75 }
76
77 static bool isMinimumSMVersion(unsigned fullSmVersion) {
78 return getSMVersion(fullSmVersion) >= 20;
79 }
80
81private:
82 static bool hasFamilySpecificFeatures(unsigned fullSmVersion) {
83 return (fullSmVersion % 10) >= 2;
84 }
85
86 static bool hasArchAcceleratedFeatures(unsigned fullSmVersion) {
87 return (fullSmVersion % 10) == 3;
88 }
89
90 static unsigned getSMVersion(unsigned fullSmVersion) {
91 return fullSmVersion / 10;
92 }
93
94 static unsigned getSMFamily(unsigned fullSmVersion) {
95 return fullSmVersion / 100;
96 }
97};
98
99} // namespace NVVM
100} // namespace mlir
101
102#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h.inc"
103
104namespace mlir {
105
106namespace OpTrait {
107
108template <unsigned... FullSMVersions>
110public:
111 template <typename ConcreteOp>
112 class Impl
113 : public OpTrait::TraitBase<ConcreteOp,
114 NVVMRequiresSM<FullSMVersions...>::Impl>,
115 public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
116 public:
118 return NVVM::NVVMCheckSMVersion(FullSMVersions...);
119 }
120 };
121};
122} // namespace OpTrait
123} // namespace mlir
124#endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const
Helper class for implementing traits.
Include the generated interface declarations.
NVVMCheckSMVersion(Versions... fullSmVersions)
static constexpr char kArchAcceleratedSuffix
static bool isMinimumSMVersion(unsigned fullSmVersion)
static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString)
llvm::SmallVector< unsigned > fullSmVersionList
bool isCompatibleWith(const unsigned &targetFullSmVersion) const
static constexpr char kFamilySpecificSuffix