MLIR 23.0.0git
ROCDLDialect.cpp
Go to the documentation of this file.
1//===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the ROCDL IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The ROCDL dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
18
21#include "mlir/IR/Builders.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Operation.h"
27#include "llvm/ADT/StringExtras.h"
28#include "llvm/ADT/StringRef.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32
33using namespace mlir;
34using namespace ROCDL;
35
36#include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.cpp.inc"
37#include "mlir/Dialect/LLVMIR/ROCDLOpsEnums.cpp.inc"
38
39//===----------------------------------------------------------------------===//
40// ROCDLDialect initialization, type parsing, and registration.
41//===----------------------------------------------------------------------===//
42
43namespace {
44struct ROCDLInlinerInterface final : DialectInlinerInterface {
45 using DialectInlinerInterface::DialectInlinerInterface;
46 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
47 return true;
48 }
49};
50} // namespace
51
52// TODO: This should be the llvm.rocdl dialect once this is supported.
53void ROCDLDialect::initialize() {
54 addOperations<
55#define GET_OP_LIST
56#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
57 >();
58
59 addAttributes<
60#define GET_ATTRDEF_LIST
61#include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc"
62 >();
63
64 // Support unknown operations because not all ROCDL operations are registered.
65 allowUnknownOperations();
66 addInterfaces<ROCDLInlinerInterface>();
67 declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>();
68}
69
70LLVM::ModFlagBehavior
71BufferOOBModeModuleFlagAttr::getModuleFlagBehavior() const {
72 return LLVM::ModFlagBehavior::Max;
73}
74
75StringAttr BufferOOBModeModuleFlagAttr::getModuleFlagKey() const {
76 return StringAttr::get(getContext(),
77 ROCDLDialect::getModuleFlagKeyBufferOOBModeName());
78}
79
80Attribute BufferOOBModeModuleFlagAttr::getModuleFlagValue() const {
81 return BufferOOBModeAttr::get(getContext(), getValue());
82}
83
84LLVM::ModFlagBehavior
85TBufferOOBModeModuleFlagAttr::getModuleFlagBehavior() const {
86 return LLVM::ModFlagBehavior::Max;
87}
88
89StringAttr TBufferOOBModeModuleFlagAttr::getModuleFlagKey() const {
90 return StringAttr::get(getContext(),
91 ROCDLDialect::getModuleFlagKeyTBufferOOBModeName());
92}
93
94Attribute TBufferOOBModeModuleFlagAttr::getModuleFlagValue() const {
95 return BufferOOBModeAttr::get(getContext(), getValue());
96}
97
98LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
99 NamedAttribute attr) {
100 // Kernel function attribute should be attached to functions.
101 if (kernelAttrName.getName() == attr.getName()) {
102 if (!isa<LLVM::LLVMFuncOp>(op)) {
103 return op->emitError() << "'" << kernelAttrName.getName()
104 << "' attribute attached to unexpected op";
105 }
106 }
107 return success();
108}
109
110//===----------------------------------------------------------------------===//
111// ROCDL op custom parsers/printers.
112//===----------------------------------------------------------------------===//
113
114template <typename EnumAttrT, typename EnumT>
115static ParseResult parseCachePolicyEnum(OpAsmParser &parser,
116 Attribute &cachePolicy) {
117 if (parser.parseLess())
118 return failure();
119 FailureOr<EnumT> parsed = FieldParser<EnumT>::parse(parser);
120 if (failed(parsed))
121 return failure();
122 if (parser.parseGreater())
123 return failure();
124 cachePolicy = EnumAttrT::get(parser.getContext(), *parsed);
125 return success();
126}
127
128static ParseResult parseCachePolicy(OpAsmParser &parser,
129 Attribute &cachePolicy) {
130 uint32_t rawValue;
131 OptionalParseResult rawValueParseResult =
132 parser.parseOptionalInteger(rawValue);
133 if (rawValueParseResult.has_value()) {
134 if (failed(*rawValueParseResult))
135 return failure();
136 cachePolicy =
137 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), rawValue);
138 return success();
139 }
140
141 StringRef policyFamily;
142 auto loc = parser.getCurrentLocation();
143 if (failed(parser.parseOptionalKeyword(
144 &policyFamily, {"pre_gfx12", "gfx942", "gfx12", "gfx12_atomic"}))) {
145 return parser.emitError(loc)
146 << "expected cache policy family 'pre_gfx12', 'gfx942', 'gfx12', "
147 "'gfx12_atomic', or a 32-bit integer";
148 }
149
150 if (policyFamily == "pre_gfx12")
152 parser, cachePolicy);
153 if (policyFamily == "gfx942")
155 parser, cachePolicy);
156 if (policyFamily == "gfx12")
158 parser, cachePolicy);
159 return parseCachePolicyEnum<Gfx12AtomicCachePolicyAttr,
160 Gfx12AtomicCachePolicy>(parser, cachePolicy);
161}
162
163template <typename EnumAttrT>
164static void printCachePolicyEnum(OpAsmPrinter &printer, EnumAttrT cachePolicy,
165 StringRef family) {
166 printer << family << "<" << cachePolicy.getValue() << ">";
167}
168
170 Attribute cachePolicy) {
171 llvm::TypeSwitch<Attribute>(cachePolicy)
172 .Case<IntegerAttr>([&](IntegerAttr rawPolicy) {
173 printer << rawPolicy.getValue().getZExtValue();
174 })
175 .Case<PreGfx12CachePolicyAttr>([&](PreGfx12CachePolicyAttr policy) {
176 printCachePolicyEnum(printer, policy, "pre_gfx12");
177 })
178 .Case<Gfx942CachePolicyAttr>([&](Gfx942CachePolicyAttr policy) {
179 printCachePolicyEnum(printer, policy, "gfx942");
180 })
181 .Case<Gfx12CachePolicyAttr>([&](Gfx12CachePolicyAttr policy) {
182 printCachePolicyEnum(printer, policy, "gfx12");
183 })
184 .Case<Gfx12AtomicCachePolicyAttr>([&](Gfx12AtomicCachePolicyAttr policy) {
185 printCachePolicyEnum(printer, policy, "gfx12_atomic");
186 })
187 .DefaultUnreachable("unknown ROCDL cache policy attribute");
188}
189
190//===----------------------------------------------------------------------===//
191// ROCDL target attribute.
192//===----------------------------------------------------------------------===//
193LogicalResult
194ROCDLTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
195 int optLevel, StringRef triple, StringRef chip,
196 StringRef features, StringRef abiVersion,
197 DictionaryAttr flags, ArrayAttr files) {
198 if (optLevel < 0 || optLevel > 3) {
199 emitError() << "The optimization level must be a number between 0 and 3.";
200 return failure();
201 }
202 if (triple.empty()) {
203 emitError() << "The target triple cannot be empty.";
204 return failure();
205 }
206 if (chip.empty()) {
207 emitError() << "The target chip cannot be empty.";
208 return failure();
209 }
210 if (abiVersion != "400" && abiVersion != "500" && abiVersion != "600") {
211 emitError() << "Invalid ABI version, it must be `400`, `500` or '600'.";
212 return failure();
213 }
214 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
215 return mlir::isa_and_nonnull<StringAttr>(attr);
216 })) {
217 emitError() << "All the elements in the `link` array must be strings.";
218 return failure();
219 }
220 return success();
221}
222
223#define GET_OP_CLASSES
224#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
225
226#define GET_ATTRDEF_CLASSES
227#include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc"
return success()
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
ArrayAttr()
b getContext())
static ParseResult parseCachePolicyEnum(OpAsmParser &parser, Attribute &cachePolicy)
static ParseResult parseCachePolicy(OpAsmParser &parser, Attribute &cachePolicy)
static void printCachePolicyEnum(OpAsmPrinter &printer, EnumAttrT cachePolicy, StringRef family)
static void printCachePolicy(OpAsmPrinter &printer, Operation *, Attribute cachePolicy)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseGreater()=0
Parse a '>' token.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents a diagnostic that is inflight and set to be reported.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
Provide a template class that can be specialized by users to dispatch to parsers.