MLIR 23.0.0git
TransformTypes.cpp
Go to the documentation of this file.
1//===- TransformTypes.cpp - Transform Dialect Type Definitions ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Builders.h"
16#include "mlir/IR/Types.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/Compiler.h"
19
20using namespace mlir;
21
22// These are automatically generated by ODS but are not used as the Transform
23// dialect uses a different dispatch mechanism to support dialect extensions.
24[[maybe_unused]] static OptionalParseResult
25generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
26[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
27 AsmPrinter &printer);
28
29#define GET_TYPEDEF_CLASSES
30#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
31
32void transform::TransformDialect::initializeTypes() {
33 addTypesChecked<
34#define GET_TYPEDEF_LIST
35#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
36 >();
37}
38
39//===----------------------------------------------------------------------===//
40// transform::AffineMapParamType
41//===----------------------------------------------------------------------===//
42
44transform::AffineMapParamType::checkPayload(Location loc,
45 ArrayRef<Attribute> payload) const {
46 for (Attribute attr : payload) {
47 if (!mlir::isa<AffineMapAttr>(attr)) {
48 return emitSilenceableError(loc)
49 << "expected affine map attribute, got " << attr;
50 }
51 }
53}
54
55//===----------------------------------------------------------------------===//
56// transform::AnyOpType
57//===----------------------------------------------------------------------===//
58
60transform::AnyOpType::checkPayload(Location loc,
61 ArrayRef<Operation *> payload) const {
63}
64
65//===----------------------------------------------------------------------===//
66// transform::AnyValueType
67//===----------------------------------------------------------------------===//
68
70transform::AnyValueType::checkPayload(Location loc,
71 ArrayRef<Value> payload) const {
73}
74
75//===----------------------------------------------------------------------===//
76// transform::OperationType
77//===----------------------------------------------------------------------===//
78
80transform::OperationType::checkPayload(Location loc,
81 ArrayRef<Operation *> payload) const {
82 OperationName opName(getOperationName(), loc.getContext());
83 for (Operation *op : payload) {
84 if (opName != op->getName()) {
86 emitSilenceableError(loc)
87 << "incompatible payload operation name expected " << opName << " vs "
88 << op->getName() << " -> " << *op;
89 diag.attachNote(op->getLoc()) << "payload operation";
90 return diag;
91 }
92 }
93
95}
96
97//===----------------------------------------------------------------------===//
98// transform::AnyParamType
99//===----------------------------------------------------------------------===//
100
102transform::AnyParamType::checkPayload(Location loc,
103 ArrayRef<Attribute> payload) const {
105}
106
107//===----------------------------------------------------------------------===//
108// transform::NormalizedOpType
109//===----------------------------------------------------------------------===//
110
112transform::NormalizedOpType::checkPayload(Location loc,
113 ArrayRef<Operation *> payload) const {
114 // Only check payloads that are not already guaranteeing the required forms.
115 SmallVector<Operation *> payloadsToCheck =
116 llvm::filter_to_vector(payload, [this](Operation *op) {
117 auto normalFormCheckedOp = dyn_cast<NormalFormCheckedOpInterface>(op);
118 if (!normalFormCheckedOp)
119 return true;
120
121 SmallVector<NormalFormAttrInterface> checkedNormalForms;
122 normalFormCheckedOp.getCheckedNormalForms(checkedNormalForms);
123 return !llvm::all_of(
124 this->getNormalForms(), [&](NormalFormAttrInterface form) {
125 return llvm::is_contained(checkedNormalForms, form);
126 });
127 });
128 return detail::checkNormalForms(getNormalForms(), payloadsToCheck);
129}
130
131LogicalResult transform::NormalizedOpType::verify(
134 return detail::verifyNormalFormList(emitError, normalForms);
135}
136
137//===----------------------------------------------------------------------===//
138// transform::ParamType
139//===----------------------------------------------------------------------===//
140
141LogicalResult
142transform::ParamType::verify(function_ref<InFlightDiagnostic()> emitError,
143 Type type) {
144 IntegerType intType = llvm::dyn_cast<IntegerType>(type);
145 if (!intType || intType.getWidth() > 64)
146 return emitError() << "only supports integer types with width <=64";
147 return success();
148}
149
151transform::ParamType::checkPayload(Location loc,
152 ArrayRef<Attribute> payload) const {
153 for (Attribute attr : payload) {
154 auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
155 if (!integerAttr) {
156 return emitSilenceableError(loc)
157 << "expected parameter to be an integer attribute, got " << attr;
158 }
159 if (integerAttr.getType() != getType()) {
160 return emitSilenceableError(loc)
161 << "expected the type of the parameter attribute ("
162 << integerAttr.getType() << ") to match the parameter type ("
163 << getType() << ")";
164 }
165 }
167}
168
169//===----------------------------------------------------------------------===//
170// transform::TypeParamType
171//===----------------------------------------------------------------------===//
172
174transform::TypeParamType::checkPayload(Location loc,
175 ArrayRef<Attribute> payload) const {
176 for (Attribute attr : payload) {
177 if (!mlir::isa<TypeAttr>(attr)) {
178 return emitSilenceableError(loc)
179 << "expected type attribute, got " << attr;
180 }
181 }
183}
return success()
static std::string diag(const llvm::Value &value)
static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value)
static LogicalResult generatedTypePrinter(Type def, AsmPrinter &printer)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition Attributes.h:25
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147