MLIR 22.0.0git
NVVMRequiresSMTraits.h
Go to the documentation of this file.
1//===--- NVVMRequiresSMTraits.h - NVVM Requires SM Traits -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines op traits for the NVVM Dialect in MLIR
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
14#define NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
15
18#include "llvm/ADT/StringExtras.h"
19
20namespace mlir {
21
22namespace NVVM {
23
24// Struct to store and check compatibility of SM versions.
26 // Set to true if the SM version is accelerated (e.g., sm_90a).
28
29 // List of SM versions.
30 // Typically only has one version except for cases where multiple
31 // arch-accelerated versions are supported.
32 // For example, tcgen05.shift is supported on sm_100a, sm_101a, and sm_103a.
34
35 template <typename... Ints>
36 NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
37 : archAccelerated(archAccelerated), smVersionList({smVersions...}) {
38 assert((archAccelerated || smVersionList.size() == 1) &&
39 "non arch-accelerated SM version list must be a single version!");
40 }
41
42 bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const {
43 assert(targetSM.smVersionList.size() == 1 &&
44 "target SM version list must be a single version!");
45
46 if (archAccelerated) {
47 if (!targetSM.archAccelerated)
48 return false;
49
50 for (auto version : smVersionList) {
51 if (version == targetSM.smVersionList[0])
52 return true;
53 }
54 } else {
55 return targetSM.smVersionList[0] >= smVersionList[0];
56 }
57
58 return false;
59 }
60
61 bool isMinimumSMVersion() const { return smVersionList[0] >= 20; }
62
63 // Parses an SM version string and returns an equivalent NVVMCheckSMVersion
64 // object.
65 static const NVVMCheckSMVersion
66 getTargetSMVersionFromStr(StringRef smVersionString) {
67 bool isAA = smVersionString.back() == 'a';
68
69 int smVersionInt;
70 smVersionString.drop_front(3)
71 .take_while([](char c) { return llvm::isDigit(c); })
72 .getAsInteger(10, smVersionInt);
73
74 return NVVMCheckSMVersion(isAA, smVersionInt);
75 }
76};
77
78} // namespace NVVM
79} // namespace mlir
80
81#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h.inc"
82
83namespace mlir {
84
85namespace OpTrait {
86
87template <int MinVersion>
89public:
90 template <typename ConcreteOp>
91 class Impl
92 : public OpTrait::TraitBase<ConcreteOp, NVVMRequiresSM<MinVersion>::Impl>,
94 public:
96 return NVVM::NVVMCheckSMVersion(false, MinVersion);
97 }
98 };
99};
100
101template <int... SMVersions>
103public:
104 template <typename ConcreteOp>
105 class Impl : public OpTrait::TraitBase<ConcreteOp,
106 NVVMRequiresSMa<SMVersions...>::Impl>,
107 public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
108 public:
110 return NVVM::NVVMCheckSMVersion(true, SMVersions...);
111 }
112 };
113};
114
115} // namespace OpTrait
116} // namespace mlir
117#endif // NVVM_DIALECT_NVVM_IR_NVVMREQUIRESSMTRAITS_H_
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const
Helper class for implementing traits.
Include the generated interface declarations.
NVVMCheckSMVersion(bool archAccelerated, Ints... smVersions)
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
llvm::SmallVector< int, 1 > smVersionList
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)