MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Matchers.h"
24#include "mlir/IR/SymbolTable.h"
27
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cstddef>
39#include <iterator>
40#include <optional>
41#include <variant>
42
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
47
48using namespace mlir;
49using namespace mlir::omp;
50
53 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
54}
55
58 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
59}
60
63 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
64}
65
66namespace {
67struct MemRefPointerLikeModel
68 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 MemRefType> {
70 Type getElementType(Type pointer) const {
71 return llvm::cast<MemRefType>(pointer).getElementType();
72 }
73};
74
75struct LLVMPointerPointerLikeModel
76 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
78 Type getElementType(Type pointer) const { return Type(); }
79};
80} // namespace
81
82/// Generate a name of a canonical loop nest of the format
83/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
84/// argument index of an operation that has multiple regions, if the operation
85/// has multiple regions.
86/// `_s<idx>` identifies the position of an operation within a region, where
87/// only operations that may potentially contain loops ("container operations"
88/// i.e. have region arguments) are counted. Again, it is omitted if there is
89/// only one such operation in a region. If there are canonical loops nested
90/// inside each other, also may also use the format `_d<num>` where <num> is the
91/// nesting depth of the loop.
92///
93/// The generated name is a best-effort to make canonical loop unique within an
94/// SSA namespace. This also means that regions with IsolatedFromAbove property
95/// do not consider any parents or siblings.
96static std::string generateLoopNestingName(StringRef prefix,
97 CanonicalLoopOp op) {
98 struct Component {
99 /// If true, this component describes a region operand of an operation (the
100 /// operand's owner) If false, this component describes an operation located
101 /// in a parent region
102 bool isRegionArgOfOp;
103 bool skip = false;
104 bool isUnique = false;
105
106 size_t idx;
107 Operation *op;
108 Region *parentRegion;
109 size_t loopDepth;
110
111 Operation *&getOwnerOp() {
112 assert(isRegionArgOfOp && "Must describe a region operand");
113 return op;
114 }
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp && "Must describe a region operand");
117 return idx;
118 }
119
120 Operation *&getContainerOp() {
121 assert(!isRegionArgOfOp && "Must describe a operation of a region");
122 return op;
123 }
124 size_t &getOpPos() {
125 assert(!isRegionArgOfOp && "Must describe a operation of a region");
126 return idx;
127 }
128 bool isLoopOp() const {
129 assert(!isRegionArgOfOp && "Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
131 }
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp && "Must describe a operation of a region");
134 return parentRegion;
135 }
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp && "Must describe a operation of a region");
138 return loopDepth;
139 }
140
141 void skipIf(bool v = true) { skip = skip || v; }
142 };
143
144 // List of ancestors, from inner to outer.
145 // Alternates between
146 // * region argument of an operation
147 // * operation within a region
148 SmallVector<Component> components;
149
150 // Gather a list of parent regions and operations, and the position within
151 // their parent
152 Operation *o = op.getOperation();
153 while (o) {
154 // Operation within a region
155 Region *r = o->getParentRegion();
156 if (!r)
157 break;
158
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
160 size_t idx = 0;
161 bool found = false;
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp = true;
164 for (Block *b : traversal) {
165 for (Operation &op : *b) {
166 if (&op == o && !found) {
167 sequentialIdx = idx;
168 found = true;
169 }
170 if (op.getNumRegions()) {
171 idx += 1;
172 if (idx > 1)
173 isOnlyContainerOp = false;
174 }
175 if (found && !isOnlyContainerOp)
176 break;
177 }
178 }
179
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp = false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
186
187 Operation *parent = r->getParentOp();
188
189 // Region argument of an operation
190 Component &regionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp = true;
192 regionArgOfOperation.isUnique = true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
195
196 // The IsolatedFromAbove trait of the parent operation implies that each
197 // individual region argument has its own separate namespace, so no
198 // ambiguity.
199 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
200 break;
201
202 // Component only needed if operation has multiple region operands. Region
203 // arguments may be optional, but we currently do not consider this.
204 if (parent->getRegions().size() > 1) {
205 auto getRegionIndex = [](Operation *o, Region *r) {
206 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
207 if (&region == r)
208 return idx;
209 }
210 llvm_unreachable("Region not child of its parent operation");
211 };
212 regionArgOfOperation.isUnique = false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
214 }
215
216 // next parent
217 o = parent;
218 }
219
220 // Determine whether a region-argument component is not needed
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
223
224 // Find runs of nested loops and determine each loop's depth in the loop nest
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
227 if (c.skip)
228 continue;
229
230 // non-skipped multi-argument operands interrupt the loop nest
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
233 continue;
234 }
235
236 // Multiple loops in a region means each of them is the outermost loop of a
237 // new loop nest
238 if (!c.isUnique)
239 numSurroundingLoops = 0;
240
241 c.getLoopDepth() = numSurroundingLoops;
242
243 // Next loop is surrounded by one more loop
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
246 }
247
248 // In loop nests, skip all but the innermost loop that contains the depth
249 // number
250 bool isLoopNest = false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
253 continue;
254
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
256 // Innermost loop of a loop nest of at least two loops
257 isLoopNest = true;
258 } else if (isLoopNest) {
259 // Non-innermost loop of a loop nest
260 c.skipIf(c.isUnique);
261
262 // If there is no surrounding loop left, this must have been the outermost
263 // loop; leave loop-nest mode for the next iteration
264 if (c.getLoopDepth() == 0)
265 isLoopNest = false;
266 }
267 }
268
269 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
270 // we mark them as skipped only after computing loop nests)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
274
275 // Components can be skipped if they are already disambiguated by their parent
276 // (or does not have a parent)
277 bool newRegion = true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
280
281 // non-skipped components disambiguate unique children
282 if (!c.skip)
283 newRegion = true;
284
285 // ...except canonical loops that need a suffix for each nest
286 if (!c.isRegionArgOfOp && c.getContainerOp())
287 newRegion = false;
288 }
289
290 // Compile the nesting name string
291 SmallString<64> Name{prefix};
292 llvm::raw_svector_ostream NameOS(Name);
293 for (auto &c : llvm::reverse(components)) {
294 if (c.skip)
295 continue;
296
297 if (c.isRegionArgOfOp)
298 NameOS << "_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS << "_d" << c.getLoopDepth();
301 else
302 NameOS << "_s" << c.getOpPos();
303 }
304
305 return NameOS.str().str();
306}
307
308void OpenMPDialect::initialize() {
309 addOperations<
310#define GET_OP_LIST
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
312 >();
313 addAttributes<
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
316 >();
317 addTypes<
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
320 >();
321
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
323
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
326 *getContext());
327
328 // Attach default offload module interface to module op to access
329 // offload functionality through
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
331 *getContext());
332
333 // Attach default declare target interfaces to operations which can be marked
334 // as declare target (Global Operations and Functions/Subroutines in dialects
335 // that Fortran (or other languages that lower to MLIR) translates too
336 mlir::LLVM::GlobalOp::attachInterface<
338 *getContext());
339 mlir::LLVM::LLVMFuncOp::attachInterface<
341 *getContext());
342 mlir::func::FuncOp::attachInterface<
344}
345
346//===----------------------------------------------------------------------===//
347// Parser and printer for Allocate Clause
348//===----------------------------------------------------------------------===//
349
350/// Parse an allocate clause with allocators and a list of operands with types.
351///
352/// allocate-operand-list :: = allocate-operand |
353/// allocator-operand `,` allocate-operand-list
354/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
355/// ssa-id-and-type ::= ssa-id `:` type
356static ParseResult parseAllocateAndAllocator(
357 OpAsmParser &parser,
359 SmallVectorImpl<Type> &allocateTypes,
361 SmallVectorImpl<Type> &allocatorTypes) {
362
363 return parser.parseCommaSeparatedList([&]() {
365 Type type;
366 if (parser.parseOperand(operand) || parser.parseColonType(type))
367 return failure();
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
370 if (parser.parseArrow())
371 return failure();
372 if (parser.parseOperand(operand) || parser.parseColonType(type))
373 return failure();
374
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
377 return success();
378 });
379}
380
381/// Print allocate clause
383 OperandRange allocateVars,
384 TypeRange allocateTypes,
385 OperandRange allocatorVars,
386 TypeRange allocatorTypes) {
387 for (unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
389 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
390 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
391 }
392}
393
394//===----------------------------------------------------------------------===//
395// Parser and printer for a clause attribute (StringEnumAttr)
396//===----------------------------------------------------------------------===//
397
398template <typename ClauseAttr>
399static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
400 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
401 StringRef enumStr;
402 SMLoc loc = parser.getCurrentLocation();
403 if (parser.parseKeyword(&enumStr))
404 return failure();
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.getContext(), *enumValue);
407 return success();
408 }
409 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
410}
411
412template <typename ClauseAttr>
413static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
414 p << stringifyEnum(attr.getValue());
415}
416
417//===----------------------------------------------------------------------===//
418// Parser and printer for Linear Clause
419//===----------------------------------------------------------------------===//
420
421/// linear ::= `linear` `(` linear-list `)`
422/// linear-list := linear-val | linear-val linear-list
423/// linear-val := ssa-id-and-type `=` ssa-id-and-type
424/// | `val` `(` ssa-id-and-type `=` ssa-id-and-type `)`
425/// | `ref` `(` ssa-id-and-type `=` ssa-id-and-type `)`
426/// | `uval` `(` ssa-id-and-type `=` ssa-id-and-type `)`
427static ParseResult parseLinearClause(
428 OpAsmParser &parser,
430 SmallVectorImpl<Type> &linearTypes,
432 SmallVectorImpl<Type> &linearStepTypes, ArrayAttr &linearModifiers) {
433 SmallVector<Attribute> modifiers;
434 auto result = parser.parseCommaSeparatedList([&]() {
436 Type type, stepType;
438
439 std::optional<omp::LinearModifier> linearModifier;
440 if (succeeded(parser.parseOptionalKeyword("val"))) {
441 linearModifier = omp::LinearModifier::val;
442 } else if (succeeded(parser.parseOptionalKeyword("ref"))) {
443 linearModifier = omp::LinearModifier::ref;
444 } else if (succeeded(parser.parseOptionalKeyword("uval"))) {
445 linearModifier = omp::LinearModifier::uval;
446 }
447
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.parseLParen())
450 return failure();
451
452 if (parser.parseOperand(var) || parser.parseColonType(type) ||
453 parser.parseEqual() || parser.parseOperand(stepVar) ||
454 parser.parseColonType(stepType))
455 return failure();
456
457 if (hasLinearModifierParens && parser.parseRParen())
458 return failure();
459
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
465 modifiers.push_back(
466 omp::LinearModifierAttr::get(parser.getContext(), *linearModifier));
467 } else {
468 modifiers.push_back(UnitAttr::get(parser.getContext()));
469 }
470 return success();
471 });
472 if (failed(result))
473 return failure();
474 linearModifiers = ArrayAttr::get(parser.getContext(), modifiers);
475 return success();
476}
477
478/// Print Linear Clause
480 ValueRange linearVars, TypeRange linearTypes,
481 ValueRange linearStepVars, TypeRange stepVarTypes,
482 ArrayAttr linearModifiers) {
483 size_t linearVarsSize = linearVars.size();
484 for (unsigned i = 0; i < linearVarsSize; ++i) {
485 if (i != 0)
486 p << ", ";
487 // Print modifier keyword wrapper if present.
488 Attribute modAttr = linearModifiers ? linearModifiers[i] : nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) : nullptr;
490 if (mod) {
491 p << omp::stringifyLinearModifier(mod.getValue()) << "(";
492 }
493 p << linearVars[i] << " : " << linearTypes[i];
494 p << " = " << linearStepVars[i] << " : " << stepVarTypes[i];
495 if (mod)
496 p << ")";
497 }
498}
499
500//===----------------------------------------------------------------------===//
501// Verifier for Linear modifier
502//===----------------------------------------------------------------------===//
503
504/// OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or
505/// uval only on a declare simd directive."
506/// Also verifies that modifier count matches variable count.
507static LogicalResult
508verifyLinearModifiers(Operation *op, std::optional<ArrayAttr> linearModifiers,
509 OperandRange linearVars, bool isDeclareSimd = false) {
510 if (!linearModifiers)
511 return success();
512 if (linearModifiers->size() != linearVars.size())
513 return op->emitOpError()
514 << "expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (Attribute attr : *linearModifiers) {
517 if (!attr)
518 continue;
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
520 if (!modAttr)
521 continue;
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
524 return op->emitOpError()
525 << "linear modifier '" << omp::stringifyLinearModifier(mod)
526 << "' may only be specified on a declare simd directive";
527 }
528 }
529 return success();
530}
531
532//===----------------------------------------------------------------------===//
533// Verifier for Nontemporal Clause
534//===----------------------------------------------------------------------===//
535
536static LogicalResult verifyNontemporalClause(Operation *op,
537 OperandRange nontemporalVars) {
538
539 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
540 DenseSet<Value> nontemporalItems;
541 for (const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->emitOpError() << "nontemporal variable used more than once";
544
545 return success();
546}
547
548//===----------------------------------------------------------------------===//
549// Parser, verifier and printer for Aligned Clause
550//===----------------------------------------------------------------------===//
551static LogicalResult verifyAlignedClause(Operation *op,
552 std::optional<ArrayAttr> alignments,
553 OperandRange alignedVars) {
554 // Check if number of alignment values equals to number of aligned variables
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
557 return op->emitOpError()
558 << "expected as many alignment values as aligned variables";
559 } else {
560 if (alignments)
561 return op->emitOpError() << "unexpected alignment values attribute";
562 return success();
563 }
564
565 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
566 DenseSet<Value> alignedItems;
567 for (auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->emitOpError() << "aligned variable used more than once";
570
571 if (!alignments)
572 return success();
573
574 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
575 for (unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->emitOpError() << "alignment should be greater than 0";
579 } else {
580 return op->emitOpError() << "expected integer alignment";
581 }
582 }
583
584 return success();
585}
586
587/// aligned ::= `aligned` `(` aligned-list `)`
588/// aligned-list := aligned-val | aligned-val aligned-list
589/// aligned-val := ssa-id-and-type `->` alignment
590static ParseResult
593 SmallVectorImpl<Type> &alignedTypes,
594 ArrayAttr &alignmentsAttr) {
595 SmallVector<Attribute> alignmentVec;
596 if (failed(parser.parseCommaSeparatedList([&]() {
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
601 return failure();
602 }
603 return success();
604 })))
605 return failure();
606 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
608 return success();
609}
610
611/// Print Aligned Clause
613 ValueRange alignedVars, TypeRange alignedTypes,
614 std::optional<ArrayAttr> alignments) {
615 for (unsigned i = 0; i < alignedVars.size(); ++i) {
616 if (i != 0)
617 p << ", ";
618 p << alignedVars[i] << " : " << alignedVars[i].getType();
619 p << " -> " << (*alignments)[i];
620 }
621}
622
623//===----------------------------------------------------------------------===//
624// Parser, printer and verifier for Schedule Clause
625//===----------------------------------------------------------------------===//
626
627static ParseResult
629 SmallVectorImpl<SmallString<12>> &modifiers) {
630 if (modifiers.size() > 2)
631 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
632 for (const auto &mod : modifiers) {
633 // Translate the string. If it has no value, then it was not a valid
634 // modifier!
635 auto symbol = symbolizeScheduleModifier(mod);
636 if (!symbol)
637 return parser.emitError(parser.getNameLoc())
638 << " unknown modifier type: " << mod;
639 }
640
641 // If we have one modifier that is "simd", then stick a "none" modiifer in
642 // index 0.
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
647 }
648 } else if (modifiers.size() == 2) {
649 // If there are two modifier:
650 // First modifier should not be simd, second one should be simd
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
653 return parser.emitError(parser.getNameLoc())
654 << " incorrect modifier order";
655 }
656 return success();
657}
658
659/// schedule ::= `schedule` `(` sched-list `)`
660/// sched-list ::= sched-val | sched-val sched-list |
661/// sched-val `,` sched-modifier
662/// sched-val ::= sched-with-chunk | sched-wo-chunk
663/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
664/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
665/// sched-wo-chunk ::= `auto` | `runtime`
666/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
667/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
668static ParseResult
669parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
672 Type &chunkType) {
673 StringRef keyword;
674 if (parser.parseKeyword(&keyword))
675 return failure();
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
678 if (!schedule)
679 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
680
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
682 switch (*schedule) {
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
686 if (succeeded(parser.parseOptionalEqual())) {
687 chunkSize = OpAsmParser::UnresolvedOperand{};
688 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
689 return failure();
690 } else {
691 chunkSize = std::nullopt;
692 }
693 break;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
698 }
699
700 // If there is a comma, we have one or more modifiers..
702 while (succeeded(parser.parseOptionalComma())) {
703 StringRef mod;
704 if (parser.parseKeyword(&mod))
705 return failure();
706 modifiers.push_back(mod);
707 }
708
709 if (verifyScheduleModifiers(parser, modifiers))
710 return failure();
711
712 if (!modifiers.empty()) {
713 SMLoc loc = parser.getCurrentLocation();
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
717 } else {
718 return parser.emitError(loc, "invalid schedule modifier");
719 }
720 // Only SIMD attribute is allowed here!
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
723 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
724 }
725 }
726
727 return success();
728}
729
730/// Print schedule clause
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd, Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
737 if (scheduleChunk)
738 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
739 if (scheduleMod)
740 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
741 if (scheduleSimd)
742 p << ", simd";
743}
744
745//===----------------------------------------------------------------------===//
746// Parser and printer for Order Clause
747//===----------------------------------------------------------------------===//
748
749// order ::= `order` `(` [order-modifier ':'] concurrent `)`
750// order-modifier ::= reproducible | unconstrained
751static ParseResult parseOrderClause(OpAsmParser &parser,
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
754 StringRef enumStr;
755 SMLoc loc = parser.getCurrentLocation();
756 if (parser.parseKeyword(&enumStr))
757 return failure();
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
761 if (parser.parseOptionalColon())
762 return failure();
763 loc = parser.getCurrentLocation();
764 if (parser.parseKeyword(&enumStr))
765 return failure();
766 }
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
770 return success();
771 }
772 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
773}
774
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
778 if (orderMod)
779 p << stringifyOrderModifier(orderMod.getValue()) << ":";
780 if (order)
781 p << stringifyClauseOrderKind(order.getValue());
782}
783
784template <typename ClauseTypeAttr, typename ClauseType>
785static ParseResult
786parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
788 Type &operandType,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
791 StringRef enumStr;
792 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
795 if (parser.parseComma())
796 return failure();
797 } else {
798 return parser.emitError(parser.getCurrentLocation())
799 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
800 ;
801 }
802 }
803
805 if (succeeded(parser.parseOperand(var))) {
806 operand = var;
807 } else {
808 return parser.emitError(parser.getCurrentLocation())
809 << "expected " << clauseName << " operand";
810 }
811
812 if (operand.has_value()) {
813 if (parser.parseColonType(operandType))
814 return failure();
815 }
816
817 return success();
818}
819
820template <typename ClauseTypeAttr, typename ClauseType>
821static void
823 ClauseTypeAttr prescriptiveness, Value operand,
824 mlir::Type operandType,
825 StringRef (*stringifyClauseType)(ClauseType)) {
826
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
829
830 if (operand)
831 p << operand << ": " << operandType;
832}
833
834//===----------------------------------------------------------------------===//
835// Parser and printer for grainsize Clause
836//===----------------------------------------------------------------------===//
837
838// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
839static ParseResult
840parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType, "grainsize");
846}
847
849 ClauseGrainsizeTypeAttr grainsizeMod,
850 Value grainsize, mlir::Type grainsizeType) {
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
854}
855
856//===----------------------------------------------------------------------===//
857// Parser and printer for num_tasks Clause
858//===----------------------------------------------------------------------===//
859
860// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
861static ParseResult
862parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
867 "num_tasks");
868}
869
871 ClauseNumTasksTypeAttr numTasksMod,
872 Value numTasks, mlir::Type numTasksType) {
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
875}
876
877//===----------------------------------------------------------------------===//
878// Parser and printer for Heap Alloc Clause
879//===----------------------------------------------------------------------===//
880
881/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
882static ParseResult parseHeapAllocClause(
883 OpAsmParser &parser, TypeAttr &inTypeAttr,
885 SmallVectorImpl<Type> &typeparamsTypes,
887 SmallVectorImpl<Type> &shapeTypes) {
888 mlir::Type inType;
889 if (parser.parseType(inType))
890 return mlir::failure();
891 inTypeAttr = TypeAttr::get(inType);
892
893 if (!parser.parseOptionalLParen()) {
894 // parse the LEN params of the derived type. (<params> : <types>)
895 if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) ||
896 parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen())
897 return failure();
898 }
899
900 if (!parser.parseOptionalComma()) {
901 // parse size to scale by, vector of n dimensions of type index
903 return failure();
904
905 // TODO: This overrides the actual types of the operands, which might cause
906 // issues when they don't match. At the moment this is done in place of
907 // making the corresponding operand type `Variadic<Index>` because index
908 // types are lowered to I64 prior to LLVM IR translation.
909 shapeTypes.append(shape.size(), IndexType::get(parser.getContext()));
910 }
911
912 return success();
913}
914
916 TypeAttr inType, ValueRange typeparams,
917 TypeRange typeparamsTypes, ValueRange shape,
918 TypeRange shapeTypes) {
919 p << inType;
920 if (!typeparams.empty()) {
921 p << '(' << typeparams << " : " << typeparamsTypes << ')';
922 }
923 for (auto sh : shape) {
924 p << ", ";
925 p.printOperand(sh);
926 }
927}
928
929//===----------------------------------------------------------------------===//
930// Parser, printer and verify for dyn_groupprivate Clause
931//===----------------------------------------------------------------------===//
932
933static LogicalResult
934verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup,
935 FallbackModifierAttr fallback,
936 Value dynGroupprivateSize) {
937 if (!dynGroupprivateSize && (accessGroup || fallback))
938 return op->emitOpError("dyn_groupprivate modifiers require a size operand");
939
940 return success();
941}
942
943static ParseResult parseDynGroupprivateClause(
944 OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
945 FallbackModifierAttr &fallbackAttr,
946 std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
947 Type &sizeType) {
948
949 bool parsedAccessGroup = false;
950 bool parsedFallback = false;
951 bool parsedSize = false;
952
953 return parser.parseCommaSeparatedList([&]() -> ParseResult {
954 // Parse AccessGroupModifier.
955 if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
956 if (parsedAccessGroup)
957 return parser.emitError(parser.getCurrentLocation(),
958 "duplicate access group modifier");
959 accessGroupAttr = AccessGroupModifierAttr::get(
960 parser.getContext(), AccessGroupModifier::cgroup);
961 parsedAccessGroup = true;
962 return success();
963 }
964 // Parse FallbackModifier.
965 if (succeeded(parser.parseOptionalKeyword("fallback"))) {
966 if (parsedFallback)
967 return parser.emitError(parser.getCurrentLocation(),
968 "duplicate fallback modifier");
969 if (parser.parseLParen())
970 return parser.emitError(parser.getCurrentLocation(),
971 "expected '(' after 'fallback'");
972 llvm::StringRef fbKind;
973 if (parser.parseKeyword(&fbKind))
974 return parser.emitError(
975 parser.getCurrentLocation(),
976 "expected fallback modifier (abort/null/default_mem)");
977 std::optional<FallbackModifier> fbEnum;
978 if (fbKind == "abort")
979 fbEnum = FallbackModifier::abort;
980 else if (fbKind == "null")
981 fbEnum = FallbackModifier::null;
982 else if (fbKind == "default_mem")
983 fbEnum = FallbackModifier::default_mem;
984 else
985 return parser.emitError(parser.getCurrentLocation(),
986 "invalid fallback modifier '" + fbKind + "'");
987 fallbackAttr = FallbackModifierAttr::get(parser.getContext(), *fbEnum);
988 if (parser.parseRParen())
989 return parser.emitError(parser.getCurrentLocation(),
990 "expected ')' after fallback modifier");
991 parsedFallback = true;
992 return success();
993 }
994 // Parse size operand.
996 if (succeeded(parser.parseOperand(operand))) {
997 if (parsedSize)
998 return parser.emitError(parser.getCurrentLocation(),
999 "duplicate size operand");
1000 dynGroupprivateSize = operand;
1001 parsedSize = true;
1002 if (failed(parser.parseColon()) || failed(parser.parseType(sizeType)))
1003 return parser.emitError(parser.getCurrentLocation(),
1004 "expected ':' and type after size operand");
1005 return success();
1006 }
1007 return parser.emitError(parser.getCurrentLocation(),
1008 "expected dyn_groupprivate_size operand");
1009 });
1010}
1011
1013 AccessGroupModifierAttr modifierFirst,
1014 FallbackModifierAttr modifierSecond,
1015 Value dynGroupprivateSize,
1016 Type sizeType) {
1017
1018 bool needsComma = false;
1019
1020 if (modifierFirst) {
1021 printer << modifierFirst.getValue();
1022 needsComma = true;
1023 }
1024
1025 if (modifierSecond) {
1026 if (needsComma)
1027 printer << ", ";
1028 printer << "fallback(";
1029 printer << modifierSecond.getValue();
1030 printer << ")";
1031 needsComma = true;
1032 }
1033
1034 if (dynGroupprivateSize) {
1035 if (needsComma)
1036 printer << ", ";
1037 printer << dynGroupprivateSize << " : " << sizeType;
1038 }
1039}
1040
1041//===----------------------------------------------------------------------===//
1042// Parsers for operations including clauses that define entry block arguments.
1043//===----------------------------------------------------------------------===//
1044
1045namespace {
1046struct MapParseArgs {
1047 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1048 SmallVectorImpl<Type> &types;
1049 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1050 SmallVectorImpl<Type> &types)
1051 : vars(vars), types(types) {}
1052};
1053struct PrivateParseArgs {
1054 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1055 llvm::SmallVectorImpl<Type> &types;
1056 ArrayAttr &syms;
1057 UnitAttr &needsBarrier;
1058 DenseI64ArrayAttr *mapIndices;
1059 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1060 SmallVectorImpl<Type> &types, ArrayAttr &syms,
1061 UnitAttr &needsBarrier,
1062 DenseI64ArrayAttr *mapIndices = nullptr)
1063 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1064 mapIndices(mapIndices) {}
1065};
1066
1067struct ReductionParseArgs {
1068 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1069 SmallVectorImpl<Type> &types;
1070 DenseBoolArrayAttr &byref;
1071 ArrayAttr &syms;
1072 ReductionModifierAttr *modifier;
1073 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1074 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
1075 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
1076 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1077};
1078
1079struct AllRegionParseArgs {
1080 std::optional<MapParseArgs> hasDeviceAddrArgs;
1081 std::optional<MapParseArgs> hostEvalArgs;
1082 std::optional<ReductionParseArgs> inReductionArgs;
1083 std::optional<MapParseArgs> mapArgs;
1084 std::optional<PrivateParseArgs> privateArgs;
1085 std::optional<ReductionParseArgs> reductionArgs;
1086 std::optional<ReductionParseArgs> taskReductionArgs;
1087 std::optional<MapParseArgs> useDeviceAddrArgs;
1088 std::optional<MapParseArgs> useDevicePtrArgs;
1089};
1090} // namespace
1091
1092static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
1093 return "private_barrier";
1094}
1095
1096static ParseResult parseClauseWithRegionArgs(
1097 OpAsmParser &parser,
1099 SmallVectorImpl<Type> &types,
1100 SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
1101 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
1102 DenseBoolArrayAttr *byref = nullptr,
1103 ReductionModifierAttr *modifier = nullptr,
1104 UnitAttr *needsBarrier = nullptr) {
1106 SmallVector<int64_t> mapIndicesVec;
1107 SmallVector<bool> isByRefVec;
1108 unsigned regionArgOffset = regionPrivateArgs.size();
1109
1110 if (parser.parseLParen())
1111 return failure();
1112
1113 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
1114 StringRef enumStr;
1115 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
1116 parser.parseComma())
1117 return failure();
1118 std::optional<ReductionModifier> enumValue =
1119 symbolizeReductionModifier(enumStr);
1120 if (!enumValue.has_value())
1121 return failure();
1122 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
1123 if (!*modifier)
1124 return failure();
1125 }
1126
1127 if (parser.parseCommaSeparatedList([&]() {
1128 if (byref)
1129 isByRefVec.push_back(
1130 parser.parseOptionalKeyword("byref").succeeded());
1131
1132 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
1133 return failure();
1134
1135 if (parser.parseOperand(operands.emplace_back()) ||
1136 parser.parseArrow() ||
1137 parser.parseArgument(regionPrivateArgs.emplace_back()))
1138 return failure();
1139
1140 if (mapIndices) {
1141 if (parser.parseOptionalLSquare().succeeded()) {
1142 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
1143 parser.parseInteger(mapIndicesVec.emplace_back()) ||
1144 parser.parseRSquare())
1145 return failure();
1146 } else {
1147 mapIndicesVec.push_back(-1);
1148 }
1149 }
1150
1151 return success();
1152 }))
1153 return failure();
1154
1155 if (parser.parseColon())
1156 return failure();
1157
1158 if (parser.parseCommaSeparatedList([&]() {
1159 if (parser.parseType(types.emplace_back()))
1160 return failure();
1161
1162 return success();
1163 }))
1164 return failure();
1165
1166 if (operands.size() != types.size())
1167 return failure();
1168
1169 if (parser.parseRParen())
1170 return failure();
1171
1172 if (needsBarrier) {
1174 .succeeded())
1175 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
1176 }
1177
1178 auto *argsBegin = regionPrivateArgs.begin();
1179 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
1180 argsBegin + regionArgOffset + types.size());
1181 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1182 prv.type = type;
1183 }
1184
1185 if (symbols) {
1186 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
1187 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
1188 }
1189
1190 if (!mapIndicesVec.empty())
1191 *mapIndices =
1192 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
1193
1194 if (byref)
1195 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
1196
1197 return success();
1198}
1199
1200static ParseResult parseBlockArgClause(
1201 OpAsmParser &parser,
1203 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1204 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1205 if (!mapArgs)
1206 return failure();
1207
1208 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
1209 entryBlockArgs)))
1210 return failure();
1211 }
1212 return success();
1213}
1214
1215static ParseResult parseBlockArgClause(
1216 OpAsmParser &parser,
1218 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1219 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1220 if (!privateArgs)
1221 return failure();
1222
1223 if (failed(parseClauseWithRegionArgs(
1224 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1225 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1226 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
1227 return failure();
1228 }
1229 return success();
1230}
1231
1232static ParseResult parseBlockArgClause(
1233 OpAsmParser &parser,
1235 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1236 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1237 if (!reductionArgs)
1238 return failure();
1239 if (failed(parseClauseWithRegionArgs(
1240 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1241 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1242 reductionArgs->modifier)))
1243 return failure();
1244 }
1245 return success();
1246}
1247
1248static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1249 AllRegionParseArgs args) {
1251
1252 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1253 args.hasDeviceAddrArgs)))
1254 return parser.emitError(parser.getCurrentLocation())
1255 << "invalid `has_device_addr` format";
1256
1257 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1258 args.hostEvalArgs)))
1259 return parser.emitError(parser.getCurrentLocation())
1260 << "invalid `host_eval` format";
1261
1262 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1263 args.inReductionArgs)))
1264 return parser.emitError(parser.getCurrentLocation())
1265 << "invalid `in_reduction` format";
1266
1267 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1268 args.mapArgs)))
1269 return parser.emitError(parser.getCurrentLocation())
1270 << "invalid `map_entries` format";
1271
1272 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1273 args.privateArgs)))
1274 return parser.emitError(parser.getCurrentLocation())
1275 << "invalid `private` format";
1276
1277 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1278 args.reductionArgs)))
1279 return parser.emitError(parser.getCurrentLocation())
1280 << "invalid `reduction` format";
1281
1282 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1283 args.taskReductionArgs)))
1284 return parser.emitError(parser.getCurrentLocation())
1285 << "invalid `task_reduction` format";
1286
1287 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1288 args.useDeviceAddrArgs)))
1289 return parser.emitError(parser.getCurrentLocation())
1290 << "invalid `use_device_addr` format";
1291
1292 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1293 args.useDevicePtrArgs)))
1294 return parser.emitError(parser.getCurrentLocation())
1295 << "invalid `use_device_addr` format";
1296
1297 return parser.parseRegion(region, entryBlockArgs);
1298}
1299
1300// These parseXyz functions correspond to the custom<Xyz> definitions
1301// in the .td file(s).
1302static ParseResult parseTargetOpRegion(
1303 OpAsmParser &parser, Region &region,
1305 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1307 SmallVectorImpl<Type> &hostEvalTypes,
1309 SmallVectorImpl<Type> &inReductionTypes,
1310 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1312 SmallVectorImpl<Type> &mapTypes,
1314 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1315 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1316 AllRegionParseArgs args;
1317 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1318 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1319 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1320 inReductionByref, inReductionSyms);
1321 args.mapArgs.emplace(mapVars, mapTypes);
1322 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1323 privateNeedsBarrier, &privateMaps);
1324 return parseBlockArgRegion(parser, region, args);
1325}
1326
1328 OpAsmParser &parser, Region &region,
1330 SmallVectorImpl<Type> &inReductionTypes,
1331 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1333 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1334 UnitAttr &privateNeedsBarrier) {
1335 AllRegionParseArgs args;
1336 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1337 inReductionByref, inReductionSyms);
1338 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1339 privateNeedsBarrier);
1340 return parseBlockArgRegion(parser, region, args);
1341}
1342
1344 OpAsmParser &parser, Region &region,
1346 SmallVectorImpl<Type> &inReductionTypes,
1347 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1349 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1350 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1352 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1353 ArrayAttr &reductionSyms) {
1354 AllRegionParseArgs args;
1355 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1356 inReductionByref, inReductionSyms);
1357 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1358 privateNeedsBarrier);
1359 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1360 reductionSyms, &reductionMod);
1361 return parseBlockArgRegion(parser, region, args);
1362}
1363
1364static ParseResult parsePrivateRegion(
1365 OpAsmParser &parser, Region &region,
1367 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1368 UnitAttr &privateNeedsBarrier) {
1369 AllRegionParseArgs args;
1370 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1371 privateNeedsBarrier);
1372 return parseBlockArgRegion(parser, region, args);
1373}
1374
1376 OpAsmParser &parser, Region &region,
1378 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1379 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1381 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1382 ArrayAttr &reductionSyms) {
1383 AllRegionParseArgs args;
1384 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1385 privateNeedsBarrier);
1386 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1387 reductionSyms, &reductionMod);
1388 return parseBlockArgRegion(parser, region, args);
1389}
1390
1391static ParseResult parseTaskReductionRegion(
1392 OpAsmParser &parser, Region &region,
1394 SmallVectorImpl<Type> &taskReductionTypes,
1395 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1396 AllRegionParseArgs args;
1397 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1398 taskReductionByref, taskReductionSyms);
1399 return parseBlockArgRegion(parser, region, args);
1400}
1401
1403 OpAsmParser &parser, Region &region,
1405 SmallVectorImpl<Type> &useDeviceAddrTypes,
1407 SmallVectorImpl<Type> &useDevicePtrTypes) {
1408 AllRegionParseArgs args;
1409 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1410 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1411 return parseBlockArgRegion(parser, region, args);
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// Printers for operations including clauses that define entry block arguments.
1416//===----------------------------------------------------------------------===//
1417
1418namespace {
1419struct MapPrintArgs {
1420 ValueRange vars;
1421 TypeRange types;
1422 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1423};
1424struct PrivatePrintArgs {
1425 ValueRange vars;
1426 TypeRange types;
1427 ArrayAttr syms;
1428 UnitAttr needsBarrier;
1429 DenseI64ArrayAttr mapIndices;
1430 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1431 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1432 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1433 mapIndices(mapIndices) {}
1434};
1435struct ReductionPrintArgs {
1436 ValueRange vars;
1437 TypeRange types;
1438 DenseBoolArrayAttr byref;
1439 ArrayAttr syms;
1440 ReductionModifierAttr modifier;
1441 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1442 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1443 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1444};
1445struct AllRegionPrintArgs {
1446 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1447 std::optional<MapPrintArgs> hostEvalArgs;
1448 std::optional<ReductionPrintArgs> inReductionArgs;
1449 std::optional<MapPrintArgs> mapArgs;
1450 std::optional<PrivatePrintArgs> privateArgs;
1451 std::optional<ReductionPrintArgs> reductionArgs;
1452 std::optional<ReductionPrintArgs> taskReductionArgs;
1453 std::optional<MapPrintArgs> useDeviceAddrArgs;
1454 std::optional<MapPrintArgs> useDevicePtrArgs;
1455};
1456} // namespace
1457
1459 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1460 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1461 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1462 DenseBoolArrayAttr byref = nullptr,
1463 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1464 if (argsSubrange.empty())
1465 return;
1466
1467 p << clauseName << "(";
1468
1469 if (modifier)
1470 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1471
1472 if (!symbols) {
1473 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1474 symbols = ArrayAttr::get(ctx, values);
1475 }
1476
1477 if (!mapIndices) {
1478 llvm::SmallVector<int64_t> values(operands.size(), -1);
1479 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1480 }
1481
1482 if (!byref) {
1483 mlir::SmallVector<bool> values(operands.size(), false);
1484 byref = DenseBoolArrayAttr::get(ctx, values);
1485 }
1486
1487 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1488 mapIndices.asArrayRef(),
1489 byref.asArrayRef()),
1490 p, [&p](auto t) {
1491 auto [op, arg, sym, map, isByRef] = t;
1492 if (isByRef)
1493 p << "byref ";
1494 if (sym)
1495 p << sym << " ";
1496
1497 p << op << " -> " << arg;
1498
1499 if (map != -1)
1500 p << " [map_idx=" << map << "]";
1501 });
1502 p << " : ";
1503 llvm::interleaveComma(types, p);
1504 p << ") ";
1505
1506 if (needsBarrier)
1507 p << getPrivateNeedsBarrierSpelling() << " ";
1508}
1509
1511 StringRef clauseName, ValueRange argsSubrange,
1512 std::optional<MapPrintArgs> mapArgs) {
1513 if (mapArgs)
1514 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1515 mapArgs->types);
1516}
1517
1519 StringRef clauseName, ValueRange argsSubrange,
1520 std::optional<PrivatePrintArgs> privateArgs) {
1521 if (privateArgs)
1523 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1524 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1525 /*modifier=*/nullptr, privateArgs->needsBarrier);
1526}
1527
1528static void
1529printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1530 ValueRange argsSubrange,
1531 std::optional<ReductionPrintArgs> reductionArgs) {
1532 if (reductionArgs)
1533 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1534 reductionArgs->vars, reductionArgs->types,
1535 reductionArgs->syms, /*mapIndices=*/nullptr,
1536 reductionArgs->byref, reductionArgs->modifier);
1537}
1538
1540 const AllRegionPrintArgs &args) {
1541 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1542 MLIRContext *ctx = op->getContext();
1543
1544 printBlockArgClause(p, ctx, "has_device_addr",
1545 iface.getHasDeviceAddrBlockArgs(),
1546 args.hasDeviceAddrArgs);
1547 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1548 args.hostEvalArgs);
1549 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1550 args.inReductionArgs);
1551 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1552 args.mapArgs);
1553 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1554 args.privateArgs);
1555 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1556 args.reductionArgs);
1557 printBlockArgClause(p, ctx, "task_reduction",
1558 iface.getTaskReductionBlockArgs(),
1559 args.taskReductionArgs);
1560 printBlockArgClause(p, ctx, "use_device_addr",
1561 iface.getUseDeviceAddrBlockArgs(),
1562 args.useDeviceAddrArgs);
1563 printBlockArgClause(p, ctx, "use_device_ptr",
1564 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1565
1566 p.printRegion(region, /*printEntryBlockArgs=*/false);
1567}
1568
1569// These parseXyz functions correspond to the custom<Xyz> definitions
1570// in the .td file(s).
1572 OpAsmPrinter &p, Operation *op, Region &region,
1573 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1574 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1575 ValueRange inReductionVars, TypeRange inReductionTypes,
1576 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1577 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1578 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1579 DenseI64ArrayAttr privateMaps) {
1580 AllRegionPrintArgs args;
1581 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1582 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1583 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1584 inReductionByref, inReductionSyms);
1585 args.mapArgs.emplace(mapVars, mapTypes);
1586 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1587 privateNeedsBarrier, privateMaps);
1588 printBlockArgRegion(p, op, region, args);
1589}
1590
1592 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1593 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1594 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1595 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1596 AllRegionPrintArgs args;
1597 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1598 inReductionByref, inReductionSyms);
1599 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1600 privateNeedsBarrier,
1601 /*mapIndices=*/nullptr);
1602 printBlockArgRegion(p, op, region, args);
1603}
1604
1606 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1607 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1608 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1609 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1610 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1611 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1612 ArrayAttr reductionSyms) {
1613 AllRegionPrintArgs args;
1614 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1615 inReductionByref, inReductionSyms);
1616 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1617 privateNeedsBarrier,
1618 /*mapIndices=*/nullptr);
1619 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1620 reductionSyms, reductionMod);
1621 printBlockArgRegion(p, op, region, args);
1622}
1623
1625 ValueRange privateVars, TypeRange privateTypes,
1626 ArrayAttr privateSyms,
1627 UnitAttr privateNeedsBarrier) {
1628 AllRegionPrintArgs args;
1629 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1630 privateNeedsBarrier,
1631 /*mapIndices=*/nullptr);
1632 printBlockArgRegion(p, op, region, args);
1633}
1634
1636 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1637 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1638 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1639 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1640 ArrayAttr reductionSyms) {
1641 AllRegionPrintArgs args;
1642 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1643 privateNeedsBarrier,
1644 /*mapIndices=*/nullptr);
1645 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1646 reductionSyms, reductionMod);
1647 printBlockArgRegion(p, op, region, args);
1648}
1649
1651 Region &region,
1652 ValueRange taskReductionVars,
1653 TypeRange taskReductionTypes,
1654 DenseBoolArrayAttr taskReductionByref,
1655 ArrayAttr taskReductionSyms) {
1656 AllRegionPrintArgs args;
1657 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1658 taskReductionByref, taskReductionSyms);
1659 printBlockArgRegion(p, op, region, args);
1660}
1661
1663 Region &region,
1664 ValueRange useDeviceAddrVars,
1665 TypeRange useDeviceAddrTypes,
1666 ValueRange useDevicePtrVars,
1667 TypeRange useDevicePtrTypes) {
1668 AllRegionPrintArgs args;
1669 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1670 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1671 printBlockArgRegion(p, op, region, args);
1672}
1673
1674template <typename ParsePrefixFn>
1675static ParseResult parseSplitIteratedList(
1676 OpAsmParser &parser,
1678 SmallVectorImpl<Type> &iteratedTypes,
1680 SmallVectorImpl<Type> &plainTypes, ParsePrefixFn &&parsePrefix) {
1681
1682 return parser.parseCommaSeparatedList([&]() -> ParseResult {
1683 if (failed(parsePrefix()))
1684 return failure();
1685
1687 Type ty;
1688 if (parser.parseOperand(v) || parser.parseColonType(ty))
1689 return failure();
1690
1691 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1692 iteratedVars.push_back(v);
1693 iteratedTypes.push_back(ty);
1694 } else {
1695 plainVars.push_back(v);
1696 plainTypes.push_back(ty);
1697 }
1698 return success();
1699 });
1700}
1701
1702template <typename PrintPrefixFn>
1704 TypeRange iteratedTypes,
1705 ValueRange plainVars, TypeRange plainTypes,
1706 PrintPrefixFn &&printPrefixForPlain,
1707 PrintPrefixFn &&printPrefixForIterated) {
1708
1709 bool first = true;
1710 auto emit = [&](Value v, Type t, auto &&printPrefix) {
1711 if (!first)
1712 p << ", ";
1713 printPrefix(v, t);
1714 p << v << " : " << t;
1715 first = false;
1716 };
1717
1718 for (unsigned i = 0; i < iteratedVars.size(); ++i)
1719 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1720 for (unsigned i = 0; i < plainVars.size(); ++i)
1721 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1722}
1723
1724/// Verifies Reduction Clause
1725static LogicalResult
1726verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1727 OperandRange reductionVars,
1728 std::optional<ArrayRef<bool>> reductionByref) {
1729 if (!reductionVars.empty()) {
1730 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1731 return op->emitOpError()
1732 << "expected as many reduction symbol references "
1733 "as reduction variables";
1734 if (reductionByref && reductionByref->size() != reductionVars.size())
1735 return op->emitError() << "expected as many reduction variable by "
1736 "reference attributes as reduction variables";
1737 } else {
1738 if (reductionSyms)
1739 return op->emitOpError() << "unexpected reduction symbol references";
1740 return success();
1741 }
1742
1743 // TODO: The followings should be done in
1744 // SymbolUserOpInterface::verifySymbolUses.
1745 DenseSet<Value> accumulators;
1746 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1747 Value accum = std::get<0>(args);
1748
1749 if (!accumulators.insert(accum).second)
1750 return op->emitOpError() << "accumulator variable used more than once";
1751
1752 Type varType = accum.getType();
1753 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1754 auto decl =
1756 if (!decl)
1757 return op->emitOpError() << "expected symbol reference " << symbolRef
1758 << " to point to a reduction declaration";
1759
1760 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1761 return op->emitOpError()
1762 << "expected accumulator (" << varType
1763 << ") to be the same type as reduction declaration ("
1764 << decl.getAccumulatorType() << ")";
1765 }
1766
1767 return success();
1768}
1769
1770//===----------------------------------------------------------------------===//
1771// Parser, printer and verifier for Copyprivate
1772//===----------------------------------------------------------------------===//
1773
1774/// copyprivate-entry-list ::= copyprivate-entry
1775/// | copyprivate-entry-list `,` copyprivate-entry
1776/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1777static ParseResult parseCopyprivate(
1778 OpAsmParser &parser,
1780 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1782 if (failed(parser.parseCommaSeparatedList([&]() {
1783 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1784 parser.parseArrow() ||
1785 parser.parseAttribute(symsVec.emplace_back()) ||
1786 parser.parseColonType(copyprivateTypes.emplace_back()))
1787 return failure();
1788 return success();
1789 })))
1790 return failure();
1791 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1792 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1793 return success();
1794}
1795
1796/// Print Copyprivate clause
1798 OperandRange copyprivateVars,
1799 TypeRange copyprivateTypes,
1800 std::optional<ArrayAttr> copyprivateSyms) {
1801 if (!copyprivateSyms.has_value())
1802 return;
1803 llvm::interleaveComma(
1804 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1805 [&](const auto &args) {
1806 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1807 << std::get<2>(args);
1808 });
1809}
1810
1811/// Verifies CopyPrivate Clause
1812static LogicalResult
1814 std::optional<ArrayAttr> copyprivateSyms) {
1815 size_t copyprivateSymsSize =
1816 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1817 if (copyprivateSymsSize != copyprivateVars.size())
1818 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1819 << copyprivateVars.size()
1820 << ") and functions (= " << copyprivateSymsSize
1821 << "), both must be equal";
1822 if (!copyprivateSyms.has_value())
1823 return success();
1824
1825 for (auto copyprivateVarAndSym :
1826 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1827 auto symbolRef =
1828 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1829 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1830 funcOp;
1831 if (mlir::func::FuncOp mlirFuncOp =
1833 symbolRef))
1834 funcOp = mlirFuncOp;
1835 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1837 op, symbolRef))
1838 funcOp = llvmFuncOp;
1839
1840 auto getNumArguments = [&] {
1841 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1842 };
1843
1844 auto getArgumentType = [&](unsigned i) {
1845 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1846 *funcOp);
1847 };
1848
1849 if (!funcOp)
1850 return op->emitOpError() << "expected symbol reference " << symbolRef
1851 << " to point to a copy function";
1852
1853 if (getNumArguments() != 2)
1854 return op->emitOpError()
1855 << "expected copy function " << symbolRef << " to have 2 operands";
1856
1857 Type argTy = getArgumentType(0);
1858 if (argTy != getArgumentType(1))
1859 return op->emitOpError() << "expected copy function " << symbolRef
1860 << " arguments to have the same type";
1861
1862 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1863 if (argTy != varType)
1864 return op->emitOpError()
1865 << "expected copy function arguments' type (" << argTy
1866 << ") to be the same as copyprivate variable's type (" << varType
1867 << ")";
1868 }
1869
1870 return success();
1871}
1872
1873//===----------------------------------------------------------------------===//
1874// Parser, printer and verifier for DependVarList
1875//===----------------------------------------------------------------------===//
1876
1877/// depend-entry-list ::= depend-entry
1878/// | depend-entry-list `,` depend-entry
1879/// depend-entry ::= depend-kind `->` ssa-id `:` type
1880/// | depend-kind `->` ssa-id `:` iterated-type
1881static ParseResult parseDependVarList(
1882 OpAsmParser &parser,
1884 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds,
1886 SmallVectorImpl<Type> &iteratedTypes, ArrayAttr &iteratedKinds) {
1889 if (failed(parser.parseCommaSeparatedList([&]() {
1890 StringRef keyword;
1891 OpAsmParser::UnresolvedOperand operand;
1892 Type ty;
1893 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1894 parser.parseOperand(operand) || parser.parseColonType(ty))
1895 return failure();
1896 std::optional<ClauseTaskDepend> keywordDepend =
1897 symbolizeClauseTaskDepend(keyword);
1898 if (!keywordDepend)
1899 return failure();
1900 auto kindAttr =
1901 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1902 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1903 iteratedVars.push_back(operand);
1904 iteratedTypes.push_back(ty);
1905 iterKindsVec.push_back(kindAttr);
1906 } else {
1907 dependVars.push_back(operand);
1908 dependTypes.push_back(ty);
1909 kindsVec.push_back(kindAttr);
1910 }
1911 return success();
1912 })))
1913 return failure();
1914 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1915 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1916 SmallVector<Attribute> iterKinds(iterKindsVec.begin(), iterKindsVec.end());
1917 iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds);
1918 return success();
1919}
1920
1921/// Print Depend clause
1923 OperandRange dependVars, TypeRange dependTypes,
1924 std::optional<ArrayAttr> dependKinds,
1925 OperandRange iteratedVars,
1926 TypeRange iteratedTypes,
1927 std::optional<ArrayAttr> iteratedKinds) {
1928 bool first = true;
1929 auto printEntries = [&](OperandRange vars, TypeRange types,
1930 std::optional<ArrayAttr> kinds) {
1931 for (unsigned i = 0, e = vars.size(); i < e; ++i) {
1932 if (!first)
1933 p << ", ";
1934 p << stringifyClauseTaskDepend(
1935 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1936 .getValue())
1937 << " -> " << vars[i] << " : " << types[i];
1938 first = false;
1939 }
1940 };
1941 printEntries(dependVars, dependTypes, dependKinds);
1942 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1943}
1944
1945/// Verifies Depend clause
1946static LogicalResult verifyDependVarList(Operation *op,
1947 std::optional<ArrayAttr> dependKinds,
1948 OperandRange dependVars,
1949 std::optional<ArrayAttr> iteratedKinds,
1950 OperandRange iteratedVars) {
1951 if (!dependVars.empty()) {
1952 if (!dependKinds || dependKinds->size() != dependVars.size())
1953 return op->emitOpError() << "expected as many depend values"
1954 " as depend variables";
1955 } else {
1956 if (dependKinds && !dependKinds->empty())
1957 return op->emitOpError() << "unexpected depend values";
1958 }
1959
1960 if (!iteratedVars.empty()) {
1961 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1962 return op->emitOpError() << "expected as many depend iterated values"
1963 " as depend iterated variables";
1964 } else {
1965 if (iteratedKinds && !iteratedKinds->empty())
1966 return op->emitOpError() << "unexpected depend iterated values";
1967 }
1968
1969 return success();
1970}
1971
1972//===----------------------------------------------------------------------===//
1973// Parser, printer and verifier for Synchronization Hint (2.17.12)
1974//===----------------------------------------------------------------------===//
1975
1976/// Parses a Synchronization Hint clause. The value of hint is an integer
1977/// which is a combination of different hints from `omp_sync_hint_t`.
1978///
1979/// hint-clause = `hint` `(` hint-value `)`
1980static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1981 IntegerAttr &hintAttr) {
1982 StringRef hintKeyword;
1983 int64_t hint = 0;
1984 if (succeeded(parser.parseOptionalKeyword("none"))) {
1985 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1986 return success();
1987 }
1988 auto parseKeyword = [&]() -> ParseResult {
1989 if (failed(parser.parseKeyword(&hintKeyword)))
1990 return failure();
1991 if (hintKeyword == "uncontended")
1992 hint |= 1;
1993 else if (hintKeyword == "contended")
1994 hint |= 2;
1995 else if (hintKeyword == "nonspeculative")
1996 hint |= 4;
1997 else if (hintKeyword == "speculative")
1998 hint |= 8;
1999 else
2000 return parser.emitError(parser.getCurrentLocation())
2001 << hintKeyword << " is not a valid hint";
2002 return success();
2003 };
2004 if (parser.parseCommaSeparatedList(parseKeyword))
2005 return failure();
2006 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
2007 return success();
2008}
2009
2010/// Prints a Synchronization Hint clause
2012 IntegerAttr hintAttr) {
2013 int64_t hint = hintAttr.getInt();
2014
2015 if (hint == 0) {
2016 p << "none";
2017 return;
2018 }
2019
2020 // Helper function to get n-th bit from the right end of `value`
2021 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
2022
2023 bool uncontended = bitn(hint, 0);
2024 bool contended = bitn(hint, 1);
2025 bool nonspeculative = bitn(hint, 2);
2026 bool speculative = bitn(hint, 3);
2027
2029 if (uncontended)
2030 hints.push_back("uncontended");
2031 if (contended)
2032 hints.push_back("contended");
2033 if (nonspeculative)
2034 hints.push_back("nonspeculative");
2035 if (speculative)
2036 hints.push_back("speculative");
2037
2038 llvm::interleaveComma(hints, p);
2039}
2040
2041/// Verifies a synchronization hint clause
2042static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
2043
2044 // Helper function to get n-th bit from the right end of `value`
2045 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
2046
2047 bool uncontended = bitn(hint, 0);
2048 bool contended = bitn(hint, 1);
2049 bool nonspeculative = bitn(hint, 2);
2050 bool speculative = bitn(hint, 3);
2051
2052 if (uncontended && contended)
2053 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
2054 "omp_sync_hint_contended cannot be combined";
2055 if (nonspeculative && speculative)
2056 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
2057 "omp_sync_hint_speculative cannot be combined.";
2058 return success();
2059}
2060
2061//===----------------------------------------------------------------------===//
2062// Parser, printer and verifier for Target
2063//===----------------------------------------------------------------------===//
2064
2065// Helper function to get bitwise AND of `value` and 'flag' then return it as a
2066// boolean
2067static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
2068 return (value & flag) == flag;
2069}
2070
2071/// Parses a map_entries map type from a string format back into its numeric
2072/// value.
2073///
2074/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
2075/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
2076static ParseResult parseMapClause(OpAsmParser &parser,
2077 ClauseMapFlagsAttr &mapType) {
2078 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
2079 // This simply verifies the correct keyword is read in, the
2080 // keyword itself is stored inside of the operation
2081 auto parseTypeAndMod = [&]() -> ParseResult {
2082 StringRef mapTypeMod;
2083 if (parser.parseKeyword(&mapTypeMod))
2084 return failure();
2085
2086 if (mapTypeMod == "always")
2087 mapTypeBits |= ClauseMapFlags::always;
2088
2089 if (mapTypeMod == "implicit")
2090 mapTypeBits |= ClauseMapFlags::implicit;
2091
2092 if (mapTypeMod == "ompx_hold")
2093 mapTypeBits |= ClauseMapFlags::ompx_hold;
2094
2095 if (mapTypeMod == "close")
2096 mapTypeBits |= ClauseMapFlags::close;
2097
2098 if (mapTypeMod == "present")
2099 mapTypeBits |= ClauseMapFlags::present;
2100
2101 if (mapTypeMod == "to")
2102 mapTypeBits |= ClauseMapFlags::to;
2103
2104 if (mapTypeMod == "from")
2105 mapTypeBits |= ClauseMapFlags::from;
2106
2107 if (mapTypeMod == "tofrom")
2108 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
2109
2110 if (mapTypeMod == "delete")
2111 mapTypeBits |= ClauseMapFlags::del;
2112
2113 if (mapTypeMod == "storage")
2114 mapTypeBits |= ClauseMapFlags::storage;
2115
2116 if (mapTypeMod == "return_param")
2117 mapTypeBits |= ClauseMapFlags::return_param;
2118
2119 if (mapTypeMod == "private")
2120 mapTypeBits |= ClauseMapFlags::priv;
2121
2122 if (mapTypeMod == "literal")
2123 mapTypeBits |= ClauseMapFlags::literal;
2124
2125 if (mapTypeMod == "attach")
2126 mapTypeBits |= ClauseMapFlags::attach;
2127
2128 if (mapTypeMod == "attach_always")
2129 mapTypeBits |= ClauseMapFlags::attach_always;
2130
2131 if (mapTypeMod == "attach_never")
2132 mapTypeBits |= ClauseMapFlags::attach_never;
2133
2134 if (mapTypeMod == "attach_auto")
2135 mapTypeBits |= ClauseMapFlags::attach_auto;
2136
2137 if (mapTypeMod == "ref_ptr")
2138 mapTypeBits |= ClauseMapFlags::ref_ptr;
2139
2140 if (mapTypeMod == "ref_ptee")
2141 mapTypeBits |= ClauseMapFlags::ref_ptee;
2142
2143 if (mapTypeMod == "is_device_ptr")
2144 mapTypeBits |= ClauseMapFlags::is_device_ptr;
2145
2146 return success();
2147 };
2148
2149 if (parser.parseCommaSeparatedList(parseTypeAndMod))
2150 return failure();
2151
2152 mapType =
2153 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
2154
2155 return success();
2156}
2157
2158/// Prints a map_entries map type from its numeric value out into its string
2159/// format.
2160static void printMapClause(OpAsmPrinter &p, Operation *op,
2161 ClauseMapFlagsAttr mapType) {
2163 ClauseMapFlags mapFlags = mapType.getValue();
2164
2165 // handling of always, close, present placed at the beginning of the string
2166 // to aid readability
2167 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
2168 mapTypeStrs.push_back("always");
2169 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
2170 mapTypeStrs.push_back("implicit");
2171 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
2172 mapTypeStrs.push_back("ompx_hold");
2173 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
2174 mapTypeStrs.push_back("close");
2175 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
2176 mapTypeStrs.push_back("present");
2177
2178 // special handling of to/from/tofrom/delete and release/alloc, release +
2179 // alloc are the abscense of one of the other flags, whereas tofrom requires
2180 // both the to and from flag to be set.
2181 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
2182 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
2183
2184 if (to && from)
2185 mapTypeStrs.push_back("tofrom");
2186 else if (from)
2187 mapTypeStrs.push_back("from");
2188 else if (to)
2189 mapTypeStrs.push_back("to");
2190
2191 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
2192 mapTypeStrs.push_back("delete");
2193 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
2194 mapTypeStrs.push_back("return_param");
2195 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
2196 mapTypeStrs.push_back("storage");
2197 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
2198 mapTypeStrs.push_back("private");
2199 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
2200 mapTypeStrs.push_back("literal");
2201 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
2202 mapTypeStrs.push_back("attach");
2203 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
2204 mapTypeStrs.push_back("attach_always");
2205 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
2206 mapTypeStrs.push_back("attach_never");
2207 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
2208 mapTypeStrs.push_back("attach_auto");
2209 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
2210 mapTypeStrs.push_back("ref_ptr");
2211 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
2212 mapTypeStrs.push_back("ref_ptee");
2213 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
2214 mapTypeStrs.push_back("is_device_ptr");
2215 if (mapFlags == ClauseMapFlags::none)
2216 mapTypeStrs.push_back("none");
2217
2218 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2219 p << mapTypeStrs[i];
2220 if (i + 1 < mapTypeStrs.size()) {
2221 p << ", ";
2222 }
2223 }
2224}
2225
2226static ParseResult parseMembersIndex(OpAsmParser &parser,
2227 ArrayAttr &membersIdx) {
2228 SmallVector<Attribute> values, memberIdxs;
2229
2230 auto parseIndices = [&]() -> ParseResult {
2231 int64_t value;
2232 if (parser.parseInteger(value))
2233 return failure();
2234 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
2235 APInt(64, value, /*isSigned=*/false)));
2236 return success();
2237 };
2238
2239 do {
2240 if (failed(parser.parseLSquare()))
2241 return failure();
2242
2243 if (parser.parseCommaSeparatedList(parseIndices))
2244 return failure();
2245
2246 if (failed(parser.parseRSquare()))
2247 return failure();
2248
2249 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
2250 values.clear();
2251 } while (succeeded(parser.parseOptionalComma()));
2252
2253 if (!memberIdxs.empty())
2254 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
2255
2256 return success();
2257}
2258
2259static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
2260 ArrayAttr membersIdx) {
2261 if (!membersIdx)
2262 return;
2263
2264 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
2265 p << "[";
2266 auto memberIdx = cast<ArrayAttr>(v);
2267 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
2268 p << cast<IntegerAttr>(v2).getInt();
2269 });
2270 p << "]";
2271 });
2272}
2273
2275 VariableCaptureKindAttr mapCaptureType) {
2276 std::string typeCapStr;
2277 llvm::raw_string_ostream typeCap(typeCapStr);
2278 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2279 typeCap << "ByRef";
2280 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2281 typeCap << "ByCopy";
2282 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2283 typeCap << "VLAType";
2284 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2285 typeCap << "This";
2286 p << typeCapStr;
2287}
2288
2289static ParseResult parseCaptureType(OpAsmParser &parser,
2290 VariableCaptureKindAttr &mapCaptureType) {
2291 StringRef mapCaptureKey;
2292 if (parser.parseKeyword(&mapCaptureKey))
2293 return failure();
2294
2295 if (mapCaptureKey == "This")
2296 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2297 parser.getContext(), mlir::omp::VariableCaptureKind::This);
2298 if (mapCaptureKey == "ByRef")
2299 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2300 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
2301 if (mapCaptureKey == "ByCopy")
2302 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2303 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2304 if (mapCaptureKey == "VLAType")
2305 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2306 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
2307
2308 return success();
2309}
2310
2311static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
2314
2315 for (auto mapOp : mapVars) {
2316 if (!mapOp.getDefiningOp())
2317 return emitError(op->getLoc(), "missing map operation");
2318
2319 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2320 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2321
2322 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2323 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2324 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2325
2326 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2327 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2328 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2329 bool attach = mapTypeToBool(mapTypeBits, ClauseMapFlags::attach);
2330
2331 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2332 return emitError(op->getLoc(),
2333 "to, from, tofrom and alloc map types are permitted");
2334
2335 if (isa<TargetEnterDataOp>(op) && (from || del))
2336 return emitError(op->getLoc(), "to and alloc map types are permitted");
2337
2338 if (isa<TargetExitDataOp>(op) && to)
2339 return emitError(op->getLoc(),
2340 "from, release and delete map types are permitted");
2341
2342 if (isa<TargetUpdateOp>(op)) {
2343 if (del) {
2344 return emitError(op->getLoc(),
2345 "at least one of to or from map types must be "
2346 "specified, other map types are not permitted");
2347 }
2348
2349 if (!to && !from && !attach) {
2350 return emitError(
2351 op->getLoc(),
2352 "at least one of to or from or attach map types must be "
2353 "specified, other map types are not permitted");
2354 }
2355
2356 auto updateVar = mapInfoOp.getVarPtr();
2357
2358 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2359 (from && updateToVars.contains(updateVar))) {
2360 return emitError(
2361 op->getLoc(),
2362 "either to or from map types can be specified, not both");
2363 }
2364
2365 if (always || close || implicit) {
2366 return emitError(
2367 op->getLoc(),
2368 "present, mapper and iterator map type modifiers are permitted");
2369 }
2370
2371 // It's possible we have an attach map, in which case if there is no to
2372 // or from tied to it, we skip insertion.
2373 if (to || from) {
2374 to ? updateToVars.insert(updateVar)
2375 : updateFromVars.insert(updateVar);
2376 }
2377 }
2378
2379 if ((mapInfoOp.getVarPtrPtr() && !mapInfoOp.getVarPtrPtrType()) ||
2380 (!mapInfoOp.getVarPtrPtr() && mapInfoOp.getVarPtrPtrType())) {
2381 return emitError(
2382 op->getLoc(),
2383 "if varPtrPtr or varPtrPtrType is specified, then both "
2384 "must be present");
2385 }
2386 } else if (!isa<DeclareMapperInfoOp>(op)) {
2387 return emitError(op->getLoc(),
2388 "map argument is not a map entry operation");
2389 }
2390 }
2391
2392 return success();
2393}
2394
2395template <typename OpType>
2396static LogicalResult verifyPrivateVarList(OpType &op);
2397
2398static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2399 std::optional<DenseI64ArrayAttr> privateMapIndices =
2400 targetOp.getPrivateMapsAttr();
2401
2402 // None of the private operands are mapped.
2403 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2404 return success();
2405
2406 OperandRange privateVars = targetOp.getPrivateVars();
2407
2408 if (privateMapIndices.value().size() !=
2409 static_cast<int64_t>(privateVars.size()))
2410 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2411 "`private_maps` attribute mismatch");
2412
2413 return success();
2414}
2415
2416//===----------------------------------------------------------------------===//
2417// MapInfoOp
2418//===----------------------------------------------------------------------===//
2419
2420static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2421 StringRef clauseName,
2422 OperandRange vars) {
2423 for (Value var : vars)
2424 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2425 return op->emitOpError()
2426 << "'" << clauseName
2427 << "' arguments must be defined by 'omp.map.info' ops";
2428 return success();
2429}
2430
2431LogicalResult MapInfoOp::verify() {
2432 if (getMapperId() &&
2434 *this, getMapperIdAttr())) {
2435 return emitError("invalid mapper id");
2436 }
2437
2438 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2439 return failure();
2440
2441 return success();
2442}
2443
2444//===----------------------------------------------------------------------===//
2445// TargetDataOp
2446//===----------------------------------------------------------------------===//
2447
2448void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2449 const TargetDataOperands &clauses) {
2450 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2451 clauses.mapVars, clauses.useDeviceAddrVars,
2452 clauses.useDevicePtrVars);
2453}
2454
2455LogicalResult TargetDataOp::verify() {
2456 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2457 getUseDeviceAddrVars().empty()) {
2458 return ::emitError(this->getLoc(),
2459 "At least one of map, use_device_ptr_vars, or "
2460 "use_device_addr_vars operand must be present");
2461 }
2462
2463 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2464 getUseDevicePtrVars())))
2465 return failure();
2466
2467 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2468 getUseDeviceAddrVars())))
2469 return failure();
2470
2471 return verifyMapClause(*this, getMapVars());
2472}
2473
2474//===----------------------------------------------------------------------===//
2475// TargetEnterDataOp
2476//===----------------------------------------------------------------------===//
2477
2478void TargetEnterDataOp::build(
2479 OpBuilder &builder, OperationState &state,
2480 const TargetEnterExitUpdateDataOperands &clauses) {
2481 MLIRContext *ctx = builder.getContext();
2482 TargetEnterDataOp::build(
2483 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2484 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2485 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2486 clauses.nowait);
2487}
2488
2489LogicalResult TargetEnterDataOp::verify() {
2490 LogicalResult verifyDependVars =
2491 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2492 getDependIteratedKinds(), getDependIterated());
2493 return failed(verifyDependVars) ? verifyDependVars
2494 : verifyMapClause(*this, getMapVars());
2495}
2496
2497//===----------------------------------------------------------------------===//
2498// TargetExitDataOp
2499//===----------------------------------------------------------------------===//
2500
2501void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2502 const TargetEnterExitUpdateDataOperands &clauses) {
2503 MLIRContext *ctx = builder.getContext();
2504 TargetExitDataOp::build(
2505 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2506 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2507 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2508 clauses.nowait);
2509}
2510
2511LogicalResult TargetExitDataOp::verify() {
2512 LogicalResult verifyDependVars =
2513 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2514 getDependIteratedKinds(), getDependIterated());
2515 return failed(verifyDependVars) ? verifyDependVars
2516 : verifyMapClause(*this, getMapVars());
2517}
2518
2519//===----------------------------------------------------------------------===//
2520// TargetUpdateOp
2521//===----------------------------------------------------------------------===//
2522
2523void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2524 const TargetEnterExitUpdateDataOperands &clauses) {
2525 MLIRContext *ctx = builder.getContext();
2526 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2527 clauses.dependVars,
2528 makeArrayAttr(ctx, clauses.dependIteratedKinds),
2529 clauses.dependIterated, clauses.device, clauses.ifExpr,
2530 clauses.mapVars, clauses.nowait);
2531}
2532
2533LogicalResult TargetUpdateOp::verify() {
2534 LogicalResult verifyDependVars =
2535 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2536 getDependIteratedKinds(), getDependIterated());
2537 return failed(verifyDependVars) ? verifyDependVars
2538 : verifyMapClause(*this, getMapVars());
2539}
2540
2541//===----------------------------------------------------------------------===//
2542// TargetOp
2543//===----------------------------------------------------------------------===//
2544
2545void TargetOp::build(OpBuilder &builder, OperationState &state,
2546 const TargetOperands &clauses) {
2547 MLIRContext *ctx = builder.getContext();
2548 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2549 // inReductionByref, inReductionSyms.
2550 TargetOp::build(
2551 builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
2552 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2553 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2554 clauses.device, clauses.dynGroupprivateAccessGroup,
2555 clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
2556 clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
2557 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2558 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2559 clauses.nowait, clauses.privateVars,
2560 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2561 clauses.threadLimitVars,
2562 /*private_maps=*/nullptr);
2563}
2564
2565LogicalResult TargetOp::verify() {
2566 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
2567 getDependIteratedKinds(),
2568 getDependIterated())))
2569 return failure();
2570
2571 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2572 getHasDeviceAddrVars())))
2573 return failure();
2574
2575 if (failed(verifyMapClause(*this, getMapVars())))
2576 return failure();
2577
2579 *this, getDynGroupprivateAccessGroupAttr(),
2580 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
2581 return failure();
2582
2583 if (failed(verifyPrivateVarList(*this)))
2584 return failure();
2585
2586 return verifyPrivateVarsMapping(*this);
2587}
2588
2589LogicalResult TargetOp::verifyRegions() {
2590 auto teamsOps = getOps<TeamsOp>();
2591 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2592 return emitError("target containing multiple 'omp.teams' nested ops");
2593
2594 // Check that host_eval values are only used in legal ways.
2595 bool hostEvalTripCount;
2596 Operation *capturedOp = getInnermostCapturedOmpOp();
2597 TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
2598 for (Value hostEvalArg :
2599 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2600 for (Operation *user : hostEvalArg.getUsers()) {
2601 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2602 // Check if used in num_teams_lower or any of num_teams_upper_vars
2603 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2604 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2605 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2606 continue;
2607
2608 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2609 "and 'thread_limit' in 'omp.teams'";
2610 }
2611 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2612 if (execMode == TargetExecMode::spmd &&
2613 parallelOp->isAncestor(capturedOp) &&
2614 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2615 continue;
2616
2617 return emitOpError()
2618 << "host_eval argument only legal as 'num_threads' in "
2619 "'omp.parallel' when representing target SPMD";
2620 }
2621 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2622 if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
2623 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2624 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2625 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2626 continue;
2627
2628 return emitOpError() << "host_eval argument only legal as loop bounds "
2629 "and steps in 'omp.loop_nest' when trip count "
2630 "must be evaluated in the host";
2631 }
2632
2633 return emitOpError() << "host_eval argument illegal use in '"
2634 << user->getName() << "' operation";
2635 }
2636 }
2637 return success();
2638}
2639
2640static Operation *
2641findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2642 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2643 assert(rootOp && "expected valid operation");
2644
2645 Dialect *ompDialect = rootOp->getDialect();
2646 Operation *capturedOp = nullptr;
2647 DominanceInfo domInfo;
2648
2649 // Process in pre-order to check operations from outermost to innermost,
2650 // ensuring we only enter the region of an operation if it meets the criteria
2651 // for being captured. We stop the exploration of nested operations as soon as
2652 // we process a region holding no operations to be captured.
2653 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2654 if (op == rootOp)
2655 return WalkResult::advance();
2656
2657 // Ignore operations of other dialects or omp operations with no regions,
2658 // because these will only be checked if they are siblings of an omp
2659 // operation that can potentially be captured.
2660 bool isOmpDialect = op->getDialect() == ompDialect;
2661 bool hasRegions = op->getNumRegions() > 0;
2662 if (!isOmpDialect || !hasRegions)
2663 return WalkResult::skip();
2664
2665 // This operation cannot be captured if it can be executed more than once
2666 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2667 // be executed before all exits of the region (i.e. it doesn't dominate all
2668 // blocks with no successors reachable from the entry block).
2669 if (checkSingleMandatoryExec) {
2670 Region *parentRegion = op->getParentRegion();
2671 Block *parentBlock = op->getBlock();
2672
2673 for (Block *successor : parentBlock->getSuccessors())
2674 if (successor->isReachable(parentBlock))
2675 return WalkResult::interrupt();
2676
2677 for (Block &block : *parentRegion)
2678 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2679 !domInfo.dominates(parentBlock, &block))
2680 return WalkResult::interrupt();
2681 }
2682
2683 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2684 // into nested operations.
2685 for (Operation &sibling : op->getParentRegion()->getOps())
2686 if (&sibling != op && !siblingAllowedFn(&sibling))
2687 return WalkResult::interrupt();
2688
2689 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2690 // Otherwise, process the contents of this operation.
2691 capturedOp = op;
2692 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2694 });
2695
2696 return capturedOp;
2697}
2698
2699Operation *TargetOp::getInnermostCapturedOmpOp() {
2700 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2701
2702 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2703 // effects, but don't include a memory write effect.
2704 return findCapturedOmpOp(
2705 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2706 if (!sibling)
2707 return false;
2708
2709 if (ompDialect == sibling->getDialect())
2710 return sibling->hasTrait<OpTrait::IsTerminator>();
2711
2712 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2714 effects;
2715 memOp.getEffects(effects);
2716 return !llvm::any_of(
2717 effects, [&](MemoryEffects::EffectInstance &effect) {
2718 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2719 isa<SideEffects::AutomaticAllocationScopeResource>(
2720 effect.getResource());
2721 });
2722 }
2723 return true;
2724 });
2725}
2726
2727/// Check if we can promote SPMD kernel to No-Loop kernel.
2728static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2729 WsloopOp *wsLoopOp) {
2730 // num_teams clause can break no-loop teams/threads assumption.
2731 if (!teamsOp.getNumTeamsUpperVars().empty())
2732 return false;
2733
2734 // Reduction kernels are slower in no-loop mode.
2735 if (teamsOp.getNumReductionVars())
2736 return false;
2737 if (wsLoopOp->getNumReductionVars())
2738 return false;
2739
2740 // Check if the user allows the promotion of kernels to no-loop mode.
2741 OffloadModuleInterface offloadMod =
2742 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2743 if (!offloadMod)
2744 return false;
2745 auto ompFlags = offloadMod.getFlags();
2746 if (!ompFlags)
2747 return false;
2748 return ompFlags.getAssumeTeamsOversubscription() &&
2749 ompFlags.getAssumeThreadsOversubscription();
2750}
2751
2752TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
2753 bool *hostEvalTripCount) {
2754 // TODO: Support detection of bare kernel mode.
2755 // A non-null captured op is only valid if it resides inside of a TargetOp
2756 // and is the result of calling getInnermostCapturedOmpOp() on it.
2757 TargetOp targetOp =
2758 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2759 assert((!capturedOp ||
2760 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2761 "unexpected captured op");
2762
2763 if (hostEvalTripCount)
2764 *hostEvalTripCount = false;
2765
2766 // If it's not capturing a loop, it's a default target region.
2767 if (!isa_and_present<LoopNestOp>(capturedOp))
2768 return TargetExecMode::generic;
2769
2770 // Get the innermost non-simd loop wrapper.
2772 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2773 assert(!loopWrappers.empty());
2774
2775 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2776 if (isa<SimdOp>(innermostWrapper))
2777 innermostWrapper = std::next(innermostWrapper);
2778
2779 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2780 if (numWrappers != 1 && numWrappers != 2)
2781 return TargetExecMode::generic;
2782
2783 // Detect target-teams-distribute-parallel-wsloop[-simd].
2784 if (numWrappers == 2) {
2785 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2786 if (!wsloopOp)
2787 return TargetExecMode::generic;
2788
2789 innermostWrapper = std::next(innermostWrapper);
2790 if (!isa<DistributeOp>(innermostWrapper))
2791 return TargetExecMode::generic;
2792
2793 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2794 if (!isa_and_present<ParallelOp>(parallelOp))
2795 return TargetExecMode::generic;
2796
2797 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2798 if (!teamsOp)
2799 return TargetExecMode::generic;
2800
2801 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2802 TargetExecMode result = TargetExecMode::spmd;
2803 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2804 result = TargetExecMode::no_loop;
2805 if (hostEvalTripCount)
2806 *hostEvalTripCount = true;
2807 return result;
2808 }
2809 }
2810 // Detect target-teams-distribute[-simd] and target-teams-loop.
2811 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2812 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2813 if (!isa_and_present<TeamsOp>(teamsOp))
2814 return TargetExecMode::generic;
2815
2816 if (teamsOp->getParentOp() != targetOp.getOperation())
2817 return TargetExecMode::generic;
2818
2819 if (hostEvalTripCount)
2820 *hostEvalTripCount = true;
2821
2822 if (isa<LoopOp>(innermostWrapper))
2823 return TargetExecMode::spmd;
2824
2825 return TargetExecMode::generic;
2826 }
2827 // Detect target-parallel-wsloop[-simd].
2828 else if (isa<WsloopOp>(innermostWrapper)) {
2829 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2830 if (!isa_and_present<ParallelOp>(parallelOp))
2831 return TargetExecMode::generic;
2832
2833 if (parallelOp->getParentOp() == targetOp.getOperation())
2834 return TargetExecMode::spmd;
2835 }
2836
2837 return TargetExecMode::generic;
2838}
2839
2840//===----------------------------------------------------------------------===//
2841// ParallelOp
2842//===----------------------------------------------------------------------===//
2843
2844void ParallelOp::build(OpBuilder &builder, OperationState &state,
2845 ArrayRef<NamedAttribute> attributes) {
2846 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2847 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2848 /*num_threads_vars=*/ValueRange(),
2849 /*private_vars=*/ValueRange(),
2850 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2851 /*proc_bind_kind=*/nullptr,
2852 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2853 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2854 state.addAttributes(attributes);
2855}
2856
2857void ParallelOp::build(OpBuilder &builder, OperationState &state,
2858 const ParallelOperands &clauses) {
2859 MLIRContext *ctx = builder.getContext();
2860 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2861 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2862 makeArrayAttr(ctx, clauses.privateSyms),
2863 clauses.privateNeedsBarrier, clauses.procBindKind,
2864 clauses.reductionMod, clauses.reductionVars,
2865 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2866 makeArrayAttr(ctx, clauses.reductionSyms));
2867}
2868
2869template <typename OpType>
2870static LogicalResult verifyPrivateVarList(OpType &op) {
2871 auto privateVars = op.getPrivateVars();
2872 auto privateSyms = op.getPrivateSymsAttr();
2873
2874 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2875 return success();
2876
2877 auto numPrivateVars = privateVars.size();
2878 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2879
2880 if (numPrivateVars != numPrivateSyms)
2881 return op.emitError() << "inconsistent number of private variables and "
2882 "privatizer op symbols, private vars: "
2883 << numPrivateVars
2884 << " vs. privatizer op symbols: " << numPrivateSyms;
2885
2886 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2887 Type varType = std::get<0>(privateVarInfo).getType();
2888 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2889 PrivateClauseOp privatizerOp =
2891
2892 if (privatizerOp == nullptr)
2893 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2894 << privateSym << "'";
2895
2896 Type privatizerType = privatizerOp.getArgType();
2897
2898 if (privatizerType && (varType != privatizerType))
2899 return op.emitError()
2900 << "type mismatch between a "
2901 << (privatizerOp.getDataSharingType() ==
2902 DataSharingClauseType::Private
2903 ? "private"
2904 : "firstprivate")
2905 << " variable and its privatizer op, var type: " << varType
2906 << " vs. privatizer op type: " << privatizerType;
2907 }
2908
2909 return success();
2910}
2911
2912LogicalResult ParallelOp::verify() {
2913 if (getAllocateVars().size() != getAllocatorVars().size())
2914 return emitError(
2915 "expected equal sizes for allocate and allocator variables");
2916
2917 if (failed(verifyPrivateVarList(*this)))
2918 return failure();
2919
2920 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2921 getReductionByref());
2922}
2923
2924LogicalResult ParallelOp::verifyRegions() {
2925 auto distChildOps = getOps<DistributeOp>();
2926 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2927 if (numDistChildOps > 1)
2928 return emitError()
2929 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2930
2931 if (numDistChildOps == 1) {
2932 if (!isComposite())
2933 return emitError()
2934 << "'omp.composite' attribute missing from composite operation";
2935
2936 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2937 Operation &distributeOp = **distChildOps.begin();
2938 for (Operation &childOp : getOps()) {
2939 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2940 continue;
2941
2942 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2943 return emitError() << "unexpected OpenMP operation inside of composite "
2944 "'omp.parallel': "
2945 << childOp.getName();
2946 }
2947 } else if (isComposite()) {
2948 return emitError()
2949 << "'omp.composite' attribute present in non-composite operation";
2950 }
2951 return success();
2952}
2953
2954//===----------------------------------------------------------------------===//
2955// TeamsOp
2956//===----------------------------------------------------------------------===//
2957
2959 while ((op = op->getParentOp()))
2960 if (isa<OpenMPDialect>(op->getDialect()))
2961 return false;
2962 return true;
2963}
2964
2965void TeamsOp::build(OpBuilder &builder, OperationState &state,
2966 const TeamsOperands &clauses) {
2967 MLIRContext *ctx = builder.getContext();
2968 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2969 TeamsOp::build(
2970 builder, state, clauses.allocateVars, clauses.allocatorVars,
2971 clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
2972 clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
2973 clauses.numTeamsUpperVars, /*private_vars=*/{}, /*private_syms=*/nullptr,
2974 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2975 clauses.reductionVars,
2976 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2977 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2978}
2979
2980// Verify num_teams clause
2981static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
2982 OperandRange numTeamsUpperVars) {
2983 // If lower is specified, upper must have exactly one value
2984 if (numTeamsLower) {
2985 if (numTeamsUpperVars.size() != 1)
2986 return op->emitError(
2987 "expected exactly one num_teams upper bound when lower bound is "
2988 "specified");
2989 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
2990 return op->emitError(
2991 "expected num_teams upper bound and lower bound to be "
2992 "the same type");
2993 }
2994
2995 return success();
2996}
2997
2998LogicalResult TeamsOp::verify() {
2999 // Check parent region
3000 // TODO If nested inside of a target region, also check that it does not
3001 // contain any statements, declarations or directives other than this
3002 // omp.teams construct. The issue is how to support the initialization of
3003 // this operation's own arguments (allow SSA values across omp.target?).
3004 Operation *op = getOperation();
3005 if (!isa<TargetOp>(op->getParentOp()) &&
3007 return emitError("expected to be nested inside of omp.target or not nested "
3008 "in any OpenMP dialect operations");
3009
3010 // Check for num_teams clause restrictions
3011 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
3012 this->getNumTeamsUpperVars())))
3013 return failure();
3014
3015 // Check for allocate clause restrictions
3016 if (getAllocateVars().size() != getAllocatorVars().size())
3017 return emitError(
3018 "expected equal sizes for allocate and allocator variables");
3019
3021 op, getDynGroupprivateAccessGroupAttr(),
3022 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
3023 return failure();
3024
3025 if (failed(verifyPrivateVarList(*this)))
3026 return failure();
3027
3028 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3029 getReductionByref());
3030}
3031
3032//===----------------------------------------------------------------------===//
3033// SectionOp
3034//===----------------------------------------------------------------------===//
3035
3036OperandRange SectionOp::getPrivateVars() {
3037 return getParentOp().getPrivateVars();
3038}
3039
3040OperandRange SectionOp::getReductionVars() {
3041 return getParentOp().getReductionVars();
3042}
3043
3044//===----------------------------------------------------------------------===//
3045// SectionsOp
3046//===----------------------------------------------------------------------===//
3047
3048void SectionsOp::build(OpBuilder &builder, OperationState &state,
3049 const SectionsOperands &clauses) {
3050 MLIRContext *ctx = builder.getContext();
3051 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
3052 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3053 clauses.nowait, /*private_vars=*/{},
3054 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
3055 clauses.reductionMod, clauses.reductionVars,
3056 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3057 makeArrayAttr(ctx, clauses.reductionSyms));
3058}
3059
3060LogicalResult SectionsOp::verify() {
3061 if (getAllocateVars().size() != getAllocatorVars().size())
3062 return emitError(
3063 "expected equal sizes for allocate and allocator variables");
3064
3065 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3066 getReductionByref());
3067}
3068
3069LogicalResult SectionsOp::verifyRegions() {
3070 for (auto &inst : *getRegion().begin()) {
3071 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
3072 return emitOpError()
3073 << "expected omp.section op or terminator op inside region";
3074 }
3075 }
3076
3077 return success();
3078}
3079
3080//===----------------------------------------------------------------------===//
3081// ScopeOp
3082//===----------------------------------------------------------------------===//
3083
3084void ScopeOp::build(OpBuilder &builder, OperationState &state,
3085 const ScopeOperands &clauses) {
3086 MLIRContext *ctx = builder.getContext();
3087 ScopeOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3088 clauses.nowait, clauses.privateVars,
3089 makeArrayAttr(ctx, clauses.privateSyms),
3090 clauses.privateNeedsBarrier, clauses.reductionMod,
3091 clauses.reductionVars,
3092 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3093 makeArrayAttr(ctx, clauses.reductionSyms));
3094}
3095
3096LogicalResult ScopeOp::verify() {
3097 if (getAllocateVars().size() != getAllocatorVars().size())
3098 return emitError(
3099 "expected equal sizes for allocate and allocator variables");
3100
3101 if (failed(verifyPrivateVarList(*this)))
3102 return failure();
3103
3104 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3105 getReductionByref());
3106}
3107
3108//===----------------------------------------------------------------------===//
3109// SingleOp
3110//===----------------------------------------------------------------------===//
3111
3112void SingleOp::build(OpBuilder &builder, OperationState &state,
3113 const SingleOperands &clauses) {
3114 MLIRContext *ctx = builder.getContext();
3115 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
3116 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3117 clauses.copyprivateVars,
3118 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
3119 /*private_vars=*/{}, /*private_syms=*/nullptr,
3120 /*private_needs_barrier=*/nullptr);
3121}
3122
3123LogicalResult SingleOp::verify() {
3124 // Check for allocate clause restrictions
3125 if (getAllocateVars().size() != getAllocatorVars().size())
3126 return emitError(
3127 "expected equal sizes for allocate and allocator variables");
3128
3129 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
3130 getCopyprivateSyms());
3131}
3132
3133//===----------------------------------------------------------------------===//
3134// WorkshareOp
3135//===----------------------------------------------------------------------===//
3136
3137void WorkshareOp::build(OpBuilder &builder, OperationState &state,
3138 const WorkshareOperands &clauses) {
3139 WorkshareOp::build(builder, state, clauses.nowait);
3140}
3141
3142//===----------------------------------------------------------------------===//
3143// WorkshareLoopWrapperOp
3144//===----------------------------------------------------------------------===//
3145
3146LogicalResult WorkshareLoopWrapperOp::verify() {
3147 if (!(*this)->getParentOfType<WorkshareOp>())
3148 return emitOpError() << "must be nested in an omp.workshare";
3149 return success();
3150}
3151
3152LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
3153 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3154 getNestedWrapper())
3155 return emitOpError() << "expected to be a standalone loop wrapper";
3156
3157 return success();
3158}
3159
3160//===----------------------------------------------------------------------===//
3161// LoopWrapperInterface
3162//===----------------------------------------------------------------------===//
3163
3164LogicalResult LoopWrapperInterface::verifyImpl() {
3165 Operation *op = this->getOperation();
3166 if (!op->hasTrait<OpTrait::NoTerminator>() ||
3168 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
3169 "and `SingleBlock` traits";
3170
3171 if (op->getNumRegions() != 1)
3172 return emitOpError() << "loop wrapper does not contain exactly one region";
3173
3174 Region &region = op->getRegion(0);
3175 if (range_size(region.getOps()) != 1)
3176 return emitOpError()
3177 << "loop wrapper does not contain exactly one nested op";
3178
3179 Operation &firstOp = *region.op_begin();
3180 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
3181 return emitOpError() << "nested in loop wrapper is not another loop "
3182 "wrapper or `omp.loop_nest`";
3183
3184 return success();
3185}
3186
3187//===----------------------------------------------------------------------===//
3188// LoopOp
3189//===----------------------------------------------------------------------===//
3190
3191void LoopOp::build(OpBuilder &builder, OperationState &state,
3192 const LoopOperands &clauses) {
3193 MLIRContext *ctx = builder.getContext();
3194
3195 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
3196 makeArrayAttr(ctx, clauses.privateSyms),
3197 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
3198 clauses.reductionMod, clauses.reductionVars,
3199 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3200 makeArrayAttr(ctx, clauses.reductionSyms));
3201}
3202
3203LogicalResult LoopOp::verify() {
3204 if (failed(verifyPrivateVarList(*this)))
3205 return failure();
3206
3207 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3208 getReductionByref());
3209}
3210
3211LogicalResult LoopOp::verifyRegions() {
3212 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3213 getNestedWrapper())
3214 return emitOpError() << "expected to be a standalone loop wrapper";
3215
3216 return success();
3217}
3218
3219//===----------------------------------------------------------------------===//
3220// WsloopOp
3221//===----------------------------------------------------------------------===//
3222
3223void WsloopOp::build(OpBuilder &builder, OperationState &state,
3224 ArrayRef<NamedAttribute> attributes) {
3225 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
3226 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
3227 /*linear_var_types*/ nullptr, /*linear_modifiers=*/nullptr,
3228 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
3229 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
3230 /*private_needs_barrier=*/false,
3231 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
3232 /*reduction_byref=*/nullptr,
3233 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
3234 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
3235 /*schedule_simd=*/false);
3236 state.addAttributes(attributes);
3237}
3238
3239void WsloopOp::build(OpBuilder &builder, OperationState &state,
3240 const WsloopOperands &clauses) {
3241 MLIRContext *ctx = builder.getContext();
3242 // TODO: Store clauses in op: allocateVars, allocatorVars
3243 WsloopOp::build(
3244 builder, state,
3245 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
3246 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3247 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3248 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3249 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3250 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3251 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3252 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3253}
3254
3255LogicalResult WsloopOp::verify() {
3256 if (failed(
3257 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3258 return failure();
3259 if (getLinearVars().size() &&
3260 getLinearVarTypes().value().size() != getLinearVars().size())
3261 return emitError() << "Ill-formed type attributes for linear variables";
3262
3263 if (failed(verifyPrivateVarList(*this)))
3264 return failure();
3265
3266 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3267 getReductionByref());
3268}
3269
3270LogicalResult WsloopOp::verifyRegions() {
3271 bool isCompositeChildLeaf =
3272 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3273
3274 if (LoopWrapperInterface nested = getNestedWrapper()) {
3275 if (!isComposite())
3276 return emitError()
3277 << "'omp.composite' attribute missing from composite wrapper";
3278
3279 // Check for the allowed leaf constructs that may appear in a composite
3280 // construct directly after DO/FOR.
3281 if (!isa<SimdOp>(nested))
3282 return emitError() << "only supported nested wrapper is 'omp.simd'";
3283
3284 } else if (isComposite() && !isCompositeChildLeaf) {
3285 return emitError()
3286 << "'omp.composite' attribute present in non-composite wrapper";
3287 } else if (!isComposite() && isCompositeChildLeaf) {
3288 return emitError()
3289 << "'omp.composite' attribute missing from composite wrapper";
3290 }
3291
3292 return success();
3293}
3294
3295//===----------------------------------------------------------------------===//
3296// Simd construct [2.9.3.1]
3297//===----------------------------------------------------------------------===//
3298
3299void SimdOp::build(OpBuilder &builder, OperationState &state,
3300 const SimdOperands &clauses) {
3301 MLIRContext *ctx = builder.getContext();
3302 SimdOp::build(builder, state, clauses.alignedVars,
3303 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
3304 clauses.linearVars, clauses.linearStepVars,
3305 clauses.linearVarTypes, clauses.linearModifiers,
3306 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3307 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3308 clauses.privateNeedsBarrier, clauses.reductionMod,
3309 clauses.reductionVars,
3310 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3311 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
3312 clauses.simdlen);
3313}
3314
3315LogicalResult SimdOp::verify() {
3316 if (getSimdlen().has_value() && getSafelen().has_value() &&
3317 getSimdlen().value() > getSafelen().value())
3318 return emitOpError()
3319 << "simdlen clause and safelen clause are both present, but the "
3320 "simdlen value is not less than or equal to safelen value";
3321
3322 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
3323 return failure();
3324
3325 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
3326 return failure();
3327
3328 if (failed(
3329 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3330 return failure();
3331
3332 bool isCompositeChildLeaf =
3333 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3334
3335 if (!isComposite() && isCompositeChildLeaf)
3336 return emitError()
3337 << "'omp.composite' attribute missing from composite wrapper";
3338
3339 if (isComposite() && !isCompositeChildLeaf)
3340 return emitError()
3341 << "'omp.composite' attribute present in non-composite wrapper";
3342
3343 // Firstprivate is not allowed for SIMD in the standard. Check that none of
3344 // the private decls are for firstprivate.
3345 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3346 if (privateSyms) {
3347 for (const Attribute &sym : *privateSyms) {
3348 auto symRef = cast<SymbolRefAttr>(sym);
3349 omp::PrivateClauseOp privatizer =
3351 getOperation(), symRef);
3352 if (!privatizer)
3353 return emitError() << "Cannot find privatizer '" << symRef << "'";
3354 if (privatizer.getDataSharingType() ==
3355 DataSharingClauseType::FirstPrivate)
3356 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
3357 }
3358 }
3359
3360 if (failed(verifyPrivateVarList(*this)))
3361 return failure();
3362
3363 if (getLinearVars().size() &&
3364 getLinearVarTypes().value().size() != getLinearVars().size())
3365 return emitError() << "Ill-formed type attributes for linear variables";
3366 return success();
3367}
3368
3369LogicalResult SimdOp::verifyRegions() {
3370 if (getNestedWrapper())
3371 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
3372
3373 return success();
3374}
3375
3376//===----------------------------------------------------------------------===//
3377// Distribute construct [2.9.4.1]
3378//===----------------------------------------------------------------------===//
3379
3380void DistributeOp::build(OpBuilder &builder, OperationState &state,
3381 const DistributeOperands &clauses) {
3382 DistributeOp::build(builder, state, clauses.allocateVars,
3383 clauses.allocatorVars, clauses.distScheduleStatic,
3384 clauses.distScheduleChunkSize, clauses.order,
3385 clauses.orderMod, clauses.privateVars,
3386 makeArrayAttr(builder.getContext(), clauses.privateSyms),
3387 clauses.privateNeedsBarrier);
3388}
3389
3390LogicalResult DistributeOp::verify() {
3391 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3392 return emitOpError() << "chunk size set without "
3393 "dist_schedule_static being present";
3394
3395 if (getAllocateVars().size() != getAllocatorVars().size())
3396 return emitError(
3397 "expected equal sizes for allocate and allocator variables");
3398
3399 if (failed(verifyPrivateVarList(*this)))
3400 return failure();
3401
3402 return success();
3403}
3404
3405LogicalResult DistributeOp::verifyRegions() {
3406 if (LoopWrapperInterface nested = getNestedWrapper()) {
3407 if (!isComposite())
3408 return emitError()
3409 << "'omp.composite' attribute missing from composite wrapper";
3410 // Check for the allowed leaf constructs that may appear in a composite
3411 // construct directly after DISTRIBUTE.
3412 if (isa<WsloopOp>(nested)) {
3413 Operation *parentOp = (*this)->getParentOp();
3414 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3415 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3416 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3417 "when a composite 'omp.parallel' is the direct "
3418 "parent";
3419 }
3420 } else if (!isa<SimdOp>(nested))
3421 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3422 "'omp.wsloop'";
3423 } else if (isComposite()) {
3424 return emitError()
3425 << "'omp.composite' attribute present in non-composite wrapper";
3426 }
3427
3428 return success();
3429}
3430
3431//===----------------------------------------------------------------------===//
3432// DeclareMapperOp / DeclareMapperInfoOp
3433//===----------------------------------------------------------------------===//
3434
3435LogicalResult DeclareMapperInfoOp::verify() {
3436 return verifyMapClause(*this, getMapVars());
3437}
3438
3439LogicalResult DeclareMapperOp::verifyRegions() {
3440 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3441 getRegion().getBlocks().front().getTerminator()))
3442 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3443
3444 return success();
3445}
3446
3447//===----------------------------------------------------------------------===//
3448// DeclareReductionOp
3449//===----------------------------------------------------------------------===//
3450
3451LogicalResult DeclareReductionOp::verifyRegions() {
3452 if (!getAllocRegion().empty()) {
3453 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3454 if (yieldOp.getResults().size() != 1 ||
3455 yieldOp.getResults().getTypes()[0] != getType())
3456 return emitOpError() << "expects alloc region to yield a value "
3457 "of the reduction type";
3458 }
3459 }
3460
3461 if (getInitializerRegion().empty())
3462 return emitOpError() << "expects non-empty initializer region";
3463 Block &initializerEntryBlock = getInitializerRegion().front();
3464
3465 if (initializerEntryBlock.getNumArguments() == 1) {
3466 if (!getAllocRegion().empty())
3467 return emitOpError() << "expects two arguments to the initializer region "
3468 "when an allocation region is used";
3469 } else if (initializerEntryBlock.getNumArguments() == 2) {
3470 if (getAllocRegion().empty())
3471 return emitOpError() << "expects one argument to the initializer region "
3472 "when no allocation region is used";
3473 } else {
3474 return emitOpError()
3475 << "expects one or two arguments to the initializer region";
3476 }
3477
3478 for (mlir::Value arg : initializerEntryBlock.getArguments())
3479 if (arg.getType() != getType())
3480 return emitOpError() << "expects initializer region argument to match "
3481 "the reduction type";
3482
3483 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3484 if (yieldOp.getResults().size() != 1 ||
3485 yieldOp.getResults().getTypes()[0] != getType())
3486 return emitOpError() << "expects initializer region to yield a value "
3487 "of the reduction type";
3488 }
3489
3490 if (getReductionRegion().empty())
3491 return emitOpError() << "expects non-empty reduction region";
3492 Block &reductionEntryBlock = getReductionRegion().front();
3493 if (reductionEntryBlock.getNumArguments() != 2 ||
3494 reductionEntryBlock.getArgumentTypes()[0] !=
3495 reductionEntryBlock.getArgumentTypes()[1] ||
3496 reductionEntryBlock.getArgumentTypes()[0] != getType())
3497 return emitOpError() << "expects reduction region with two arguments of "
3498 "the reduction type";
3499 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3500 if (yieldOp.getResults().size() != 1 ||
3501 yieldOp.getResults().getTypes()[0] != getType())
3502 return emitOpError() << "expects reduction region to yield a value "
3503 "of the reduction type";
3504 }
3505
3506 if (!getAtomicReductionRegion().empty()) {
3507 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3508 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3509 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3510 atomicReductionEntryBlock.getArgumentTypes()[1])
3511 return emitOpError() << "expects atomic reduction region with two "
3512 "arguments of the same type";
3513 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3514 atomicReductionEntryBlock.getArgumentTypes()[0]);
3515 if (!ptrType ||
3516 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3517 return emitOpError() << "expects atomic reduction region arguments to "
3518 "be accumulators containing the reduction type";
3519 }
3520
3521 if (getCleanupRegion().empty())
3522 return success();
3523 Block &cleanupEntryBlock = getCleanupRegion().front();
3524 if (cleanupEntryBlock.getNumArguments() != 1 ||
3525 cleanupEntryBlock.getArgument(0).getType() != getType())
3526 return emitOpError() << "expects cleanup region with one argument "
3527 "of the reduction type";
3528
3529 return success();
3530}
3531
3532//===----------------------------------------------------------------------===//
3533// TaskOp
3534//===----------------------------------------------------------------------===//
3535
3536void TaskOp::build(OpBuilder &builder, OperationState &state,
3537 const TaskOperands &clauses) {
3538 MLIRContext *ctx = builder.getContext();
3539 TaskOp::build(
3540 builder, state, clauses.iterated, clauses.affinityVars,
3541 clauses.allocateVars, clauses.allocatorVars,
3542 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3543 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3544 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3545 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3546 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3547 clauses.priority, /*private_vars=*/clauses.privateVars,
3548 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3549 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3550}
3551
3552LogicalResult TaskOp::verify() {
3553 LogicalResult verifyDependVars =
3554 verifyDependVarList(*this, getDependKinds(), getDependVars(),
3555 getDependIteratedKinds(), getDependIterated());
3556 if (failed(verifyDependVars))
3557 return verifyDependVars;
3558
3559 if (failed(verifyPrivateVarList(*this)))
3560 return failure();
3561
3562 return verifyReductionVarList(*this, getInReductionSyms(),
3563 getInReductionVars(), getInReductionByref());
3564}
3565
3566//===----------------------------------------------------------------------===//
3567// TaskgroupOp
3568//===----------------------------------------------------------------------===//
3569
3570void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3571 const TaskgroupOperands &clauses) {
3572 MLIRContext *ctx = builder.getContext();
3573 TaskgroupOp::build(builder, state, clauses.allocateVars,
3574 clauses.allocatorVars, clauses.taskReductionVars,
3575 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3576 makeArrayAttr(ctx, clauses.taskReductionSyms));
3577}
3578
3579LogicalResult TaskgroupOp::verify() {
3580 return verifyReductionVarList(*this, getTaskReductionSyms(),
3581 getTaskReductionVars(),
3582 getTaskReductionByref());
3583}
3584
3585//===----------------------------------------------------------------------===//
3586// TaskloopContextOp
3587//===----------------------------------------------------------------------===//
3588
3589void TaskloopContextOp::build(OpBuilder &builder, OperationState &state,
3590 const TaskloopContextOperands &clauses) {
3591 MLIRContext *ctx = builder.getContext();
3592 TaskloopContextOp::build(
3593 builder, state, clauses.allocateVars, clauses.allocatorVars,
3594 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3595 clauses.inReductionVars,
3596 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3597 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3598 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3599 /*private_vars=*/clauses.privateVars,
3600 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3601 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3602 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3603 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3604}
3605
3606TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3607 return cast<TaskloopWrapperOp>(
3608 *llvm::find_if(getRegion().front(), [](mlir::Operation &op) {
3609 return isa<TaskloopWrapperOp>(op);
3610 }));
3611}
3612
3613LogicalResult TaskloopContextOp::verify() {
3614 if (getAllocateVars().size() != getAllocatorVars().size())
3615 return emitError(
3616 "expected equal sizes for allocate and allocator variables");
3617
3618 if (failed(verifyPrivateVarList(*this)))
3619 return failure();
3620
3621 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3622 getReductionVars(), getReductionByref())) ||
3623 failed(verifyReductionVarList(*this, getInReductionSyms(),
3624 getInReductionVars(),
3625 getInReductionByref())))
3626 return failure();
3627
3628 if (!getReductionVars().empty() && getNogroup())
3629 return emitError("if a reduction clause is present on the taskloop "
3630 "directive, the nogroup clause must not be specified");
3631 for (auto var : getReductionVars()) {
3632 if (llvm::is_contained(getInReductionVars(), var))
3633 return emitError("the same list item cannot appear in both a reduction "
3634 "and an in_reduction clause");
3635 }
3636
3637 if (getGrainsize() && getNumTasks()) {
3638 return emitError(
3639 "the grainsize clause and num_tasks clause are mutually exclusive and "
3640 "may not appear on the same taskloop directive");
3641 }
3642
3643 return success();
3644}
3645
3646LogicalResult TaskloopContextOp::verifyRegions() {
3647 Region &region = getRegion();
3648 if (region.empty())
3649 return emitOpError() << "expected non-empty region";
3650
3651 auto count = llvm::count_if(region.front(), [](mlir::Operation &op) {
3652 return isa<TaskloopWrapperOp>(op);
3653 });
3654 if (count != 1)
3655 return emitOpError()
3656 << "expected exactly 1 TaskloopWrapperOp directly nested in "
3657 "the region, but "
3658 << count << " were found";
3659 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3660
3661 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3662 // This will fail the verifier for TaskloopWrapperOp and print an error
3663 // message there.
3664 if (!loopNestOp)
3665 return failure();
3666
3667 std::function<bool(Value)> isValidBoundValue = [&](Value value) -> bool {
3668 Region *valueRegion = value.getParentRegion();
3669 // A loop bound value defined outside of the taskloop context region is
3670 // valid. A region is considered an ancestor of itself.
3671 if (!region.isAncestor(valueRegion))
3672 return true;
3673
3674 Operation *defOp = value.getDefiningOp();
3675 if (!defOp || defOp->getNumRegions() != 0 || !isPure(defOp))
3676 return false;
3677
3678 return llvm::all_of(defOp->getOperands(), isValidBoundValue);
3679 };
3680 auto hasUnsupportedTaskloopLocalBound = [&](OperandRange range) -> bool {
3681 return llvm::any_of(range,
3682 [&](Value value) { return !isValidBoundValue(value); });
3683 };
3684
3685 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3686 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3687 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3688 return emitOpError()
3689 << "expects loop bounds and steps to be defined outside of the "
3690 "taskloop.context region or by pure, regionless operations "
3691 "that do not depend on block arguments";
3692 }
3693
3694 return success();
3695}
3696
3697//===----------------------------------------------------------------------===//
3698// TaskloopWrapperOp
3699//===----------------------------------------------------------------------===//
3700
3701void TaskloopWrapperOp::build(OpBuilder &builder, OperationState &state,
3702 const TaskloopWrapperOperands &clauses) {
3703 TaskloopWrapperOp::build(builder, state);
3704}
3705
3706TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3707 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3708}
3709
3710LogicalResult TaskloopWrapperOp::verify() {
3711 TaskloopContextOp context = getTaskloopContext();
3712 if (!context)
3713 return emitOpError() << "expected to be nested in a taskloop context op";
3714 return success();
3715}
3716
3717LogicalResult TaskloopWrapperOp::verifyRegions() {
3718 if (LoopWrapperInterface nested = getNestedWrapper()) {
3719 if (!isComposite())
3720 return emitError()
3721 << "'omp.composite' attribute missing from composite wrapper";
3722
3723 // Check for the allowed leaf constructs that may appear in a composite
3724 // construct directly after TASKLOOP.
3725 if (!isa<SimdOp>(nested))
3726 return emitError() << "only supported nested wrapper is 'omp.simd'";
3727 } else if (isComposite()) {
3728 return emitError()
3729 << "'omp.composite' attribute present in non-composite wrapper";
3730 }
3731
3732 return success();
3733}
3734
3735//===----------------------------------------------------------------------===//
3736// LoopNestOp
3737//===----------------------------------------------------------------------===//
3738
3739ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3740 // Parse an opening `(` followed by induction variables followed by `)`
3743 Type loopVarType;
3745 parser.parseColonType(loopVarType) ||
3746 // Parse loop bounds.
3747 parser.parseEqual() ||
3748 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3749 parser.parseKeyword("to") ||
3750 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3751 return failure();
3752
3753 for (auto &iv : ivs)
3754 iv.type = loopVarType;
3755
3756 auto *ctx = parser.getBuilder().getContext();
3757 // Parse "inclusive" flag.
3758 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3759 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3760
3761 // Parse step values.
3763 if (parser.parseKeyword("step") ||
3764 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3765 return failure();
3766
3767 // Parse collapse
3768 int64_t value = 0;
3769 if (!parser.parseOptionalKeyword("collapse") &&
3770 (parser.parseLParen() || parser.parseInteger(value) ||
3771 parser.parseRParen()))
3772 return failure();
3773 if (value > 1)
3774 result.addAttribute(
3775 "collapse_num_loops",
3776 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3777
3778 // Parse tiles
3780 auto parseTiles = [&]() -> ParseResult {
3781 int64_t tile;
3782 if (parser.parseInteger(tile))
3783 return failure();
3784 tiles.push_back(tile);
3785 return success();
3786 };
3787
3788 if (!parser.parseOptionalKeyword("tiles") &&
3789 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3790 parser.parseRParen()))
3791 return failure();
3792
3793 if (tiles.size() > 0)
3794 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3795
3796 // Parse the body.
3797 Region *region = result.addRegion();
3798 if (parser.parseRegion(*region, ivs))
3799 return failure();
3800
3801 // Resolve operands.
3802 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3803 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3804 parser.resolveOperands(steps, loopVarType, result.operands))
3805 return failure();
3806
3807 // Parse the optional attribute list.
3808 return parser.parseOptionalAttrDict(result.attributes);
3809}
3810
3811void LoopNestOp::print(OpAsmPrinter &p) {
3812 Region &region = getRegion();
3813 auto args = region.getArguments();
3814 p << " (" << args << ") : " << args[0].getType() << " = ("
3815 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3816 if (getLoopInclusive())
3817 p << "inclusive ";
3818 p << "step (" << getLoopSteps() << ") ";
3819 if (int64_t numCollapse = getCollapseNumLoops())
3820 if (numCollapse > 1)
3821 p << "collapse(" << numCollapse << ") ";
3822
3823 if (const auto tiles = getTileSizes())
3824 p << "tiles(" << tiles.value() << ") ";
3825
3826 p.printRegion(region, /*printEntryBlockArgs=*/false);
3827}
3828
3829void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3830 const LoopNestOperands &clauses) {
3831 MLIRContext *ctx = builder.getContext();
3832 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3833 clauses.loopLowerBounds, clauses.loopUpperBounds,
3834 clauses.loopSteps, clauses.loopInclusive,
3835 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3836}
3837
3838LogicalResult LoopNestOp::verify() {
3839 if (getLoopLowerBounds().empty())
3840 return emitOpError() << "must represent at least one loop";
3841
3842 if (getLoopLowerBounds().size() != getIVs().size())
3843 return emitOpError() << "number of range arguments and IVs do not match";
3844
3845 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3846 if (lb.getType() != iv.getType())
3847 return emitOpError()
3848 << "range argument type does not match corresponding IV type";
3849 }
3850
3851 uint64_t numIVs = getIVs().size();
3852
3853 if (const auto &numCollapse = getCollapseNumLoops())
3854 if (numCollapse > numIVs)
3855 return emitOpError()
3856 << "collapse value is larger than the number of loops";
3857
3858 if (const auto &tiles = getTileSizes())
3859 if (tiles.value().size() > numIVs)
3860 return emitOpError() << "too few canonical loops for tile dimensions";
3861
3862 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3863 return emitOpError() << "expects parent op to be a loop wrapper";
3864
3865 return success();
3866}
3867
3868void LoopNestOp::gatherWrappers(
3870 Operation *parent = (*this)->getParentOp();
3871 while (auto wrapper =
3872 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3873 wrappers.push_back(wrapper);
3874 parent = parent->getParentOp();
3875 }
3876}
3877
3878//===----------------------------------------------------------------------===//
3879// OpenMP canonical loop handling
3880//===----------------------------------------------------------------------===//
3881
3882std::tuple<NewCliOp, OpOperand *, OpOperand *>
3883mlir::omp ::decodeCli(Value cli) {
3884
3885 // Defining a CLI for a generated loop is optional; if there is none then
3886 // there is no followup-tranformation
3887 if (!cli)
3888 return {{}, nullptr, nullptr};
3889
3890 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3891 "Unexpected type of cli");
3892
3893 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3894 OpOperand *gen = nullptr;
3895 OpOperand *cons = nullptr;
3896 for (OpOperand &use : cli.getUses()) {
3897 auto op = cast<LoopTransformationInterface>(use.getOwner());
3898
3899 unsigned opnum = use.getOperandNumber();
3900 if (op.isGeneratee(opnum)) {
3901 assert(!gen && "Each CLI may have at most one def");
3902 gen = &use;
3903 } else if (op.isApplyee(opnum)) {
3904 assert(!cons && "Each CLI may have at most one consumer");
3905 cons = &use;
3906 } else {
3907 llvm_unreachable("Unexpected operand for a CLI");
3908 }
3909 }
3910
3911 return {create, gen, cons};
3912}
3913
3914void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3915 ::mlir::OperationState &odsState) {
3916 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3917}
3918
3919void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3920 Value result = getResult();
3921 auto [newCli, gen, cons] = decodeCli(result);
3922
3923 // Structured binding `gen` cannot be captured in lambdas before C++20
3924 OpOperand *generator = gen;
3925
3926 // Derive the CLI variable name from its generator:
3927 // * "canonloop" for omp.canonical_loop
3928 // * custom name for loop transformation generatees
3929 // * "cli" as fallback if no generator
3930 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3931 // at that level
3932 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3933 // the index of that region
3934 std::string cliName{"cli"};
3935 if (gen) {
3936 cliName =
3938 .Case([&](CanonicalLoopOp op) {
3939 return generateLoopNestingName("canonloop", op);
3940 })
3941 .Case([&](UnrollHeuristicOp op) -> std::string {
3942 llvm_unreachable("heuristic unrolling does not generate a loop");
3943 })
3944 .Case([&](FuseOp op) -> std::string {
3945 unsigned opnum = generator->getOperandNumber();
3946 // The position of the first loop to be fused is the same position
3947 // as the resulting fused loop
3948 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3949 return "canonloop_fuse";
3950 else
3951 return "fused";
3952 })
3953 .Case([&](TileOp op) -> std::string {
3954 auto [generateesFirst, generateesCount] =
3955 op.getGenerateesODSOperandIndexAndLength();
3956 unsigned firstGrid = generateesFirst;
3957 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3958 unsigned end = generateesFirst + generateesCount;
3959 unsigned opnum = generator->getOperandNumber();
3960 // In the OpenMP apply and looprange clauses, indices are 1-based
3961 if (firstGrid <= opnum && opnum < firstIntratile) {
3962 unsigned gridnum = opnum - firstGrid + 1;
3963 return ("grid" + Twine(gridnum)).str();
3964 }
3965 if (firstIntratile <= opnum && opnum < end) {
3966 unsigned intratilenum = opnum - firstIntratile + 1;
3967 return ("intratile" + Twine(intratilenum)).str();
3968 }
3969 llvm_unreachable("Unexpected generatee argument");
3970 })
3971 .DefaultUnreachable("TODO: Custom name for this operation");
3972 }
3973
3974 setNameFn(result, cliName);
3975}
3976
3977LogicalResult NewCliOp::verify() {
3978 Value cli = getResult();
3979
3980 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3981 "Unexpected type of cli");
3982
3983 // Check that the CLI is used in at most generator and one consumer
3984 OpOperand *gen = nullptr;
3985 OpOperand *cons = nullptr;
3986 for (mlir::OpOperand &use : cli.getUses()) {
3987 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3988
3989 unsigned opnum = use.getOperandNumber();
3990 if (op.isGeneratee(opnum)) {
3991 if (gen) {
3992 InFlightDiagnostic error =
3993 emitOpError("CLI must have at most one generator");
3994 error.attachNote(gen->getOwner()->getLoc())
3995 .append("first generator here:");
3996 error.attachNote(use.getOwner()->getLoc())
3997 .append("second generator here:");
3998 return error;
3999 }
4000
4001 gen = &use;
4002 } else if (op.isApplyee(opnum)) {
4003 if (cons) {
4004 InFlightDiagnostic error =
4005 emitOpError("CLI must have at most one consumer");
4006 error.attachNote(cons->getOwner()->getLoc())
4007 .append("first consumer here:")
4008 .appendOp(*cons->getOwner(),
4009 OpPrintingFlags().printGenericOpForm());
4010 error.attachNote(use.getOwner()->getLoc())
4011 .append("second consumer here:")
4012 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
4013 return error;
4014 }
4015
4016 cons = &use;
4017 } else {
4018 llvm_unreachable("Unexpected operand for a CLI");
4019 }
4020 }
4021
4022 // If the CLI is source of a transformation, it must have a generator
4023 if (cons && !gen) {
4024 InFlightDiagnostic error = emitOpError("CLI has no generator");
4025 error.attachNote(cons->getOwner()->getLoc())
4026 .append("see consumer here: ")
4027 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
4028 return error;
4029 }
4030
4031 return success();
4032}
4033
4034void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4035 Value tripCount) {
4036 odsState.addOperands(tripCount);
4037 odsState.addOperands(Value());
4038 (void)odsState.addRegion();
4039}
4040
4041void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4042 Value tripCount, ::mlir::Value cli) {
4043 odsState.addOperands(tripCount);
4044 odsState.addOperands(cli);
4045 (void)odsState.addRegion();
4046}
4047
4048void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
4049 setNameFn(&getRegion().front(), "body_entry");
4050}
4051
4052void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
4053 OpAsmSetValueNameFn setNameFn) {
4054 std::string ivName = generateLoopNestingName("iv", *this);
4055 setNameFn(region.getArgument(0), ivName);
4056}
4057
4058void CanonicalLoopOp::print(OpAsmPrinter &p) {
4059 if (getCli())
4060 p << '(' << getCli() << ')';
4061 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
4062 << " in range(" << getTripCount() << ") ";
4063
4064 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4065 /*printBlockTerminators=*/true);
4066
4067 p.printOptionalAttrDict((*this)->getAttrs());
4068}
4069
4070mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
4072 CanonicalLoopInfoType cliType =
4073 CanonicalLoopInfoType::get(parser.getContext());
4074
4075 // Parse (optional) omp.cli identifier
4077 SmallVector<mlir::Value, 1> cliOperand;
4078 if (!parser.parseOptionalLParen()) {
4079 if (parser.parseOperand(cli) ||
4080 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
4081 return failure();
4082 }
4083
4084 // We derive the type of tripCount from inductionVariable. MLIR requires the
4085 // type of tripCount to be known when calling resolveOperand so we have parse
4086 // the type before processing the inductionVariable.
4087 OpAsmParser::Argument inductionVariable;
4089 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
4090 parser.parseKeyword("in") || parser.parseKeyword("range") ||
4091 parser.parseLParen() || parser.parseOperand(tripcount) ||
4092 parser.parseRParen() ||
4093 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
4094 return failure();
4095
4096 // Parse the loop body.
4097 Region *region = result.addRegion();
4098 if (parser.parseRegion(*region, {inductionVariable}))
4099 return failure();
4100
4101 // We parsed the cli operand forst, but because it is optional, it must be
4102 // last in the operand list.
4103 result.operands.append(cliOperand);
4104
4105 // Parse the optional attribute list.
4106 if (parser.parseOptionalAttrDict(result.attributes))
4107 return failure();
4108
4109 return mlir::success();
4110}
4111
4112LogicalResult CanonicalLoopOp::verify() {
4113 // The region's entry must accept the induction variable
4114 // It can also be empty if just created
4115 if (!getRegion().empty()) {
4116 Region &region = getRegion();
4117 if (region.getNumArguments() != 1)
4118 return emitOpError(
4119 "Canonical loop region must have exactly one argument");
4120
4121 if (getInductionVar().getType() != getTripCount().getType())
4122 return emitOpError(
4123 "Region argument must be the same type as the trip count");
4124 }
4125
4126 return success();
4127}
4128
4129Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
4130
4131std::pair<unsigned, unsigned>
4132CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
4133 // No applyees
4134 return {0, 0};
4135}
4136
4137std::pair<unsigned, unsigned>
4138CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
4139 return getODSOperandIndexAndLength(odsIndex_cli);
4140}
4141
4142//===----------------------------------------------------------------------===//
4143// UnrollHeuristicOp
4144//===----------------------------------------------------------------------===//
4145
4146void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
4147 ::mlir::OperationState &odsState,
4148 ::mlir::Value cli) {
4149 odsState.addOperands(cli);
4150}
4151
4152void UnrollHeuristicOp::print(OpAsmPrinter &p) {
4153 p << '(' << getApplyee() << ')';
4154
4155 p.printOptionalAttrDict((*this)->getAttrs());
4156}
4157
4158mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
4160 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
4161
4162 if (parser.parseLParen())
4163 return failure();
4164
4166 if (parser.parseOperand(applyee) ||
4167 parser.resolveOperand(applyee, cliType, result.operands))
4168 return failure();
4169
4170 if (parser.parseRParen())
4171 return failure();
4172
4173 // Optional output loop (full unrolling has none)
4174 if (!parser.parseOptionalArrow()) {
4175 if (parser.parseLParen() || parser.parseRParen())
4176 return failure();
4177 }
4178
4179 // Parse the optional attribute list.
4180 if (parser.parseOptionalAttrDict(result.attributes))
4181 return failure();
4182
4183 return mlir::success();
4184}
4185
4186std::pair<unsigned, unsigned>
4187UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
4188 return getODSOperandIndexAndLength(odsIndex_applyee);
4189}
4190
4191std::pair<unsigned, unsigned>
4192UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
4193 return {0, 0};
4194}
4195
4196//===----------------------------------------------------------------------===//
4197// TileOp
4198//===----------------------------------------------------------------------===//
4199
4200static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
4201 OperandRange generatees,
4202 OperandRange applyees) {
4203 if (!generatees.empty())
4204 p << '(' << llvm::interleaved(generatees) << ')';
4205
4206 if (!applyees.empty())
4207 p << " <- (" << llvm::interleaved(applyees) << ')';
4208}
4209
4210static ParseResult parseLoopTransformClis(
4211 OpAsmParser &parser,
4214 if (parser.parseOptionalLess()) {
4215 // Syntax 1: generatees present
4216
4217 if (parser.parseOperandList(generateesOperands,
4219 return failure();
4220
4221 if (parser.parseLess())
4222 return failure();
4223 } else {
4224 // Syntax 2: generatees omitted
4225 }
4226
4227 // Parse `<-` (`<` has already been parsed)
4228 if (parser.parseMinus())
4229 return failure();
4230
4231 if (parser.parseOperandList(applyeesOperands,
4233 return failure();
4234
4235 return success();
4236}
4237
4238/// Check properties of the loop nest consisting of the transformation's
4239/// applyees:
4240/// 1. They are nested inside each other
4241/// 2. They are perfectly nested
4242/// (no code with side-effects in-between the loops)
4243/// 3. They are rectangular
4244/// (loop bounds are invariant in respect to the outer loops)
4245///
4246/// TODO: Generalize for LoopTransformationInterface.
4247static LogicalResult checkApplyeesNesting(TileOp op) {
4248 // Collect the loops from the nest
4249 bool isOnlyCanonLoops = true;
4251 for (Value applyee : op.getApplyees()) {
4252 auto [create, gen, cons] = decodeCli(applyee);
4253
4254 if (!gen)
4255 return op.emitOpError() << "applyee CLI has no generator";
4256
4257 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4258 canonLoops.push_back(loop);
4259 if (!loop)
4260 isOnlyCanonLoops = false;
4261 }
4262
4263 // FIXME: We currently can only verify non-rectangularity and perfect nest of
4264 // omp.canonical_loop.
4265 if (!isOnlyCanonLoops)
4266 return success();
4267
4268 DenseSet<Value> parentIVs;
4269 for (auto i : llvm::seq<int>(1, canonLoops.size())) {
4270 auto parentLoop = canonLoops[i - 1];
4271 auto loop = canonLoops[i];
4272
4273 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4274 return op.emitOpError()
4275 << "tiled loop nest must be nested within each other";
4276
4277 parentIVs.insert(parentLoop.getInductionVar());
4278
4279 // Canonical loop must be perfectly nested, i.e. the body of the parent must
4280 // only contain the omp.canonical_loop of the nested loops, and
4281 // omp.terminator
4282 bool isPerfectlyNested = [&]() {
4283 auto &parentBody = parentLoop.getRegion();
4284 if (!parentBody.hasOneBlock())
4285 return false;
4286 auto &parentBlock = parentBody.getBlocks().front();
4287
4288 auto nestedLoopIt = parentBlock.begin();
4289 if (nestedLoopIt == parentBlock.end() ||
4290 (&*nestedLoopIt != loop.getOperation()))
4291 return false;
4292
4293 auto termIt = std::next(nestedLoopIt);
4294 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4295 return false;
4296
4297 if (std::next(termIt) != parentBlock.end())
4298 return false;
4299
4300 return true;
4301 }();
4302 if (!isPerfectlyNested)
4303 return op.emitOpError() << "tiled loop nest must be perfectly nested";
4304
4305 if (parentIVs.contains(loop.getTripCount()))
4306 return op.emitOpError() << "tiled loop nest must be rectangular";
4307 }
4308
4309 // TODO: The tile sizes must be computed before the loop, but checking this
4310 // requires dominance analysis. For instance:
4311 //
4312 // %canonloop = omp.new_cli
4313 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
4314 // // write to %x
4315 // omp.terminator
4316 // }
4317 // %ts = llvm.load %x
4318 // omp.tile <- (%canonloop) sizes(%ts : i32)
4319
4320 return success();
4321}
4322
4323LogicalResult TileOp::verify() {
4324 if (getApplyees().empty())
4325 return emitOpError() << "must apply to at least one loop";
4326
4327 if (getSizes().size() != getApplyees().size())
4328 return emitOpError() << "there must be one tile size for each applyee";
4329
4330 if (!getGeneratees().empty() &&
4331 2 * getSizes().size() != getGeneratees().size())
4332 return emitOpError()
4333 << "expecting two times the number of generatees than applyees";
4334
4335 return checkApplyeesNesting(*this);
4336}
4337
4338std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4339 return getODSOperandIndexAndLength(odsIndex_applyees);
4340}
4341
4342std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4343 return getODSOperandIndexAndLength(odsIndex_generatees);
4344}
4345
4346//===----------------------------------------------------------------------===//
4347// FuseOp
4348//===----------------------------------------------------------------------===//
4349
4350static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
4351 OperandRange generatees,
4352 OperandRange applyees) {
4353 if (!generatees.empty())
4354 p << '(' << llvm::interleaved(generatees) << ')';
4355
4356 if (!applyees.empty())
4357 p << " <- (" << llvm::interleaved(applyees) << ')';
4358}
4359
4360LogicalResult FuseOp::verify() {
4361 if (getApplyees().size() < 2)
4362 return emitOpError() << "must apply to at least two loops";
4363
4364 if (getFirst().has_value() && getCount().has_value()) {
4365 int64_t first = getFirst().value();
4366 int64_t count = getCount().value();
4367 if ((unsigned)(first + count - 1) > getApplyees().size())
4368 return emitOpError() << "the numbers of applyees must be at least first "
4369 "minus one plus count attributes";
4370 if (!getGeneratees().empty() &&
4371 getGeneratees().size() != getApplyees().size() + 1 - count)
4372 return emitOpError() << "the number of generatees must be the number of "
4373 "aplyees plus one minus count";
4374
4375 } else {
4376 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4377 return emitOpError()
4378 << "in a complete fuse the number of generatees must be exactly 1";
4379 }
4380 for (auto &&applyee : getApplyees()) {
4381 auto [create, gen, cons] = decodeCli(applyee);
4382
4383 if (!gen)
4384 return emitOpError() << "applyee CLI has no generator";
4385 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4386 if (!loop)
4387 return emitOpError()
4388 << "currently only supports omp.canonical_loop as applyee";
4389 }
4390 return success();
4391}
4392std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4393 return getODSOperandIndexAndLength(odsIndex_applyees);
4394}
4395
4396std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4397 return getODSOperandIndexAndLength(odsIndex_generatees);
4398}
4399
4400//===----------------------------------------------------------------------===//
4401// Critical construct (2.17.1)
4402//===----------------------------------------------------------------------===//
4403
4404void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
4405 const CriticalDeclareOperands &clauses) {
4406 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4407}
4408
4409LogicalResult CriticalDeclareOp::verify() {
4410 return verifySynchronizationHint(*this, getHint());
4411}
4412
4413LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4414 if (getNameAttr()) {
4415 SymbolRefAttr symbolRef = getNameAttr();
4416 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
4417 *this, symbolRef);
4418 if (!decl) {
4419 return emitOpError() << "expected symbol reference " << symbolRef
4420 << " to point to a critical declaration";
4421 }
4422 }
4423
4424 return success();
4425}
4426
4427//===----------------------------------------------------------------------===//
4428// Ordered construct
4429//===----------------------------------------------------------------------===//
4430
4431static LogicalResult verifyOrderedParent(Operation &op) {
4432 bool hasRegion = op.getNumRegions() > 0;
4433 auto loopOp = op.getParentOfType<LoopNestOp>();
4434 if (!loopOp) {
4435 if (hasRegion)
4436 return success();
4437
4438 // TODO: Consider if this needs to be the case only for the standalone
4439 // variant of the ordered construct.
4440 return op.emitOpError() << "must be nested inside of a loop";
4441 }
4442
4443 Operation *wrapper = loopOp->getParentOp();
4444 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4445 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4446 if (!orderedAttr)
4447 return op.emitOpError() << "the enclosing worksharing-loop region must "
4448 "have an ordered clause";
4449
4450 if (hasRegion && orderedAttr.getInt() != 0)
4451 return op.emitOpError() << "the enclosing loop's ordered clause must not "
4452 "have a parameter present";
4453
4454 if (!hasRegion && orderedAttr.getInt() == 0)
4455 return op.emitOpError() << "the enclosing loop's ordered clause must "
4456 "have a parameter present";
4457 } else if (!isa<SimdOp>(wrapper)) {
4458 return op.emitOpError() << "must be nested inside of a worksharing, simd "
4459 "or worksharing simd loop";
4460 }
4461 return success();
4462}
4463
4464void OrderedOp::build(OpBuilder &builder, OperationState &state,
4465 const OrderedOperands &clauses) {
4466 OrderedOp::build(builder, state, clauses.doacrossDependType,
4467 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4468}
4469
4470LogicalResult OrderedOp::verify() {
4471 if (failed(verifyOrderedParent(**this)))
4472 return failure();
4473
4474 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4475 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4476 return emitOpError() << "number of variables in depend clause does not "
4477 << "match number of iteration variables in the "
4478 << "doacross loop";
4479
4480 return success();
4481}
4482
4483void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
4484 const OrderedRegionOperands &clauses) {
4485 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4486}
4487
4488LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
4489
4490//===----------------------------------------------------------------------===//
4491// TaskwaitOp
4492//===----------------------------------------------------------------------===//
4493
4494void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
4495 const TaskwaitOperands &clauses) {
4496 // TODO Store clauses in op: dependKinds, dependVars, nowait.
4497 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
4498 /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr,
4499 /*depend_iterated=*/{}, /*nowait=*/nullptr);
4500}
4501
4502//===----------------------------------------------------------------------===//
4503// Verifier for AtomicReadOp
4504//===----------------------------------------------------------------------===//
4505
4506LogicalResult AtomicReadOp::verify() {
4507 if (verifyCommon().failed())
4508 return mlir::failure();
4509
4510 int64_t version = 50;
4511 if (auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4512 if (Attribute verAttr = moduleOp->getAttr("omp.version"))
4513 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4514
4515 if (auto mo = getMemoryOrder()) {
4516 if (*mo == ClauseMemoryOrderKind::Release) {
4517 return emitError("memory-order must not be release for atomic reads");
4518 }
4519 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4520 // acq_rel is prohibited on read only in OpenMP 5.0; allowed in 5.1+.
4521 if (version < 51)
4522 return emitError("memory-order must not be acq_rel for atomic reads");
4523 }
4524 }
4525 return verifySynchronizationHint(*this, getHint());
4526}
4527
4528//===----------------------------------------------------------------------===//
4529// Verifier for AtomicWriteOp
4530//===----------------------------------------------------------------------===//
4531
4532LogicalResult AtomicWriteOp::verify() {
4533 if (verifyCommon().failed())
4534 return mlir::failure();
4535
4536 int64_t version = 50;
4537 if (auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4538 if (Attribute verAttr = moduleOp->getAttr("omp.version"))
4539 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4540
4541 if (auto mo = getMemoryOrder()) {
4542 if (*mo == ClauseMemoryOrderKind::Acquire) {
4543 return emitError("memory-order must not be acquire for atomic writes");
4544 }
4545 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4546 // acq_rel is prohibited on write only in OpenMP 5.0; allowed in 5.1+.
4547 if (version < 51)
4548 return emitError("memory-order must not be acq_rel for atomic writes");
4549 }
4550 }
4551 return verifySynchronizationHint(*this, getHint());
4552}
4553
4554//===----------------------------------------------------------------------===//
4555// Verifier for AtomicUpdateOp
4556//===----------------------------------------------------------------------===//
4557
4558LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4559 PatternRewriter &rewriter) {
4560 if (op.isNoOp()) {
4561 rewriter.eraseOp(op);
4562 return success();
4563 }
4564 if (Value writeVal = op.getWriteOpVal()) {
4565 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4566 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4567 return success();
4568 }
4569 return failure();
4570}
4571
4572LogicalResult AtomicUpdateOp::verify() {
4573 if (verifyCommon().failed())
4574 return mlir::failure();
4575
4576 int64_t version = 50;
4577 if (auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4578 if (Attribute verAttr = moduleOp->getAttr("omp.version"))
4579 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4580
4581 if (auto mo = getMemoryOrder()) {
4582 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4583 *mo == ClauseMemoryOrderKind::Acquire) {
4584 // This restriction applies only to OpenMP 5.0; removed in 5.1.
4585 if (version < 51)
4586 return emitError(
4587 "memory-order must not be acq_rel or acquire for atomic updates");
4588 }
4589 }
4590
4591 return verifySynchronizationHint(*this, getHint());
4592}
4593
4594LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4595
4596//===----------------------------------------------------------------------===//
4597// Verifier for AtomicCaptureOp
4598//===----------------------------------------------------------------------===//
4599
4600AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4601 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4602 return op;
4603 return dyn_cast<AtomicReadOp>(getSecondOp());
4604}
4605
4606AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4607 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4608 return op;
4609 return dyn_cast<AtomicWriteOp>(getSecondOp());
4610}
4611
4612AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4613 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4614 return op;
4615 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4616}
4617
4618LogicalResult AtomicCaptureOp::verify() {
4619 return verifySynchronizationHint(*this, getHint());
4620}
4621
4622LogicalResult AtomicCaptureOp::verifyRegions() {
4623 if (verifyRegionsCommon().failed())
4624 return mlir::failure();
4625
4626 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4627 return emitOpError(
4628 "operations inside capture region must not have hint clause");
4629
4630 if (getFirstOp()->getAttr("memory_order") ||
4631 getSecondOp()->getAttr("memory_order"))
4632 return emitOpError(
4633 "operations inside capture region must not have memory_order clause");
4634 return success();
4635}
4636
4637//===----------------------------------------------------------------------===//
4638// AtomicCompareOp
4639//===----------------------------------------------------------------------===//
4640
4641LogicalResult AtomicCompareOp::verify() {
4642 if (verifyCommon().failed())
4643 return mlir::failure();
4644 return verifySynchronizationHint(*this, getHint());
4645}
4646
4647LogicalResult AtomicCompareOp::verifyRegions() {
4648 if (verifyRegionsCommon().failed())
4649 return mlir::failure();
4650
4651 if (verifyOperator().failed())
4652 return mlir::failure();
4653
4654 Block &block = getRegion().front();
4655
4656 Operation *terminator = block.getTerminator();
4657 if (!terminator || !isa<YieldOp>(terminator))
4658 return emitOpError("region must be terminated with omp.yield");
4659
4660 return success();
4661}
4662
4663//===----------------------------------------------------------------------===//
4664// CancelOp
4665//===----------------------------------------------------------------------===//
4666
4667void CancelOp::build(OpBuilder &builder, OperationState &state,
4668 const CancelOperands &clauses) {
4669 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4670}
4671
4673 Operation *parent = thisOp->getParentOp();
4674 while (parent) {
4675 if (parent->getDialect() == thisOp->getDialect())
4676 return parent;
4677 parent = parent->getParentOp();
4678 }
4679 return nullptr;
4680}
4681
4682LogicalResult CancelOp::verify() {
4683 ClauseCancellationConstructType cct = getCancelDirective();
4684 // The next OpenMP operation in the chain of parents
4685 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4686 if (!structuralParent)
4687 return emitOpError() << "Orphaned cancel construct";
4688
4689 if ((cct == ClauseCancellationConstructType::Parallel) &&
4690 !mlir::isa<ParallelOp>(structuralParent)) {
4691 return emitOpError() << "cancel parallel must appear "
4692 << "inside a parallel region";
4693 }
4694 if (cct == ClauseCancellationConstructType::Loop) {
4695 // structural parent will be omp.loop_nest, directly nested inside
4696 // omp.wsloop
4697 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4698
4699 if (!wsloopOp) {
4700 return emitOpError()
4701 << "cancel loop must appear inside a worksharing-loop region";
4702 }
4703 if (wsloopOp.getNowaitAttr()) {
4704 return emitError() << "A worksharing construct that is canceled "
4705 << "must not have a nowait clause";
4706 }
4707 if (wsloopOp.getOrderedAttr()) {
4708 return emitError() << "A worksharing construct that is canceled "
4709 << "must not have an ordered clause";
4710 }
4711
4712 } else if (cct == ClauseCancellationConstructType::Sections) {
4713 // structural parent will be an omp.section, directly nested inside
4714 // omp.sections
4715 auto sectionsOp =
4716 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4717 if (!sectionsOp) {
4718 return emitOpError() << "cancel sections must appear "
4719 << "inside a sections region";
4720 }
4721 if (sectionsOp.getNowait()) {
4722 return emitError() << "A sections construct that is canceled "
4723 << "must not have a nowait clause";
4724 }
4725 }
4726 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4727 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4728 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4729 return emitOpError() << "cancel taskgroup must appear "
4730 << "inside a task region";
4731 }
4732 return success();
4733}
4734
4735//===----------------------------------------------------------------------===//
4736// CancellationPointOp
4737//===----------------------------------------------------------------------===//
4738
4739void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4740 const CancellationPointOperands &clauses) {
4741 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4742}
4743
4744LogicalResult CancellationPointOp::verify() {
4745 ClauseCancellationConstructType cct = getCancelDirective();
4746 // The next OpenMP operation in the chain of parents
4747 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4748 if (!structuralParent)
4749 return emitOpError() << "Orphaned cancellation point";
4750
4751 if ((cct == ClauseCancellationConstructType::Parallel) &&
4752 !mlir::isa<ParallelOp>(structuralParent)) {
4753 return emitOpError() << "cancellation point parallel must appear "
4754 << "inside a parallel region";
4755 }
4756 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4757 // find the wsloop
4758 if ((cct == ClauseCancellationConstructType::Loop) &&
4759 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4760 return emitOpError() << "cancellation point loop must appear "
4761 << "inside a worksharing-loop region";
4762 }
4763 if ((cct == ClauseCancellationConstructType::Sections) &&
4764 !mlir::isa<omp::SectionOp>(structuralParent)) {
4765 return emitOpError() << "cancellation point sections must appear "
4766 << "inside a sections region";
4767 }
4768 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4769 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4770 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4771 return emitOpError() << "cancellation point taskgroup must appear "
4772 << "inside a task region";
4773 }
4774 return success();
4775}
4776
4777//===----------------------------------------------------------------------===//
4778// MapBoundsOp
4779//===----------------------------------------------------------------------===//
4780
4781LogicalResult MapBoundsOp::verify() {
4782 auto extent = getExtent();
4783 auto upperbound = getUpperBound();
4784 if (!extent && !upperbound)
4785 return emitError("expected extent or upperbound.");
4786 return success();
4787}
4788
4789void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4790 TypeRange /*result_types*/, StringAttr symName,
4791 TypeAttr type) {
4792 PrivateClauseOp::build(
4793 odsBuilder, odsState, symName, type,
4794 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4795 DataSharingClauseType::Private));
4796}
4797
4798LogicalResult PrivateClauseOp::verifyRegions() {
4799 Type argType = getArgType();
4800 auto verifyTerminator = [&](Operation *terminator,
4801 bool yieldsValue) -> LogicalResult {
4802 if (!terminator->getBlock()->getSuccessors().empty())
4803 return success();
4804
4805 if (!llvm::isa<YieldOp>(terminator))
4806 return mlir::emitError(terminator->getLoc())
4807 << "expected exit block terminator to be an `omp.yield` op.";
4808
4809 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4810 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4811
4812 if (!yieldsValue) {
4813 if (yieldedTypes.empty())
4814 return success();
4815
4816 return mlir::emitError(terminator->getLoc())
4817 << "Did not expect any values to be yielded.";
4818 }
4819
4820 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4821 return success();
4822
4823 auto error = mlir::emitError(yieldOp.getLoc())
4824 << "Invalid yielded value. Expected type: " << argType
4825 << ", got: ";
4826
4827 if (yieldedTypes.empty())
4828 error << "None";
4829 else
4830 error << yieldedTypes;
4831
4832 return error;
4833 };
4834
4835 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4836 StringRef regionName,
4837 bool yieldsValue) -> LogicalResult {
4838 assert(!region.empty());
4839
4840 if (region.getNumArguments() != expectedNumArgs)
4841 return mlir::emitError(region.getLoc())
4842 << "`" << regionName << "`: "
4843 << "expected " << expectedNumArgs
4844 << " region arguments, got: " << region.getNumArguments();
4845
4846 for (Block &block : region) {
4847 // MLIR will verify the absence of the terminator for us.
4848 if (!block.mightHaveTerminator())
4849 continue;
4850
4851 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4852 return failure();
4853 }
4854
4855 return success();
4856 };
4857
4858 // Ensure all of the region arguments have the same type
4859 for (Region *region : getRegions())
4860 for (Type ty : region->getArgumentTypes())
4861 if (ty != argType)
4862 return emitError() << "Region argument type mismatch: got " << ty
4863 << " expected " << argType << ".";
4864
4865 mlir::Region &initRegion = getInitRegion();
4866 if (!initRegion.empty() &&
4867 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4868 /*yieldsValue=*/true)))
4869 return failure();
4870
4871 DataSharingClauseType dsType = getDataSharingType();
4872
4873 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4874 return emitError("`private` clauses do not require a `copy` region.");
4875
4876 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4877 return emitError(
4878 "`firstprivate` clauses require at least a `copy` region.");
4879
4880 if (dsType == DataSharingClauseType::FirstPrivate &&
4881 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4882 /*yieldsValue=*/true)))
4883 return failure();
4884
4885 if (!getDeallocRegion().empty() &&
4886 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4887 /*yieldsValue=*/false)))
4888 return failure();
4889
4890 return success();
4891}
4892
4893//===----------------------------------------------------------------------===//
4894// Spec 5.2: Masked construct (10.5)
4895//===----------------------------------------------------------------------===//
4896
4897void MaskedOp::build(OpBuilder &builder, OperationState &state,
4898 const MaskedOperands &clauses) {
4899 MaskedOp::build(builder, state, clauses.filteredThreadId);
4900}
4901
4902//===----------------------------------------------------------------------===//
4903// Spec 5.2: Scan construct (5.6)
4904//===----------------------------------------------------------------------===//
4905
4906void ScanOp::build(OpBuilder &builder, OperationState &state,
4907 const ScanOperands &clauses) {
4908 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4909}
4910
4911LogicalResult ScanOp::verify() {
4912 if (hasExclusiveVars() == hasInclusiveVars())
4913 return emitError(
4914 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4915 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4916 if (parentWsLoopOp.getReductionModAttr() &&
4917 parentWsLoopOp.getReductionModAttr().getValue() ==
4918 ReductionModifier::inscan)
4919 return success();
4920 }
4921 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4922 if (parentSimdOp.getReductionModAttr() &&
4923 parentSimdOp.getReductionModAttr().getValue() ==
4924 ReductionModifier::inscan)
4925 return success();
4926 }
4927 return emitError("SCAN directive needs to be enclosed within a parent "
4928 "worksharing loop construct or SIMD construct with INSCAN "
4929 "reduction modifier");
4930}
4931
4932/// Verifies align clause in allocate directive
4933LogicalResult verifyAlignment(Operation &op,
4934 std::optional<uint64_t> alignment) {
4935 if (alignment.has_value()) {
4936 if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
4937 return op.emitError()
4938 << "ALIGN value : " << alignment.value() << " must be power of 2";
4939 }
4940 return success();
4941}
4942
4943LogicalResult AllocateDirOp::verify() {
4944 return verifyAlignment(*getOperation(), getAlign());
4945}
4946
4947//===----------------------------------------------------------------------===//
4948// AllocSharedMemOp
4949//===----------------------------------------------------------------------===//
4950
4951LogicalResult AllocSharedMemOp::verify() {
4952 return verifyAlignment(*getOperation(), getMemAlignment());
4953}
4954
4955//===----------------------------------------------------------------------===//
4956// FreeSharedMemOp
4957//===----------------------------------------------------------------------===//
4958
4959LogicalResult FreeSharedMemOp::verify() {
4960 return verifyAlignment(*getOperation(), getMemAlignment());
4961}
4962
4963//===----------------------------------------------------------------------===//
4964// WorkdistributeOp
4965//===----------------------------------------------------------------------===//
4966
4967LogicalResult WorkdistributeOp::verify() {
4968 // Check that region exists and is not empty
4969 Region &region = getRegion();
4970 if (region.empty())
4971 return emitOpError("region cannot be empty");
4972 // Verify single entry point.
4973 Block &entryBlock = region.front();
4974 if (entryBlock.empty())
4975 return emitOpError("region must contain a structured block");
4976 // Verify single exit point.
4977 bool hasTerminator = false;
4978 for (Block &block : region) {
4979 if (isa<TerminatorOp>(block.back())) {
4980 if (hasTerminator) {
4981 return emitOpError("region must have exactly one terminator");
4982 }
4983 hasTerminator = true;
4984 }
4985 }
4986 if (!hasTerminator) {
4987 return emitOpError("region must be terminated with omp.terminator");
4988 }
4989 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4990 // No implicit barrier at end
4991 if (isa<BarrierOp>(op)) {
4992 return emitOpError(
4993 "explicit barriers are not allowed in workdistribute region");
4994 }
4995 // Check for invalid nested constructs
4996 if (isa<ParallelOp>(op)) {
4997 return emitOpError(
4998 "nested parallel constructs not allowed in workdistribute");
4999 }
5000 if (isa<TeamsOp>(op)) {
5001 return emitOpError(
5002 "nested teams constructs not allowed in workdistribute");
5003 }
5004 return WalkResult::advance();
5005 });
5006 if (walkResult.wasInterrupted())
5007 return failure();
5008
5009 Operation *parentOp = (*this)->getParentOp();
5010 if (!llvm::dyn_cast<TeamsOp>(parentOp))
5011 return emitOpError("workdistribute must be nested under teams");
5012 return success();
5013}
5014
5015//===----------------------------------------------------------------------===//
5016// Declare simd [7.7]
5017//===----------------------------------------------------------------------===//
5018
5019LogicalResult DeclareSimdOp::verify() {
5020 // Must be nested inside a function-like op
5021 auto func =
5022 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
5023 if (!func)
5024 return emitOpError() << "must be nested inside a function";
5025
5026 if (getInbranch() && getNotinbranch())
5027 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
5028
5029 if (failed(verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars(),
5030 /*isDeclareSimd=*/true)))
5031 return failure();
5032
5033 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
5034}
5035
5036void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
5037 const DeclareSimdOperands &clauses) {
5038 MLIRContext *ctx = odsBuilder.getContext();
5039 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
5040 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
5041 clauses.linearVars, clauses.linearStepVars,
5042 clauses.linearVarTypes, clauses.linearModifiers,
5043 clauses.notinbranch, clauses.simdlen,
5044 clauses.uniformVars);
5045}
5046
5047//===----------------------------------------------------------------------===//
5048// Parser and printer for Uniform Clause
5049//===----------------------------------------------------------------------===//
5050
5051/// uniform ::= `uniform` `(` uniform-list `)`
5052/// uniform-list := uniform-val (`,` uniform-val)*
5053/// uniform-val := ssa-id `:` type
5054static ParseResult
5057 SmallVectorImpl<Type> &uniformTypes) {
5058 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
5059 if (parser.parseOperand(uniformVars.emplace_back()) ||
5060 parser.parseColonType(uniformTypes.emplace_back()))
5061 return mlir::failure();
5062 return mlir::success();
5063 });
5064}
5065
5066/// Print Uniform Clauses
5068 ValueRange uniformVars, TypeRange uniformTypes) {
5069 for (unsigned i = 0; i < uniformVars.size(); ++i) {
5070 if (i != 0)
5071 p << ", ";
5072 p << uniformVars[i] << " : " << uniformTypes[i];
5073 }
5074}
5075
5076//===----------------------------------------------------------------------===//
5077// Parser and printer for Affinity Clause
5078//===----------------------------------------------------------------------===//
5079
5080static ParseResult parseAffinityClause(
5081 OpAsmParser &parser,
5084 SmallVectorImpl<Type> &iteratedTypes,
5085 SmallVectorImpl<Type> &affinityVarTypes) {
5086 if (failed(parseSplitIteratedList(
5087 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
5088 /*parsePrefix=*/[&]() -> ParseResult { return success(); })))
5089 return failure();
5090 return success();
5091}
5092
5094 ValueRange iterated, ValueRange affinityVars,
5095 TypeRange iteratedTypes,
5096 TypeRange affinityVarTypes) {
5097 auto nop = [&](Value, Type) {};
5098 printSplitIteratedList(p, iterated, iteratedTypes, affinityVars,
5099 affinityVarTypes,
5100 /*plain prefix*/ nop,
5101 /*iterated prefix*/ nop);
5102}
5103
5104//===----------------------------------------------------------------------===//
5105// Parser, printer, and verifier for Iterator modifier
5106//===----------------------------------------------------------------------===//
5107
5108static ParseResult
5113 SmallVectorImpl<Type> &lbTypes,
5114 SmallVectorImpl<Type> &ubTypes,
5115 SmallVectorImpl<Type> &stepTypes) {
5116
5117 llvm::SMLoc ivLoc = parser.getCurrentLocation();
5119
5120 // Parse induction variables: %i : i32, %j : i32
5121 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
5122 OpAsmParser::Argument &arg = ivArgs.emplace_back();
5123 if (parser.parseArgument(arg))
5124 return failure();
5125
5126 // Optional type, default to Index if not provided
5127 if (succeeded(parser.parseOptionalColon())) {
5128 if (parser.parseType(arg.type))
5129 return failure();
5130 } else {
5131 arg.type = parser.getBuilder().getIndexType();
5132 }
5133 return success();
5134 }))
5135 return failure();
5136
5137 // ) = (
5138 if (parser.parseRParen() || parser.parseEqual() || parser.parseLParen())
5139 return failure();
5140
5141 // Parse Ranges: (%lb to %ub step %st, ...)
5142 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
5143 OpAsmParser::UnresolvedOperand lb, ub, st;
5144 if (parser.parseOperand(lb) || parser.parseKeyword("to") ||
5145 parser.parseOperand(ub) || parser.parseKeyword("step") ||
5146 parser.parseOperand(st))
5147 return failure();
5148
5149 lbs.push_back(lb);
5150 ubs.push_back(ub);
5151 steps.push_back(st);
5152 return success();
5153 }))
5154 return failure();
5155
5156 if (parser.parseRParen())
5157 return failure();
5158
5159 if (ivArgs.size() != lbs.size())
5160 return parser.emitError(ivLoc)
5161 << "mismatch: " << ivArgs.size() << " variables but " << lbs.size()
5162 << " ranges";
5163
5164 for (auto &arg : ivArgs) {
5165 lbTypes.push_back(arg.type);
5166 ubTypes.push_back(arg.type);
5167 stepTypes.push_back(arg.type);
5168 }
5169
5170 return parser.parseRegion(region, ivArgs);
5171}
5172
5174 ValueRange lbs, ValueRange ubs,
5176 TypeRange) {
5177 Block &entry = region.front();
5178
5179 for (unsigned i = 0, e = entry.getNumArguments(); i < e; ++i) {
5180 if (i != 0)
5181 p << ", ";
5182 p.printRegionArgument(entry.getArgument(i));
5183 }
5184 p << ") = (";
5185
5186 // (%lb0 to %ub0 step %step0, %lb1 to %ub1 step %step1, ...)
5187 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
5188 if (i)
5189 p << ", ";
5190 p << lbs[i] << " to " << ubs[i] << " step " << steps[i];
5191 }
5192 p << ") ";
5193
5194 p.printRegion(region, /*printEntryBlockArgs=*/false,
5195 /*printBlockTerminators=*/true);
5196}
5197
5198LogicalResult IteratorOp::verify() {
5199 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().getType());
5200 if (!iteratedTy)
5201 return emitOpError() << "result must be omp.iterated<entry_ty>";
5202
5203 for (auto [lb, ub, step] : llvm::zip_equal(
5204 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5205 if (matchPattern(step, m_Zero()))
5206 return emitOpError() << "loop step must not be zero";
5207
5208 IntegerAttr lbAttr;
5209 IntegerAttr ubAttr;
5210 IntegerAttr stepAttr;
5211 if (!matchPattern(lb, m_Constant(&lbAttr)) ||
5212 !matchPattern(ub, m_Constant(&ubAttr)) ||
5213 !matchPattern(step, m_Constant(&stepAttr)))
5214 continue;
5215
5216 const APInt &lbVal = lbAttr.getValue();
5217 const APInt &ubVal = ubAttr.getValue();
5218 const APInt &stepVal = stepAttr.getValue();
5219 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5220 return emitOpError() << "positive loop step requires lower bound to be "
5221 "less than or equal to upper bound";
5222 if (stepVal.isNegative() && lbVal.slt(ubVal))
5223 return emitOpError() << "negative loop step requires lower bound to be "
5224 "greater than or equal to upper bound";
5225 }
5226
5227 Block &b = getRegion().front();
5228 auto yield = llvm::dyn_cast<omp::YieldOp>(b.getTerminator());
5229
5230 if (!yield)
5231 return emitOpError() << "region must be terminated by omp.yield";
5232
5233 if (yield.getNumOperands() != 1)
5234 return emitOpError()
5235 << "omp.yield in omp.iterator region must yield exactly one value";
5236
5237 mlir::Type yieldedTy = yield.getOperand(0).getType();
5238 mlir::Type elemTy = iteratedTy.getElementType();
5239
5240 if (yieldedTy != elemTy)
5241 return emitOpError() << "omp.iterated element type (" << elemTy
5242 << ") does not match omp.yield operand type ("
5243 << yieldedTy << ")";
5244
5245 return success();
5246}
5247
5248//===----------------------------------------------------------------------===//
5249// GroupprivateOp
5250//===----------------------------------------------------------------------===//
5251
5252LogicalResult
5253GroupprivateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
5254 auto *symbol = symbolTable.lookupNearestSymbolFrom(*this, getSymNameAttr());
5255 if (!symbol)
5256 return emitOpError() << "expected symbol reference '" << getSymName()
5257 << "' to point to a global variable";
5258
5259 if (isa<FunctionOpInterface>(symbol))
5260 return emitOpError() << "expected symbol reference '" << getSymName()
5261 << "' to point to a global variable, not a function";
5262
5263 return success();
5264}
5265
5266#define GET_ATTRDEF_CLASSES
5267#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5268
5269#define GET_OP_CLASSES
5270#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5271
5272#define GET_TYPEDEF_CLASSES
5273#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1523
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static void printHeapAllocClause(OpAsmPrinter &p, Operation *op, TypeAttr inType, ValueRange typeparams, TypeRange typeparamsTypes, ValueRange shape, TypeRange shapeTypes)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op, AccessGroupModifierAttr modifierFirst, FallbackModifierAttr modifierSecond, Value dynGroupprivateSize, Type sizeType)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseHeapAllocClause(OpAsmParser &parser, TypeAttr &inTypeAttr, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &typeparams, SmallVectorImpl< Type > &typeparamsTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &shape, SmallVectorImpl< Type > &shapeTypes)
operation ::= $in_type ( ( $typeparams ) )? ( , $shape )?
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static ParseResult parseDynGroupprivateClause(OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr, FallbackModifierAttr &fallbackAttr, std::optional< OpAsmParser::UnresolvedOperand > &dynGroupprivateSize, Type &sizeType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup, FallbackModifierAttr fallback, Value dynGroupprivateSize)
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
LogicalResult verifyAlignment(Operation &op, std::optional< uint64_t > alignment)
Verifies align clause in allocate directive.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
BlockArgListType getArguments()
Definition Block.h:97
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:143
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:161
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:237
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:247
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.