MLIR 22.0.0git
TransformOps.cpp
Go to the documentation of this file.
1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Dominance.h"
24#include "mlir/IR/Verifier.h"
30#include "mlir/Transforms/CSE.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/DebugLog.h"
41#include "llvm/Support/ErrorHandling.h"
42#include "llvm/Support/InterleavedRange.h"
43#include <optional>
44
45#define DEBUG_TYPE "transform-dialect"
46#define DEBUG_TYPE_MATCHER "transform-matcher"
47
48using namespace mlir;
49
50static ParseResult parseApplyRegisteredPassOptions(
51 OpAsmParser &parser, DictionaryAttr &options,
54 Operation *op,
55 DictionaryAttr options,
56 ValueRange dynamicOptions);
57static ParseResult parseSequenceOpOperands(
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
59 Type &rootType,
61 SmallVectorImpl<Type> &extraBindingTypes);
62static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
63 Value root, Type rootType,
64 ValueRange extraBindings,
65 TypeRange extraBindingTypes);
66static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
68static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
70 ArrayAttr &actions);
71
72/// Helper function to check if the given transform op is contained in (or
73/// equal to) the given payload target op. In that case, an error is returned.
74/// Transforming transform IR that is currently executing is generally unsafe.
76ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
77 Operation *payload) {
78 Operation *transformAncestor = transform.getOperation();
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
82 transform.emitDefiniteFailure()
83 << "cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->getLoc()) << "target payload op";
85 return diag;
86 }
87 transformAncestor = transformAncestor->getParentOp();
88 }
90}
91
92#define GET_OP_CLASSES
93#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
94
95//===----------------------------------------------------------------------===//
96// AlternativesOp
97//===----------------------------------------------------------------------===//
98
99OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
100 RegionSuccessor successor) {
101 if (!successor.isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
103 return OperandRange(getOperation()->operand_end(),
104 getOperation()->operand_end());
105}
106
107void transform::AlternativesOp::getSuccessorRegions(
109 for (Region &alternative : llvm::drop_begin(
110 getAlternatives(), point.isParent()
111 ? 0
114 ->getRegionNumber() +
115 1)) {
116 regions.emplace_back(&alternative, !getOperands().empty()
117 ? alternative.getArguments()
119 }
120 if (!point.isParent())
121 regions.emplace_back(getOperation(), getOperation()->getResults());
122}
123
124void transform::AlternativesOp::getRegionInvocationBounds(
126 (void)operands;
127 // The region corresponding to the first alternative is always executed, the
128 // remaining may or may not be executed.
129 bounds.reserve(getNumRegions());
130 bounds.emplace_back(1, 1);
131 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
132}
133
136 for (const auto &res : block->getParentOp()->getOpResults())
137 results.set(res, {});
138}
139
141transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
144 SmallVector<Operation *> originals;
145 if (Value scopeHandle = getScope())
146 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
147 else
148 originals.push_back(state.getTopLevel());
149
150 for (Operation *original : originals) {
151 if (original->isAncestor(getOperation())) {
153 << "scope must not contain the transforms being applied";
154 diag.attachNote(original->getLoc()) << "scope";
155 return diag;
156 }
157 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
159 << "only isolated-from-above ops can be alternative scopes";
160 diag.attachNote(original->getLoc()) << "scope";
161 return diag;
162 }
163 }
164
165 for (Region &reg : getAlternatives()) {
166 // Clone the scope operations and make the transforms in this alternative
167 // region apply to them by virtue of mapping the block argument (the only
168 // visible handle) to the cloned scope operations. This effectively prevents
169 // the transformation from accessing any IR outside the scope.
170 auto scope = state.make_region_scope(reg);
171 auto clones = llvm::to_vector(
172 llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
173 auto deleteClones = llvm::make_scope_exit([&] {
174 for (Operation *clone : clones)
175 clone->erase();
176 });
177 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
179
180 bool failed = false;
181 for (Operation &transform : reg.front().without_terminator()) {
183 state.applyTransform(cast<TransformOpInterface>(transform));
184 if (result.isSilenceableFailure()) {
185 LDBG() << "alternative failed: " << result.getMessage();
186 failed = true;
187 break;
188 }
189
190 if (::mlir::failed(result.silence()))
192 }
193
194 // If all operations in the given alternative succeeded, no need to consider
195 // the rest. Replace the original scoping operation with the clone on which
196 // the transformations were performed.
197 if (!failed) {
198 // We will be using the clones, so cancel their scheduled deletion.
199 deleteClones.release();
200 TrackingListener listener(state, *this);
201 IRRewriter rewriter(getContext(), &listener);
202 for (const auto &kvp : llvm::zip(originals, clones)) {
203 Operation *original = std::get<0>(kvp);
204 Operation *clone = std::get<1>(kvp);
205 original->getBlock()->getOperations().insert(original->getIterator(),
206 clone);
207 rewriter.replaceOp(original, clone->getResults());
208 }
209 detail::forwardTerminatorOperands(&reg.front(), state, results);
211 }
212 }
213 return emitSilenceableError() << "all alternatives failed";
214}
215
216void transform::AlternativesOp::getEffects(
218 consumesHandle(getOperation()->getOpOperands(), effects);
219 producesHandle(getOperation()->getOpResults(), effects);
220 for (Region *region : getRegions()) {
221 if (!region->empty())
222 producesHandle(region->front().getArguments(), effects);
223 }
224 modifiesPayload(effects);
225}
226
227LogicalResult transform::AlternativesOp::verify() {
228 for (Region &alternative : getAlternatives()) {
229 Block &block = alternative.front();
230 Operation *terminator = block.getTerminator();
231 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
233 << "expects terminator operands to have the "
234 "same type as results of the operation";
235 diag.attachNote(terminator->getLoc()) << "terminator";
236 return diag;
237 }
238 }
239
240 return success();
241}
242
243//===----------------------------------------------------------------------===//
244// AnnotateOp
245//===----------------------------------------------------------------------===//
246
248transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
252 llvm::to_vector(state.getPayloadOps(getTarget()));
253
254 Attribute attr = UnitAttr::get(getContext());
255 if (auto paramH = getParam()) {
256 ArrayRef<Attribute> params = state.getParams(paramH);
257 if (params.size() != 1) {
258 if (targets.size() != params.size()) {
259 return emitSilenceableError()
260 << "parameter and target have different payload lengths ("
261 << params.size() << " vs " << targets.size() << ")";
262 }
263 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
264 target->setAttr(getName(), attr);
266 }
267 attr = params[0];
268 }
269 for (auto *target : targets)
270 target->setAttr(getName(), attr);
272}
273
274void transform::AnnotateOp::getEffects(
276 onlyReadsHandle(getTargetMutable(), effects);
277 onlyReadsHandle(getParamMutable(), effects);
278 modifiesPayload(effects);
279}
280
281//===----------------------------------------------------------------------===//
282// ApplyCommonSubexpressionEliminationOp
283//===----------------------------------------------------------------------===//
284
286transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
288 ApplyToEachResultList &results, transform::TransformState &state) {
289 // Make sure that this transform is not applied to itself. Modifying the
290 // transform IR while it is being interpreted is generally dangerous.
291 DiagnosedSilenceableFailure payloadCheck =
293 if (!payloadCheck.succeeded())
294 return payloadCheck;
295
296 DominanceInfo domInfo;
299}
300
301void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
303 transform::onlyReadsHandle(getTargetMutable(), effects);
305}
306
307//===----------------------------------------------------------------------===//
308// ApplyDeadCodeEliminationOp
309//===----------------------------------------------------------------------===//
310
311DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
313 ApplyToEachResultList &results, transform::TransformState &state) {
314 // Make sure that this transform is not applied to itself. Modifying the
315 // transform IR while it is being interpreted is generally dangerous.
316 DiagnosedSilenceableFailure payloadCheck =
318 if (!payloadCheck.succeeded())
319 return payloadCheck;
320
321 // Maintain a worklist of potentially dead ops.
322 SetVector<Operation *> worklist;
323
324 // Helper function that adds all defining ops of used values (operands and
325 // operands of nested ops).
326 auto addDefiningOpsToWorklist = [&](Operation *op) {
327 op->walk([&](Operation *op) {
328 for (Value v : op->getOperands())
329 if (Operation *defOp = v.getDefiningOp())
330 if (target->isProperAncestor(defOp))
331 worklist.insert(defOp);
332 });
333 };
334
335 // Helper function that erases an op.
336 auto eraseOp = [&](Operation *op) {
337 // Remove op and nested ops from the worklist.
338 op->walk([&](Operation *op) {
339 const auto *it = llvm::find(worklist, op);
340 if (it != worklist.end())
341 worklist.erase(it);
342 });
343 rewriter.eraseOp(op);
344 };
345
346 // Initial walk over the IR.
347 target->walk<WalkOrder::PostOrder>([&](Operation *op) {
348 if (op != target && isOpTriviallyDead(op)) {
349 addDefiningOpsToWorklist(op);
350 eraseOp(op);
351 }
352 });
353
354 // Erase all ops that have become dead.
355 while (!worklist.empty()) {
356 Operation *op = worklist.pop_back_val();
357 if (!isOpTriviallyDead(op))
358 continue;
359 addDefiningOpsToWorklist(op);
360 eraseOp(op);
361 }
362
364}
365
366void transform::ApplyDeadCodeEliminationOp::getEffects(
368 transform::onlyReadsHandle(getTargetMutable(), effects);
370}
371
372//===----------------------------------------------------------------------===//
373// ApplyPatternsOp
374//===----------------------------------------------------------------------===//
375
376DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
378 ApplyToEachResultList &results, transform::TransformState &state) {
379 // Make sure that this transform is not applied to itself. Modifying the
380 // transform IR while it is being interpreted is generally dangerous. Even
381 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
382 // performs many additional simplifications such as dead code elimination.
383 DiagnosedSilenceableFailure payloadCheck =
385 if (!payloadCheck.succeeded())
386 return payloadCheck;
387
388 // Gather all specified patterns.
389 MLIRContext *ctx = target->getContext();
391 if (!getRegion().empty()) {
392 for (Operation &op : getRegion().front()) {
393 cast<transform::PatternDescriptorOpInterface>(&op)
394 .populatePatternsWithState(patterns, state);
395 }
396 }
397
398 // Configure the GreedyPatternRewriteDriver.
400 config.setListener(
401 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
402 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
403
404 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
406 : getMaxIterations());
407 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
409 : getMaxNumRewrites());
410
411 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
412 // was requested, apply the greedy pattern rewrite only once. (The greedy
413 // pattern rewrite driver already iterates to a fixpoint internally.)
414 bool cseChanged = false;
415 // One or two iterations should be sufficient. Stop iterating after a certain
416 // threshold to make debugging easier.
417 static const int64_t kNumMaxIterations = 50;
418 int64_t iteration = 0;
419 do {
420 LogicalResult result = failure();
421 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
422 // Op is isolated from above. Apply patterns and also perform region
423 // simplification.
424 result = applyPatternsGreedily(target, frozenPatterns, config);
425 } else {
426 // Manually gather list of ops because the other
427 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
428 // from above. This way, patterns can be applied to ops that are not
429 // isolated from above. Regions are not being simplified. Furthermore,
430 // only a single greedy rewrite iteration is performed.
432 target->walk([&](Operation *nestedOp) {
433 if (target != nestedOp)
434 ops.push_back(nestedOp);
435 });
436 result = applyOpPatternsGreedily(ops, frozenPatterns, config);
437 }
438
439 // A failure typically indicates that the pattern application did not
440 // converge.
441 if (failed(result)) {
443 << "greedy pattern application failed";
444 }
445
446 if (getApplyCse()) {
447 DominanceInfo domInfo;
449 &cseChanged);
450 }
451 } while (cseChanged && ++iteration < kNumMaxIterations);
452
453 if (iteration == kNumMaxIterations)
454 return emitDefiniteFailure() << "fixpoint iteration did not converge";
455
457}
458
459LogicalResult transform::ApplyPatternsOp::verify() {
460 if (!getRegion().empty()) {
461 for (Operation &op : getRegion().front()) {
462 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
464 << "expected children ops to implement "
465 "PatternDescriptorOpInterface";
466 diag.attachNote(op.getLoc()) << "op without interface";
467 return diag;
468 }
469 }
470 }
471 return success();
472}
473
474void transform::ApplyPatternsOp::getEffects(
476 transform::onlyReadsHandle(getTargetMutable(), effects);
478}
479
480void transform::ApplyPatternsOp::build(
482 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
483 result.addOperands(target);
484
485 OpBuilder::InsertionGuard g(builder);
486 Region *region = result.addRegion();
487 builder.createBlock(region);
488 if (bodyBuilder)
489 bodyBuilder(builder, result.location);
490}
491
492//===----------------------------------------------------------------------===//
493// ApplyCanonicalizationPatternsOp
494//===----------------------------------------------------------------------===//
495
496void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
498 MLIRContext *ctx = patterns.getContext();
499 for (Dialect *dialect : ctx->getLoadedDialects())
500 dialect->getCanonicalizationPatterns(patterns);
502 op.getCanonicalizationPatterns(patterns, ctx);
503}
504
505//===----------------------------------------------------------------------===//
506// ApplyConversionPatternsOp
507//===----------------------------------------------------------------------===//
508
509DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
512 MLIRContext *ctx = getContext();
513
514 // Instantiate the default type converter if a type converter builder is
515 // specified.
516 std::unique_ptr<TypeConverter> defaultTypeConverter;
517 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
518 getDefaultTypeConverter();
519 if (typeConverterBuilder)
520 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
521
522 // Configure conversion target.
523 ConversionTarget conversionTarget(*getContext());
524 if (getLegalOps())
525 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
526 conversionTarget.addLegalOp(
527 OperationName(cast<StringAttr>(attr).getValue(), ctx));
528 if (getIllegalOps())
529 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
530 conversionTarget.addIllegalOp(
531 OperationName(cast<StringAttr>(attr).getValue(), ctx));
532 if (getLegalDialects())
533 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
534 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
535 if (getIllegalDialects())
536 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
537 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
538
539 // Gather all specified patterns.
541 // Need to keep the converters alive until after pattern application because
542 // the patterns take a reference to an object that would otherwise get out of
543 // scope.
545 if (!getPatterns().empty()) {
546 for (Operation &op : getPatterns().front()) {
547 auto descriptor =
548 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
549
550 // Check if this pattern set specifies a type converter.
551 std::unique_ptr<TypeConverter> typeConverter =
552 descriptor.getTypeConverter();
553 TypeConverter *converter = nullptr;
554 if (typeConverter) {
555 keepAliveConverters.emplace_back(std::move(typeConverter));
556 converter = keepAliveConverters.back().get();
557 } else {
558 // No type converter specified: Use the default type converter.
559 if (!defaultTypeConverter) {
561 << "pattern descriptor does not specify type "
562 "converter and apply_conversion_patterns op has "
563 "no default type converter";
564 diag.attachNote(op.getLoc()) << "pattern descriptor op";
565 return diag;
566 }
567 converter = defaultTypeConverter.get();
568 }
569
570 // Add descriptor-specific updates to the conversion target, which may
571 // depend on the final type converter. In structural converters, the
572 // legality of types dictates the dynamic legality of an operation.
573 descriptor.populateConversionTargetRules(*converter, conversionTarget);
574
575 descriptor.populatePatterns(*converter, patterns);
576 }
577 }
578
579 // Attach a tracking listener if handles should be preserved. We configure the
580 // listener to allow op replacements with different names, as conversion
581 // patterns typically replace ops with replacement ops that have a different
582 // name.
583 TrackingListenerConfig trackingConfig;
584 trackingConfig.requireMatchingReplacementOpName = false;
585 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
586 ConversionConfig conversionConfig;
587 if (getPreserveHandles())
588 conversionConfig.listener = &trackingListener;
589
590 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
591 for (Operation *target : state.getPayloadOps(getTarget())) {
592 // Make sure that this transform is not applied to itself. Modifying the
593 // transform IR while it is being interpreted is generally dangerous.
594 DiagnosedSilenceableFailure payloadCheck =
596 if (!payloadCheck.succeeded())
597 return payloadCheck;
598
599 LogicalResult status = failure();
600 if (getPartialConversion()) {
601 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
602 conversionConfig);
603 } else {
604 status = applyFullConversion(target, conversionTarget, frozenPatterns,
605 conversionConfig);
606 }
607
608 // Check dialect conversion state.
610 if (failed(status)) {
611 diag = emitSilenceableError() << "dialect conversion failed";
612 diag.attachNote(target->getLoc()) << "target op";
613 }
614
615 // Check tracking listener error state.
616 DiagnosedSilenceableFailure trackingFailure =
617 trackingListener.checkAndResetError();
618 if (!trackingFailure.succeeded()) {
619 if (diag.succeeded()) {
620 // Tracking failure is the only failure.
621 return trackingFailure;
622 }
623 diag.attachNote() << "tracking listener also failed: "
624 << trackingFailure.getMessage();
625 (void)trackingFailure.silence();
626 }
627
628 if (!diag.succeeded())
629 return diag;
630 }
631
633}
634
635LogicalResult transform::ApplyConversionPatternsOp::verify() {
636 if (getNumRegions() != 1 && getNumRegions() != 2)
637 return emitOpError() << "expected 1 or 2 regions";
638 if (!getPatterns().empty()) {
639 for (Operation &op : getPatterns().front()) {
640 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
642 emitOpError() << "expected pattern children ops to implement "
643 "ConversionPatternDescriptorOpInterface";
644 diag.attachNote(op.getLoc()) << "op without interface";
645 return diag;
646 }
647 }
648 }
649 if (getNumRegions() == 2) {
650 Region &typeConverterRegion = getRegion(1);
651 if (!llvm::hasSingleElement(typeConverterRegion.front()))
652 return emitOpError()
653 << "expected exactly one op in default type converter region";
654 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
655 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
656 maybeTypeConverter);
657 if (!typeConverterOp) {
659 << "expected default converter child op to "
660 "implement TypeConverterBuilderOpInterface";
661 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
662 return diag;
663 }
664 // Check default type converter type.
665 if (!getPatterns().empty()) {
666 for (Operation &op : getPatterns().front()) {
667 auto descriptor =
668 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
669 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
670 return failure();
671 }
672 }
673 }
674 return success();
675}
676
677void transform::ApplyConversionPatternsOp::getEffects(
679 if (!getPreserveHandles()) {
680 transform::consumesHandle(getTargetMutable(), effects);
681 } else {
682 transform::onlyReadsHandle(getTargetMutable(), effects);
683 }
685}
686
687void transform::ApplyConversionPatternsOp::build(
689 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
690 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
691 result.addOperands(target);
692
693 {
694 OpBuilder::InsertionGuard g(builder);
695 Region *region1 = result.addRegion();
696 builder.createBlock(region1);
697 if (patternsBodyBuilder)
698 patternsBodyBuilder(builder, result.location);
699 }
700 {
701 OpBuilder::InsertionGuard g(builder);
702 Region *region2 = result.addRegion();
703 builder.createBlock(region2);
704 if (typeConverterBodyBuilder)
705 typeConverterBodyBuilder(builder, result.location);
706 }
707}
708
709//===----------------------------------------------------------------------===//
710// ApplyToLLVMConversionPatternsOp
711//===----------------------------------------------------------------------===//
712
713void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
714 TypeConverter &typeConverter, RewritePatternSet &patterns) {
715 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
716 assert(dialect && "expected that dialect is loaded");
717 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
718 // ConversionTarget is currently ignored because the enclosing
719 // apply_conversion_patterns op sets up its own ConversionTarget.
721 iface->populateConvertToLLVMConversionPatterns(
722 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
723}
724
725LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
726 transform::TypeConverterBuilderOpInterface builder) {
727 if (builder.getTypeConverterType() != "LLVMTypeConverter")
728 return emitOpError("expected LLVMTypeConverter");
729 return success();
730}
731
732LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
733 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
734 if (!dialect)
735 return emitOpError("unknown dialect or dialect not loaded: ")
736 << getDialectName();
737 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
738 if (!iface)
739 return emitOpError(
740 "dialect does not implement ConvertToLLVMPatternInterface or "
741 "extension was not loaded: ")
742 << getDialectName();
743 return success();
744}
745
746//===----------------------------------------------------------------------===//
747// ApplyLoopInvariantCodeMotionOp
748//===----------------------------------------------------------------------===//
749
751transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
752 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
755 // Currently, LICM does not remove operations, so we don't need tracking.
756 // If this ever changes, add a LICM entry point that takes a rewriter.
759}
760
761void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
763 transform::onlyReadsHandle(getTargetMutable(), effects);
765}
766
767//===----------------------------------------------------------------------===//
768// ApplyRegisteredPassOp
769//===----------------------------------------------------------------------===//
770
771void transform::ApplyRegisteredPassOp::getEffects(
773 consumesHandle(getTargetMutable(), effects);
774 onlyReadsHandle(getDynamicOptionsMutable(), effects);
775 producesHandle(getOperation()->getOpResults(), effects);
776 modifiesPayload(effects);
777}
778
780transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
783 // Obtain a single options-string to pass to the pass(-pipeline) from options
784 // passed in as a dictionary of keys mapping to values which are either
785 // attributes or param-operands pointing to attributes.
786 OperandRange dynamicOptions = getDynamicOptions();
787
788 std::string options;
789 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
790
791 // A helper to convert an option's attribute value into a corresponding
792 // string representation, with the ability to obtain the attr(s) from a param.
793 std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
794 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
795 // The corresponding value attribute(s) is/are passed in via a param.
796 // Obtain the param-operand via its specified index.
797 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
798 assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
799 "the number of ParamOperandAttrs in the options DictionaryAttr"
800 "should be the same as the number of options passed as params");
801 ArrayRef<Attribute> attrsAssociatedToParam =
802 state.getParams(dynamicOptions[dynamicOptionIdx]);
803 // Recursive so as to append all attrs associated to the param.
804 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
805 ",");
806 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
807 // Recursive so as to append all nested attrs of the array.
808 llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
809 } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
810 // Convert to unquoted string.
811 optionsStream << strAttr.getValue().str();
812 } else {
813 // For all other attributes, ask the attr to print itself (without type).
814 valueAttr.print(optionsStream, /*elideType=*/true);
815 }
816 };
817
818 // Convert the options DictionaryAttr into a single string.
819 llvm::interleave(
820 getOptions(), optionsStream,
821 [&](auto namedAttribute) {
822 optionsStream << namedAttribute.getName().str(); // Append the key.
823 optionsStream << "="; // And the key-value separator.
824 appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
825 },
826 " ");
827 optionsStream.flush();
828
829 // Get pass or pass pipeline from registry.
830 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
831 if (!info)
832 info = PassInfo::lookup(getPassName());
833 if (!info)
834 return emitDefiniteFailure()
835 << "unknown pass or pass pipeline: " << getPassName();
836
837 // Create pass manager and add the pass or pass pipeline.
839 if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
840 emitError(msg);
841 return failure();
842 }))) {
843 return emitDefiniteFailure()
844 << "failed to add pass or pass pipeline to pipeline: "
845 << getPassName();
846 }
847
848 auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
849 for (Operation *target : targets) {
850 // Make sure that this transform is not applied to itself. Modifying the
851 // transform IR while it is being interpreted is generally dangerous. Even
852 // more so when applying passes because they may perform a wide range of IR
853 // modifications.
854 DiagnosedSilenceableFailure payloadCheck =
856 if (!payloadCheck.succeeded())
857 return payloadCheck;
858
859 // Run the pass or pass pipeline on the current target operation.
860 if (failed(pm.run(target))) {
861 auto diag = emitSilenceableError() << "pass pipeline failed";
862 diag.attachNote(target->getLoc()) << "target op";
863 return diag;
864 }
865 }
866
867 // The applied pass will have directly modified the payload IR(s).
868 results.set(llvm::cast<OpResult>(getResult()), targets);
870}
871
873 OpAsmParser &parser, DictionaryAttr &options,
875 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
876 SmallVector<NamedAttribute> keyValuePairs;
877 size_t dynamicOptionsIdx = 0;
878
879 // Helper for allowing parsing of option values which can be of the form:
880 // - a normal attribute
881 // - an operand (which would be converted to an attr referring to the operand)
882 // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
883 std::function<ParseResult(Attribute &)> parseValue =
884 [&](Attribute &valueAttr) -> ParseResult {
885 // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
886 if (succeeded(parser.parseOptionalLSquare())) {
888
889 // Recursively parse the array's elements, which might be operands.
890 if (parser.parseCommaSeparatedList(
892 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
893 " in options dictionary") ||
894 parser.parseRSquare())
895 return failure(); // NB: Attempted parse should've output error message.
896
897 valueAttr = ArrayAttr::get(parser.getContext(), attrs);
898
899 return success();
900 }
901
902 // Parse the value, which can be either an attribute or an operand.
903 OptionalParseResult parsedValueAttr =
904 parser.parseOptionalAttribute(valueAttr);
905 if (!parsedValueAttr.has_value()) {
907 ParseResult parsedOperand = parser.parseOperand(operand);
908 if (failed(parsedOperand))
909 return failure(); // NB: Attempted parse should've output error message.
910 // To make use of the operand, we need to store it in the options dict.
911 // As SSA-values cannot occur in attributes, what we do instead is store
912 // an attribute in its place that contains the index of the param-operand,
913 // so that an attr-value associated to the param can be resolved later on.
914 dynamicOptions.push_back(operand);
915 auto wrappedIndex = IntegerAttr::get(
916 IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
917 valueAttr =
918 transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
919 } else if (failed(parsedValueAttr.value())) {
920 return failure(); // NB: Attempted parse should have output error message.
921 } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
922 return parser.emitError(parser.getCurrentLocation())
923 << "the param_operand attribute is a marker reserved for "
924 << "indicating a value will be passed via params and is only used "
925 << "in the generic print format";
926 }
927
928 return success();
929 };
930
931 // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
932 // string and `value` looks like either an attribute or an operand-in-an-attr.
933 std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
934 std::string key;
935 Attribute valueAttr;
936
937 if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
938 return parser.emitError(parser.getCurrentLocation())
939 << "expected key to either be an identifier or a string";
940
941 if (failed(parser.parseEqual()))
942 return parser.emitError(parser.getCurrentLocation())
943 << "expected '=' after key in key-value pair";
944
945 if (failed(parseValue(valueAttr)))
946 return parser.emitError(parser.getCurrentLocation())
947 << "expected a valid attribute or operand as value associated "
948 << "to key '" << key << "'";
949
950 keyValuePairs.push_back(NamedAttribute(key, valueAttr));
951
952 return success();
953 };
954
957 " in options dictionary"))
958 return failure(); // NB: Attempted parse should have output error message.
959
960 if (DictionaryAttr::findDuplicate(
961 keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
962 .has_value())
963 return parser.emitError(parser.getCurrentLocation())
964 << "duplicate keys found in options dictionary";
965
966 options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
967
968 return success();
969}
970
972 Operation *op,
973 DictionaryAttr options,
974 ValueRange dynamicOptions) {
975 if (options.empty())
976 return;
977
978 std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
979 if (auto paramOperandAttr =
980 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
981 // Resolve index of param-operand to its actual SSA-value and print that.
982 printer.printOperand(
983 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
984 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
985 // This case is so that ArrayAttr-contained operands are pretty-printed.
986 printer << "[";
987 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
988 printer << "]";
989 } else {
990 printer.printAttribute(valueAttr);
991 }
992 };
993
994 printer << "{";
995 llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
996 printer << namedAttribute.getName();
997 printer << " = ";
998 printOptionValue(namedAttribute.getValue());
999 });
1000 printer << "}";
1001}
1002
1003LogicalResult transform::ApplyRegisteredPassOp::verify() {
1004 // Check that there is a one-to-one correspondence between param operands
1005 // and references to dynamic options in the options dictionary.
1006
1007 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1008
1009 // Helper for option values to mark seen operands as having been seen (once).
1010 std::function<LogicalResult(Attribute)> checkOptionValue =
1011 [&](Attribute valueAttr) -> LogicalResult {
1012 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1013 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1014 if (dynamicOptionIdx < 0 ||
1015 dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1016 return emitOpError()
1017 << "dynamic option index " << dynamicOptionIdx
1018 << " is out of bounds for the number of dynamic options: "
1019 << dynamicOptions.size();
1020 if (dynamicOptions[dynamicOptionIdx] == nullptr)
1021 return emitOpError() << "dynamic option index " << dynamicOptionIdx
1022 << " is already used in options";
1023 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1024 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1025 // Recurse into ArrayAttrs as they may contain references to operands.
1026 for (auto eltAttr : arrayAttr)
1027 if (failed(checkOptionValue(eltAttr)))
1028 return failure();
1029 }
1030 return success();
1031 };
1032
1033 for (NamedAttribute namedAttr : getOptions())
1034 if (failed(checkOptionValue(namedAttr.getValue())))
1035 return failure();
1036
1037 // All dynamicOptions-params seen in the dict will have been set to null.
1038 for (Value dynamicOption : dynamicOptions)
1039 if (dynamicOption)
1040 return emitOpError() << "a param operand does not have a corresponding "
1041 << "param_operand attr in the options dict";
1042
1043 return success();
1044}
1045
1046//===----------------------------------------------------------------------===//
1047// CastOp
1048//===----------------------------------------------------------------------===//
1049
1051transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1052 Operation *target, ApplyToEachResultList &results,
1054 results.push_back(target);
1056}
1057
1058void transform::CastOp::getEffects(
1060 onlyReadsPayload(effects);
1061 onlyReadsHandle(getInputMutable(), effects);
1062 producesHandle(getOperation()->getOpResults(), effects);
1063}
1064
1065bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1066 assert(inputs.size() == 1 && "expected one input");
1067 assert(outputs.size() == 1 && "expected one output");
1068 return llvm::all_of(
1069 std::initializer_list<Type>{inputs.front(), outputs.front()},
1070 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1071}
1072
1073//===----------------------------------------------------------------------===//
1074// CollectMatchingOp
1075//===----------------------------------------------------------------------===//
1076
1077/// Applies matcher operations from the given `block` using
1078/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1079/// accordingly. If any of the matcher produces a silenceable failure, discards
1080/// it (printing the content to the debug output stream) and returns failure. If
1081/// any of the matchers produces a definite failure, reports it and returns
1082/// failure. If all matchers in the block succeed, populates `mappings` with the
1083/// payload entities associated with the block terminator operands. Note that
1084/// `mappings` will be cleared before that.
1087 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1090 assert(block.getParent() && "cannot match using a detached block");
1091 auto matchScope = state.make_region_scope(*block.getParent());
1092 if (failed(
1093 state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1095
1096 for (Operation &match : block.without_terminator()) {
1097 if (!isa<transform::MatchOpInterface>(match)) {
1098 return emitDefiniteFailure(match.getLoc())
1099 << "expected operations in the match part to "
1100 "implement MatchOpInterface";
1101 }
1103 state.applyTransform(cast<transform::TransformOpInterface>(match));
1104 if (diag.succeeded())
1105 continue;
1106
1107 return diag;
1108 }
1109
1110 // Remember the values mapped to the terminator operands so we can
1111 // forward them to the action.
1112 ValueRange yieldedValues = block.getTerminator()->getOperands();
1113 // Our contract with the caller is that the mappings will contain only the
1114 // newly mapped values, clear the rest.
1115 mappings.clear();
1116 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1118}
1119
1120/// Returns `true` if both types implement one of the interfaces provided as
1121/// template parameters.
1122template <typename... Tys>
1123static bool implementSameInterface(Type t1, Type t2) {
1124 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1125}
1126
1127/// Returns `true` if both types implement one of the transform dialect
1128/// interfaces.
1130 return implementSameInterface<transform::TransformHandleTypeInterface,
1131 transform::TransformParamTypeInterface,
1132 transform::TransformValueHandleTypeInterface>(
1133 t1, t2);
1134}
1135
1136//===----------------------------------------------------------------------===//
1137// CollectMatchingOp
1138//===----------------------------------------------------------------------===//
1139
1141transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1145 getOperation(), getMatcher());
1146 if (matcher.isExternal()) {
1147 return emitDefiniteFailure()
1148 << "unresolved external symbol " << getMatcher();
1149 }
1150
1152 rawResults.resize(getOperation()->getNumResults());
1153 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1154 for (Operation *root : state.getPayloadOps(getRoot())) {
1155 WalkResult walkResult = root->walk([&](Operation *op) {
1156 LDBG(DEBUG_TYPE_MATCHER, 1)
1157 << "matching "
1158 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1159 << " @" << op;
1160
1161 // Try matching.
1163 SmallVector<transform::MappedValue> inputMapping({op});
1165 matcher.getFunctionBody().front(),
1166 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1167 mappings);
1168 if (diag.isDefiniteFailure())
1169 return WalkResult::interrupt();
1170 if (diag.isSilenceableFailure()) {
1171 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1172 << " failed: " << diag.getMessage();
1173 return WalkResult::advance();
1174 }
1175
1176 // If succeeded, collect results.
1177 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1178 if (mapping.size() != 1) {
1179 maybeFailure.emplace(emitSilenceableError()
1180 << "result #" << i << ", associated with "
1181 << mapping.size()
1182 << " payload objects, expected 1");
1183 return WalkResult::interrupt();
1184 }
1185 rawResults[i].push_back(mapping[0]);
1186 }
1187 return WalkResult::advance();
1188 });
1189 if (walkResult.wasInterrupted())
1190 return std::move(*maybeFailure);
1191 assert(!maybeFailure && "failure set but the walk was not interrupted");
1192
1193 for (auto &&[opResult, rawResult] :
1194 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1195 results.setMappedValues(opResult, rawResult);
1196 }
1197 }
1199}
1200
1201void transform::CollectMatchingOp::getEffects(
1203 onlyReadsHandle(getRootMutable(), effects);
1204 producesHandle(getOperation()->getOpResults(), effects);
1205 onlyReadsPayload(effects);
1206}
1207
1208LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1209 SymbolTableCollection &symbolTable) {
1210 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1211 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1212 if (!matcherSymbol ||
1213 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1214 return emitError() << "unresolved matcher symbol " << getMatcher();
1215
1216 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1217 if (argumentTypes.size() != 1 ||
1218 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1219 return emitError()
1220 << "expected the matcher to take one operation handle argument";
1221 }
1222 if (!matcherSymbol.getArgAttr(
1223 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1224 return emitError() << "expected the matcher argument to be marked readonly";
1225 }
1226
1227 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1228 if (resultTypes.size() != getOperation()->getNumResults()) {
1229 return emitError()
1230 << "expected the matcher to yield as many values as op has results ("
1231 << getOperation()->getNumResults() << "), got "
1232 << resultTypes.size();
1233 }
1234
1235 for (auto &&[i, matcherType, resultType] :
1236 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1237 if (implementSameTransformInterface(matcherType, resultType))
1238 continue;
1239
1240 return emitError()
1241 << "mismatching type interfaces for matcher result and op result #"
1242 << i;
1243 }
1244
1245 return success();
1246}
1247
1248//===----------------------------------------------------------------------===//
1249// ForeachMatchOp
1250//===----------------------------------------------------------------------===//
1251
1252// This is fine because nothing is actually consumed by this op.
1253bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1254
1256transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1260 matchActionPairs;
1261 matchActionPairs.reserve(getMatchers().size());
1262 SymbolTableCollection symbolTable;
1263 for (auto &&[matcher, action] :
1264 llvm::zip_equal(getMatchers(), getActions())) {
1265 auto matcherSymbol =
1266 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1267 getOperation(), cast<SymbolRefAttr>(matcher));
1268 auto actionSymbol =
1269 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1270 getOperation(), cast<SymbolRefAttr>(action));
1271 assert(matcherSymbol && actionSymbol &&
1272 "unresolved symbols not caught by the verifier");
1273
1274 if (matcherSymbol.isExternal())
1275 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1276 if (actionSymbol.isExternal())
1277 return emitDefiniteFailure() << "unresolved external symbol " << action;
1278
1279 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1280 }
1281
1282 DiagnosedSilenceableFailure overallDiag =
1284
1285 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1286 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1287 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1288 // Explicitly add the mapping for the first block argument (the op being
1289 // matched).
1290 matchInputMapping.emplace_back();
1292 getForwardedInputs(), state);
1293 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1294 actionResultMapping.resize(getForwardedOutputs().size());
1295
1296 for (Operation *root : state.getPayloadOps(getRoot())) {
1297 WalkResult walkResult = root->walk([&](Operation *op) {
1298 // If getRestrictRoot is not present, skip over the root op itself so we
1299 // don't invalidate it.
1300 if (!getRestrictRoot() && op == root)
1301 return WalkResult::advance();
1302
1303 LDBG(DEBUG_TYPE_MATCHER, 1)
1304 << "matching "
1305 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1306 << " @" << op;
1307
1308 firstMatchArgument.clear();
1309 firstMatchArgument.push_back(op);
1310
1311 // Try all the match/action pairs until the first successful match.
1312 for (auto [matcher, action] : matchActionPairs) {
1314 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1315 state, matchOutputMapping);
1316 if (diag.isDefiniteFailure())
1317 return WalkResult::interrupt();
1318 if (diag.isSilenceableFailure()) {
1319 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1320 << " failed: " << diag.getMessage();
1321 continue;
1322 }
1323
1324 auto scope = state.make_region_scope(action.getFunctionBody());
1325 if (failed(state.mapBlockArguments(
1326 action.getFunctionBody().front().getArguments(),
1327 matchOutputMapping))) {
1328 return WalkResult::interrupt();
1329 }
1330
1331 for (Operation &transform :
1332 action.getFunctionBody().front().without_terminator()) {
1334 state.applyTransform(cast<TransformOpInterface>(transform));
1335 if (result.isDefiniteFailure())
1336 return WalkResult::interrupt();
1337 if (result.isSilenceableFailure()) {
1338 if (overallDiag.succeeded()) {
1339 overallDiag = emitSilenceableError() << "actions failed";
1340 }
1341 overallDiag.attachNote(action->getLoc())
1342 << "failed action: " << result.getMessage();
1343 overallDiag.attachNote(op->getLoc())
1344 << "when applied to this matching payload";
1345 (void)result.silence();
1346 continue;
1347 }
1348 }
1349 if (failed(detail::appendValueMappings(
1350 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1351 action.getFunctionBody().front().getTerminator()->getOperands(),
1352 state, getFlattenResults()))) {
1354 << "action @" << action.getName()
1355 << " has results associated with multiple payload entities, "
1356 "but flattening was not requested";
1357 return WalkResult::interrupt();
1358 }
1359 break;
1360 }
1361 return WalkResult::advance();
1362 });
1363 if (walkResult.wasInterrupted())
1365 }
1366
1367 // The root operation should not have been affected, so we can just reassign
1368 // the payload to the result. Note that we need to consume the root handle to
1369 // make sure any handles to operations inside, that could have been affected
1370 // by actions, are invalidated.
1371 results.set(llvm::cast<OpResult>(getUpdated()),
1372 state.getPayloadOps(getRoot()));
1373 for (auto &&[result, mapping] :
1374 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1375 results.setMappedValues(result, mapping);
1376 }
1377 return overallDiag;
1378}
1379
1380void transform::ForeachMatchOp::getAsmResultNames(
1381 OpAsmSetValueNameFn setNameFn) {
1382 setNameFn(getUpdated(), "updated_root");
1383 for (Value v : getForwardedOutputs()) {
1384 setNameFn(v, "yielded");
1385 }
1386}
1387
1388void transform::ForeachMatchOp::getEffects(
1390 // Bail if invalid.
1391 if (getOperation()->getNumOperands() < 1 ||
1392 getOperation()->getNumResults() < 1) {
1393 return modifiesPayload(effects);
1394 }
1395
1396 consumesHandle(getRootMutable(), effects);
1397 onlyReadsHandle(getForwardedInputsMutable(), effects);
1398 producesHandle(getOperation()->getOpResults(), effects);
1399 modifiesPayload(effects);
1400}
1401
1402/// Parses the comma-separated list of symbol reference pairs of the format
1403/// `@matcher -> @action`.
1404static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1406 ArrayAttr &actions) {
1407 StringAttr matcher;
1408 StringAttr action;
1409 SmallVector<Attribute> matcherList;
1410 SmallVector<Attribute> actionList;
1411 do {
1412 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1413 parser.parseSymbolName(action)) {
1414 return failure();
1415 }
1416 matcherList.push_back(SymbolRefAttr::get(matcher));
1417 actionList.push_back(SymbolRefAttr::get(action));
1418 } while (parser.parseOptionalComma().succeeded());
1419
1420 matchers = parser.getBuilder().getArrayAttr(matcherList);
1421 actions = parser.getBuilder().getArrayAttr(actionList);
1422 return success();
1423}
1424
1425/// Prints the comma-separated list of symbol reference pairs of the format
1426/// `@matcher -> @action`.
1428 ArrayAttr matchers, ArrayAttr actions) {
1429 printer.increaseIndent();
1430 printer.increaseIndent();
1431 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1432 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1433 printer.printNewline();
1434 printer << cast<SymbolRefAttr>(matcher) << " -> "
1435 << cast<SymbolRefAttr>(action);
1436 if (idx != matchers.size() - 1)
1437 printer << ", ";
1438 }
1439 printer.decreaseIndent();
1440 printer.decreaseIndent();
1441}
1442
1443LogicalResult transform::ForeachMatchOp::verify() {
1444 if (getMatchers().size() != getActions().size())
1445 return emitOpError() << "expected the same number of matchers and actions";
1446 if (getMatchers().empty())
1447 return emitOpError() << "expected at least one match/action pair";
1448
1450 for (Attribute name : getMatchers()) {
1451 if (matcherNames.insert(name).second)
1452 continue;
1453 emitWarning() << "matcher " << name
1454 << " is used more than once, only the first match will apply";
1455 }
1456
1457 return success();
1458}
1459
1460/// Checks that the attributes of the function-like operation have correct
1461/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1462/// annotations being present even if they can be inferred from the body.
1464verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1465 bool alsoVerifyInternal = false) {
1466 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1467 llvm::SmallDenseSet<unsigned> consumedArguments;
1468 if (!op.isExternal()) {
1469 transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1470 consumedArguments);
1471 }
1472 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1473 bool isConsumed =
1474 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1475 nullptr;
1476 bool isReadOnly =
1477 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1478 nullptr;
1479 if (isConsumed && isReadOnly) {
1480 return transformOp.emitSilenceableError()
1481 << "argument #" << i << " cannot be both readonly and consumed";
1482 }
1483 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1484 return transformOp.emitSilenceableError()
1485 << "must provide consumed/readonly status for arguments of "
1486 "external or called ops";
1487 }
1488 if (op.isExternal())
1489 continue;
1490
1491 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1492 return transformOp.emitSilenceableError()
1493 << "argument #" << i
1494 << " is consumed in the body but is not marked as such";
1495 }
1496 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1497 // Cannot use op.emitWarning() here as it would attempt to verify the op
1498 // before printing, resulting in infinite recursion.
1499 emitWarning(op->getLoc())
1500 << "op argument #" << i
1501 << " is not consumed in the body but is marked as consumed";
1502 }
1503 }
1505}
1506
1507LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1508 SymbolTableCollection &symbolTable) {
1509 assert(getMatchers().size() == getActions().size());
1510 auto consumedAttr =
1511 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1512 for (auto &&[matcher, action] :
1513 llvm::zip_equal(getMatchers(), getActions())) {
1514 // Presence and typing.
1515 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1516 symbolTable.lookupNearestSymbolFrom(getOperation(),
1517 cast<SymbolRefAttr>(matcher)));
1518 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1519 symbolTable.lookupNearestSymbolFrom(getOperation(),
1520 cast<SymbolRefAttr>(action)));
1521 if (!matcherSymbol ||
1522 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1523 return emitError() << "unresolved matcher symbol " << matcher;
1524 if (!actionSymbol ||
1525 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1526 return emitError() << "unresolved action symbol " << action;
1527
1529 /*emitWarnings=*/false,
1530 /*alsoVerifyInternal=*/true)
1531 .checkAndReport())) {
1532 return failure();
1533 }
1535 /*emitWarnings=*/false,
1536 /*alsoVerifyInternal=*/true)
1537 .checkAndReport())) {
1538 return failure();
1539 }
1540
1541 // Input -> matcher forwarding.
1542 TypeRange operandTypes = getOperandTypes();
1543 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1544 if (operandTypes.size() != matcherArguments.size()) {
1546 emitError() << "the number of operands (" << operandTypes.size()
1547 << ") doesn't match the number of matcher arguments ("
1548 << matcherArguments.size() << ") for " << matcher;
1549 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1550 return diag;
1551 }
1552 for (auto &&[i, operand, argument] :
1553 llvm::enumerate(operandTypes, matcherArguments)) {
1554 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1556 emitOpError()
1557 << "does not expect matcher symbol to consume its operand #" << i;
1558 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1559 return diag;
1560 }
1561
1562 if (implementSameTransformInterface(operand, argument))
1563 continue;
1564
1566 emitError()
1567 << "mismatching type interfaces for operand and matcher argument #"
1568 << i << " of matcher " << matcher;
1569 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1570 return diag;
1571 }
1572
1573 // Matcher -> action forwarding.
1574 TypeRange matcherResults = matcherSymbol.getResultTypes();
1575 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1576 if (matcherResults.size() != actionArguments.size()) {
1577 return emitError() << "mismatching number of matcher results and "
1578 "action arguments between "
1579 << matcher << " (" << matcherResults.size() << ") and "
1580 << action << " (" << actionArguments.size() << ")";
1581 }
1582 for (auto &&[i, matcherType, actionType] :
1583 llvm::enumerate(matcherResults, actionArguments)) {
1584 if (implementSameTransformInterface(matcherType, actionType))
1585 continue;
1586
1587 return emitError() << "mismatching type interfaces for matcher result "
1588 "and action argument #"
1589 << i << "of matcher " << matcher << " and action "
1590 << action;
1591 }
1592
1593 // Action -> result forwarding.
1594 TypeRange actionResults = actionSymbol.getResultTypes();
1595 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1596 if (actionResults.size() != resultTypes.size()) {
1598 emitError() << "the number of action results ("
1599 << actionResults.size() << ") for " << action
1600 << " doesn't match the number of extra op results ("
1601 << resultTypes.size() << ")";
1602 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1603 return diag;
1604 }
1605 for (auto &&[i, resultType, actionType] :
1606 llvm::enumerate(resultTypes, actionResults)) {
1607 if (implementSameTransformInterface(resultType, actionType))
1608 continue;
1609
1611 emitError() << "mismatching type interfaces for action result #" << i
1612 << " of action " << action << " and op result";
1613 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1614 return diag;
1615 }
1616 }
1617 return success();
1618}
1619
1620//===----------------------------------------------------------------------===//
1621// ForeachOp
1622//===----------------------------------------------------------------------===//
1623
1625transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1628 // We store the payloads before executing the body as ops may be removed from
1629 // the mapping by the TrackingRewriter while iteration is in progress.
1631 detail::prepareValueMappings(payloads, getTargets(), state);
1632 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1633 bool withZipShortest = getWithZipShortest();
1634
1635 // In case of `zip_shortest`, set the number of iterations to the
1636 // smallest payload in the targets.
1637 if (withZipShortest) {
1638 numIterations =
1639 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1640 const SmallVector<MappedValue> &b) {
1641 return a.size() < b.size();
1642 })->size();
1643
1644 for (auto &payload : payloads)
1645 payload.resize(numIterations);
1646 }
1647
1648 // As we will be "zipping" over them, check all payloads have the same size.
1649 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1650 // when true.
1651 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1652 argIdx++) {
1653 if (payloads[argIdx].size() != numIterations) {
1654 return emitSilenceableError()
1655 << "prior targets' payload size (" << numIterations
1656 << ") differs from payload size (" << payloads[argIdx].size()
1657 << ") of target " << getTargets()[argIdx];
1658 }
1659 }
1660
1661 // Start iterating, indexing into payloads to obtain the right arguments to
1662 // call the body with - each slice of payloads at the same argument index
1663 // corresponding to a tuple to use as the body's block arguments.
1664 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1665 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1666 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1667 auto scope = state.make_region_scope(getBody());
1668 // Set up arguments to the region's block.
1669 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1670 MappedValue argument = payloads[argIdx][iterIdx];
1671 // Note that each blockArg's handle gets associated with just a single
1672 // element from the corresponding target's payload.
1673 if (failed(state.mapBlockArgument(blockArg, {argument})))
1675 }
1676
1677 // Execute loop body.
1678 for (Operation &transform : getBody().front().without_terminator()) {
1680 llvm::cast<transform::TransformOpInterface>(transform));
1681 if (!result.succeeded())
1682 return result;
1683 }
1684
1685 // Append yielded payloads to corresponding results from prior iterations.
1686 OperandRange yieldOperands = getYieldOp().getOperands();
1687 for (auto &&[result, yieldOperand, resTuple] :
1688 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1689 // NB: each iteration we add any number of ops/vals/params to a result.
1690 if (isa<TransformHandleTypeInterface>(result.getType()))
1691 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1692 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1693 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1694 else if (isa<TransformParamTypeInterface>(result.getType()))
1695 llvm::append_range(resTuple, state.getParams(yieldOperand));
1696 else
1697 assert(false && "unhandled handle type");
1698 }
1699
1700 // Associate the accumulated result payloads to the op's actual results.
1701 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1702 results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1703
1705}
1706
1707void transform::ForeachOp::getEffects(
1709 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1710 // arity errors, this method might get called before/in absence of `verify()`.
1711 for (auto &&[target, blockArg] :
1712 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1713 BlockArgument blockArgument = blockArg;
1714 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1715 return isHandleConsumed(blockArgument,
1716 cast<TransformOpInterface>(&op));
1717 })) {
1718 consumesHandle(target, effects);
1719 } else {
1720 onlyReadsHandle(target, effects);
1721 }
1722 }
1723
1724 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1725 return doesModifyPayload(cast<TransformOpInterface>(&op));
1726 })) {
1727 modifiesPayload(effects);
1728 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1729 return doesReadPayload(cast<TransformOpInterface>(&op));
1730 })) {
1731 onlyReadsPayload(effects);
1732 }
1733
1734 producesHandle(getOperation()->getOpResults(), effects);
1735}
1736
1737void transform::ForeachOp::getSuccessorRegions(
1739 Region *bodyRegion = &getBody();
1740 if (point.isParent()) {
1741 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1742 return;
1743 }
1744
1745 // Branch back to the region or the parent.
1747 &getBody() &&
1748 "unexpected region index");
1749 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1750 regions.emplace_back(getOperation(), getOperation()->getResults());
1751}
1752
1754transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
1755 // Each block argument handle is mapped to a subset (one op to be precise)
1756 // of the payload of the corresponding `targets` operand of ForeachOp.
1757 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
1758 return getOperation()->getOperands();
1759}
1760
1761transform::YieldOp transform::ForeachOp::getYieldOp() {
1762 return cast<transform::YieldOp>(getBody().front().getTerminator());
1763}
1764
1765LogicalResult transform::ForeachOp::verify() {
1766 for (auto [targetOpt, bodyArgOpt] :
1767 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1768 if (!targetOpt || !bodyArgOpt)
1769 return emitOpError() << "expects the same number of targets as the body "
1770 "has block arguments";
1771 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1772 return emitOpError(
1773 "expects co-indexed targets and the body's "
1774 "block arguments to have the same op/value/param type");
1775 }
1776
1777 for (auto [resultOpt, yieldOperandOpt] :
1778 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1779 if (!resultOpt || !yieldOperandOpt)
1780 return emitOpError() << "expects the same number of results as the "
1781 "yield terminator has operands";
1782 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1783 return emitOpError("expects co-indexed results and yield "
1784 "operands to have the same op/value/param type");
1785 }
1786
1787 return success();
1788}
1789
1790//===----------------------------------------------------------------------===//
1791// GetParentOp
1792//===----------------------------------------------------------------------===//
1793
1795transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1799 DenseSet<Operation *> resultSet;
1800 for (Operation *target : state.getPayloadOps(getTarget())) {
1801 Operation *parent = target;
1802 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1803 parent = parent->getParentOp();
1804 while (parent) {
1805 bool checkIsolatedFromAbove =
1806 !getIsolatedFromAbove() ||
1808 bool checkOpName = !getOpName().has_value() ||
1809 parent->getName().getStringRef() == *getOpName();
1810 if (checkIsolatedFromAbove && checkOpName)
1811 break;
1812 parent = parent->getParentOp();
1813 }
1814 if (!parent) {
1815 if (getAllowEmptyResults()) {
1816 results.set(llvm::cast<OpResult>(getResult()), parents);
1818 }
1820 emitSilenceableError()
1821 << "could not find a parent op that matches all requirements";
1822 diag.attachNote(target->getLoc()) << "target op";
1823 return diag;
1824 }
1825 }
1826 if (getDeduplicate()) {
1827 if (resultSet.insert(parent).second)
1828 parents.push_back(parent);
1829 } else {
1830 parents.push_back(parent);
1831 }
1832 }
1833 results.set(llvm::cast<OpResult>(getResult()), parents);
1835}
1836
1837//===----------------------------------------------------------------------===//
1838// GetConsumersOfResult
1839//===----------------------------------------------------------------------===//
1840
1842transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1845 int64_t resultNumber = getResultNumber();
1846 auto payloadOps = state.getPayloadOps(getTarget());
1847 if (std::empty(payloadOps)) {
1848 results.set(cast<OpResult>(getResult()), {});
1850 }
1851 if (!llvm::hasSingleElement(payloadOps))
1852 return emitDefiniteFailure()
1853 << "handle must be mapped to exactly one payload op";
1854
1855 Operation *target = *payloadOps.begin();
1856 if (target->getNumResults() <= resultNumber)
1857 return emitDefiniteFailure() << "result number overflow";
1858 results.set(llvm::cast<OpResult>(getResult()),
1859 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1861}
1862
1863//===----------------------------------------------------------------------===//
1864// GetDefiningOp
1865//===----------------------------------------------------------------------===//
1866
1868transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1871 SmallVector<Operation *> definingOps;
1872 for (Value v : state.getPayloadValues(getTarget())) {
1873 if (llvm::isa<BlockArgument>(v)) {
1875 emitSilenceableError() << "cannot get defining op of block argument";
1876 diag.attachNote(v.getLoc()) << "target value";
1877 return diag;
1878 }
1879 definingOps.push_back(v.getDefiningOp());
1880 }
1881 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1883}
1884
1885//===----------------------------------------------------------------------===//
1886// GetProducerOfOperand
1887//===----------------------------------------------------------------------===//
1888
1890transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1893 int64_t operandNumber = getOperandNumber();
1894 SmallVector<Operation *> producers;
1895 for (Operation *target : state.getPayloadOps(getTarget())) {
1896 Operation *producer =
1897 target->getNumOperands() <= operandNumber
1898 ? nullptr
1899 : target->getOperand(operandNumber).getDefiningOp();
1900 if (!producer) {
1902 emitSilenceableError()
1903 << "could not find a producer for operand number: " << operandNumber
1904 << " of " << *target;
1905 diag.attachNote(target->getLoc()) << "target op";
1906 return diag;
1907 }
1908 producers.push_back(producer);
1909 }
1910 results.set(llvm::cast<OpResult>(getResult()), producers);
1912}
1913
1914//===----------------------------------------------------------------------===//
1915// GetOperandOp
1916//===----------------------------------------------------------------------===//
1917
1919transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1922 SmallVector<Value> operands;
1923 for (Operation *target : state.getPayloadOps(getTarget())) {
1924 SmallVector<int64_t> operandPositions;
1926 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1927 target->getNumOperands(), operandPositions);
1928 if (diag.isSilenceableFailure()) {
1929 diag.attachNote(target->getLoc())
1930 << "while considering positions of this payload operation";
1931 return diag;
1932 }
1933 llvm::append_range(operands,
1934 llvm::map_range(operandPositions, [&](int64_t pos) {
1935 return target->getOperand(pos);
1936 }));
1937 }
1938 results.setValues(cast<OpResult>(getResult()), operands);
1940}
1941
1942LogicalResult transform::GetOperandOp::verify() {
1943 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1944 getIsInverted(), getIsAll());
1945}
1946
1947//===----------------------------------------------------------------------===//
1948// GetResultOp
1949//===----------------------------------------------------------------------===//
1950
1952transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1955 SmallVector<Value> opResults;
1956 for (Operation *target : state.getPayloadOps(getTarget())) {
1957 SmallVector<int64_t> resultPositions;
1959 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1960 target->getNumResults(), resultPositions);
1961 if (diag.isSilenceableFailure()) {
1962 diag.attachNote(target->getLoc())
1963 << "while considering positions of this payload operation";
1964 return diag;
1965 }
1966 llvm::append_range(opResults,
1967 llvm::map_range(resultPositions, [&](int64_t pos) {
1968 return target->getResult(pos);
1969 }));
1970 }
1971 results.setValues(cast<OpResult>(getResult()), opResults);
1973}
1974
1975LogicalResult transform::GetResultOp::verify() {
1976 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1977 getIsInverted(), getIsAll());
1978}
1979
1980//===----------------------------------------------------------------------===//
1981// GetTypeOp
1982//===----------------------------------------------------------------------===//
1983
1984void transform::GetTypeOp::getEffects(
1986 onlyReadsHandle(getValueMutable(), effects);
1987 producesHandle(getOperation()->getOpResults(), effects);
1988 onlyReadsPayload(effects);
1989}
1990
1992transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1996 for (Value value : state.getPayloadValues(getValue())) {
1997 Type type = value.getType();
1998 if (getElemental()) {
1999 if (auto shaped = dyn_cast<ShapedType>(type)) {
2000 type = shaped.getElementType();
2001 }
2002 }
2003 params.push_back(TypeAttr::get(type));
2004 }
2005 results.setParams(cast<OpResult>(getResult()), params);
2007}
2008
2009//===----------------------------------------------------------------------===//
2010// IncludeOp
2011//===----------------------------------------------------------------------===//
2012
2013/// Applies the transform ops contained in `block`. Maps `results` to the same
2014/// values as the operands of the block terminator.
2016applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2018 transform::TransformResults &results) {
2019 // Apply the sequenced ops one by one.
2020 for (Operation &transform : block.without_terminator()) {
2022 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2023 if (result.isDefiniteFailure())
2024 return result;
2025
2026 if (result.isSilenceableFailure()) {
2027 if (mode == transform::FailurePropagationMode::Propagate) {
2028 // Propagate empty results in case of early exit.
2029 forwardEmptyOperands(&block, state, results);
2030 return result;
2031 }
2032 (void)result.silence();
2033 }
2034 }
2035
2036 // Forward the operation mapping for values yielded from the sequence to the
2037 // values produced by the sequence op.
2038 transform::detail::forwardTerminatorOperands(&block, state, results);
2040}
2041
2043transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2047 getOperation(), getTarget());
2048 assert(callee && "unverified reference to unknown symbol");
2049
2050 if (callee.isExternal())
2051 return emitDefiniteFailure() << "unresolved external named sequence";
2052
2053 // Map operands to block arguments.
2055 detail::prepareValueMappings(mappings, getOperands(), state);
2056 auto scope = state.make_region_scope(callee.getBody());
2057 for (auto &&[arg, map] :
2058 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2059 if (failed(state.mapBlockArgument(arg, map)))
2061 }
2062
2064 callee.getBody().front(), getFailurePropagationMode(), state, results);
2065
2066 if (!result.succeeded())
2067 return result;
2068
2069 mappings.clear();
2070 detail::prepareValueMappings(
2071 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2072 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2073 results.setMappedValues(result, mapping);
2074 return result;
2075}
2076
2078verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2079
2080void transform::IncludeOp::getEffects(
2082 // Always mark as modifying the payload.
2083 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2084 // only read, the payload may still be modified, so we currently stay on the
2085 // conservative side and always indicate modification. This may prevent some
2086 // code reordering.
2087 modifiesPayload(effects);
2088
2089 // Results are always produced.
2090 producesHandle(getOperation()->getOpResults(), effects);
2091
2092 // Adds default effects to operands and results. This will be added if
2093 // preconditions fail so the trait verifier doesn't complain about missing
2094 // effects and the real precondition failure is reported later on.
2095 auto defaultEffects = [&] {
2096 onlyReadsHandle(getOperation()->getOpOperands(), effects);
2097 };
2098
2099 // Bail if the callee is unknown. This may run as part of the verification
2100 // process before we verified the validity of the callee or of this op.
2101 auto target =
2102 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2103 if (!target)
2104 return defaultEffects();
2106 getOperation(), getTarget());
2107 if (!callee)
2108 return defaultEffects();
2109
2110 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2111 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2112 consumesHandle(getOperation()->getOpOperand(i), effects);
2113 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2114 onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2115 }
2116}
2117
2118LogicalResult
2119transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2120 // Access through indirection and do additional checking because this may be
2121 // running before the main op verifier.
2122 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2123 if (!targetAttr)
2124 return emitOpError() << "expects a 'target' symbol reference attribute";
2125
2126 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2127 *this, targetAttr);
2128 if (!target)
2129 return emitOpError() << "does not reference a named transform sequence";
2130
2131 FunctionType fnType = target.getFunctionType();
2132 if (fnType.getNumInputs() != getNumOperands())
2133 return emitError("incorrect number of operands for callee");
2134
2135 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2136 if (getOperand(i).getType() != fnType.getInput(i)) {
2137 return emitOpError("operand type mismatch: expected operand type ")
2138 << fnType.getInput(i) << ", but provided "
2139 << getOperand(i).getType() << " for operand number " << i;
2140 }
2141 }
2142
2143 if (fnType.getNumResults() != getNumResults())
2144 return emitError("incorrect number of results for callee");
2145
2146 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2147 Type resultType = getResult(i).getType();
2148 Type funcType = fnType.getResult(i);
2149 if (!implementSameTransformInterface(resultType, funcType)) {
2150 return emitOpError() << "type of result #" << i
2151 << " must implement the same transform dialect "
2152 "interface as the corresponding callee result";
2153 }
2154 }
2155
2157 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2158 /*alsoVerifyInternal=*/true)
2159 .checkAndReport();
2160}
2161
2162//===----------------------------------------------------------------------===//
2163// MatchOperationEmptyOp
2164//===----------------------------------------------------------------------===//
2165
2166DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2167 ::std::optional<::mlir::Operation *> maybeCurrent,
2169 if (!maybeCurrent.has_value()) {
2170 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2172 }
2173 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2174 return emitSilenceableError() << "operation is not empty";
2175}
2176
2177//===----------------------------------------------------------------------===//
2178// MatchOperationNameOp
2179//===----------------------------------------------------------------------===//
2180
2181DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2182 Operation *current, transform::TransformResults &results,
2184 StringRef currentOpName = current->getName().getStringRef();
2185 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2186 if (acceptedAttr.getValue() == currentOpName)
2188 }
2189 return emitSilenceableError() << "wrong operation name";
2190}
2191
2192//===----------------------------------------------------------------------===//
2193// MatchParamCmpIOp
2194//===----------------------------------------------------------------------===//
2195
2197transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2200 auto signedAPIntAsString = [&](const APInt &value) {
2201 std::string str;
2202 llvm::raw_string_ostream os(str);
2203 value.print(os, /*isSigned=*/true);
2204 return str;
2205 };
2206
2207 ArrayRef<Attribute> params = state.getParams(getParam());
2208 ArrayRef<Attribute> references = state.getParams(getReference());
2209
2210 if (params.size() != references.size()) {
2211 return emitSilenceableError()
2212 << "parameters have different payload lengths (" << params.size()
2213 << " vs " << references.size() << ")";
2214 }
2215
2216 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2217 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2218 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2219 if (!intAttr || !refAttr) {
2220 return emitDefiniteFailure()
2221 << "non-integer parameter value not expected";
2222 }
2223 if (intAttr.getType() != refAttr.getType()) {
2224 return emitDefiniteFailure()
2225 << "mismatching integer attribute types in parameter #" << i;
2226 }
2227 APInt value = intAttr.getValue();
2228 APInt refValue = refAttr.getValue();
2229
2230 // TODO: this copy will not be necessary in C++20.
2231 int64_t position = i;
2232 auto reportError = [&](StringRef direction) {
2234 emitSilenceableError() << "expected parameter to be " << direction
2235 << " " << signedAPIntAsString(refValue)
2236 << ", got " << signedAPIntAsString(value);
2237 diag.attachNote(getParam().getLoc())
2238 << "value # " << position
2239 << " associated with the parameter defined here";
2240 return diag;
2241 };
2242
2243 switch (getPredicate()) {
2244 case MatchCmpIPredicate::eq:
2245 if (value.eq(refValue))
2246 break;
2247 return reportError("equal to");
2248 case MatchCmpIPredicate::ne:
2249 if (value.ne(refValue))
2250 break;
2251 return reportError("not equal to");
2252 case MatchCmpIPredicate::lt:
2253 if (value.slt(refValue))
2254 break;
2255 return reportError("less than");
2256 case MatchCmpIPredicate::le:
2257 if (value.sle(refValue))
2258 break;
2259 return reportError("less than or equal to");
2260 case MatchCmpIPredicate::gt:
2261 if (value.sgt(refValue))
2262 break;
2263 return reportError("greater than");
2264 case MatchCmpIPredicate::ge:
2265 if (value.sge(refValue))
2266 break;
2267 return reportError("greater than or equal to");
2268 }
2269 }
2271}
2272
2273void transform::MatchParamCmpIOp::getEffects(
2275 onlyReadsHandle(getParamMutable(), effects);
2276 onlyReadsHandle(getReferenceMutable(), effects);
2277}
2278
2279//===----------------------------------------------------------------------===//
2280// ParamConstantOp
2281//===----------------------------------------------------------------------===//
2282
2284transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2287 results.setParams(cast<OpResult>(getParam()), {getValue()});
2289}
2290
2291//===----------------------------------------------------------------------===//
2292// MergeHandlesOp
2293//===----------------------------------------------------------------------===//
2294
2296transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2299 ValueRange handles = getHandles();
2300 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2301 SmallVector<Operation *> operations;
2302 for (Value operand : handles)
2303 llvm::append_range(operations, state.getPayloadOps(operand));
2304 if (!getDeduplicate()) {
2305 results.set(llvm::cast<OpResult>(getResult()), operations);
2307 }
2308
2309 SetVector<Operation *> uniqued(llvm::from_range, operations);
2310 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2312 }
2313
2314 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2316 for (Value attribute : handles)
2317 llvm::append_range(attrs, state.getParams(attribute));
2318 if (!getDeduplicate()) {
2319 results.setParams(cast<OpResult>(getResult()), attrs);
2321 }
2322
2323 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2324 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2326 }
2327
2328 assert(
2329 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2330 "expected value handle type");
2331 SmallVector<Value> payloadValues;
2332 for (Value value : handles)
2333 llvm::append_range(payloadValues, state.getPayloadValues(value));
2334 if (!getDeduplicate()) {
2335 results.setValues(cast<OpResult>(getResult()), payloadValues);
2337 }
2338
2339 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2340 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2342}
2343
2344bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2345 // Handles may be the same if deduplicating is enabled.
2346 return getDeduplicate();
2347}
2348
2349void transform::MergeHandlesOp::getEffects(
2351 onlyReadsHandle(getHandlesMutable(), effects);
2352 producesHandle(getOperation()->getOpResults(), effects);
2353
2354 // There are no effects on the Payload IR as this is only a handle
2355 // manipulation.
2356}
2357
2358OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2359 if (getDeduplicate() || getHandles().size() != 1)
2360 return {};
2361
2362 // If deduplication is not required and there is only one operand, it can be
2363 // used directly instead of merging.
2364 return getHandles().front();
2365}
2366
2367//===----------------------------------------------------------------------===//
2368// NamedSequenceOp
2369//===----------------------------------------------------------------------===//
2370
2372transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2375 if (isExternal())
2376 return emitDefiniteFailure() << "unresolved external named sequence";
2377
2378 // Map the entry block argument to the list of operations.
2379 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2380 // without attaching the interface / trait since that is tailored to a
2381 // dangling top-level op that does not get "called".
2382 auto scope = state.make_region_scope(getBody());
2383 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2384 state, this->getOperation(), getBody())))
2386
2387 return applySequenceBlock(getBody().front(),
2388 FailurePropagationMode::Propagate, state, results);
2389}
2390
2391void transform::NamedSequenceOp::getEffects(
2393
2394ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2397 parser, result, /*allowVariadic=*/false,
2398 getFunctionTypeAttrName(result.name),
2399 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2401 std::string &) { return builder.getFunctionType(inputs, results); },
2402 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2403}
2404
2405void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2407 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2408 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2409 getResAttrsAttrName());
2410}
2411
2412/// Verifies that a symbol function-like transform dialect operation has the
2413/// signature and the terminator that have conforming types, i.e., types
2414/// implementing the same transform dialect type interface. If `allowExternal`
2415/// is set, allow external symbols (declarations) and don't check the terminator
2416/// as it may not exist.
2418verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2419 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2422 << "cannot be defined inside another transform op";
2423 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2424 return diag;
2425 }
2426
2427 if (op.isExternal() || op.getFunctionBody().empty()) {
2428 if (allowExternal)
2430
2431 return emitSilenceableFailure(op) << "cannot be external";
2432 }
2433
2434 if (op.getFunctionBody().front().empty())
2435 return emitSilenceableFailure(op) << "expected a non-empty body block";
2436
2437 Operation *terminator = &op.getFunctionBody().front().back();
2438 if (!isa<transform::YieldOp>(terminator)) {
2440 << "expected '"
2441 << transform::YieldOp::getOperationName()
2442 << "' as terminator";
2443 diag.attachNote(terminator->getLoc()) << "terminator";
2444 return diag;
2445 }
2446
2447 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2448 return emitSilenceableFailure(terminator)
2449 << "expected terminator to have as many operands as the parent op "
2450 "has results";
2451 }
2452 for (auto [i, operandType, resultType] : llvm::zip_equal(
2453 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2454 terminator->getOperands().getType(), op.getResultTypes())) {
2455 if (operandType == resultType)
2456 continue;
2457 return emitSilenceableFailure(terminator)
2458 << "the type of the terminator operand #" << i
2459 << " must match the type of the corresponding parent op result ("
2460 << operandType << " vs " << resultType << ")";
2461 }
2462
2464}
2465
2466/// Verification of a NamedSequenceOp. This does not report the error
2467/// immediately, so it can be used to check for op's well-formedness before the
2468/// verifier runs, e.g., during trait verification.
2470verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2471 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2472 if (!parent->getAttr(
2473 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2476 << "expects the parent symbol table to have the '"
2477 << transform::TransformDialect::kWithNamedSequenceAttrName
2478 << "' attribute";
2479 diag.attachNote(parent->getLoc()) << "symbol table operation";
2480 return diag;
2481 }
2482 }
2483
2484 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2487 << "cannot be defined inside another transform op";
2488 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2489 return diag;
2490 }
2491
2492 if (op.isExternal() || op.getBody().empty())
2493 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2494 emitWarnings);
2495
2496 if (op.getBody().front().empty())
2497 return emitSilenceableFailure(op) << "expected a non-empty body block";
2498
2499 Operation *terminator = &op.getBody().front().back();
2500 if (!isa<transform::YieldOp>(terminator)) {
2502 << "expected '"
2503 << transform::YieldOp::getOperationName()
2504 << "' as terminator";
2505 diag.attachNote(terminator->getLoc()) << "terminator";
2506 return diag;
2507 }
2508
2509 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2510 return emitSilenceableFailure(terminator)
2511 << "expected terminator to have as many operands as the parent op "
2512 "has results";
2513 }
2514 for (auto [i, operandType, resultType] :
2515 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2516 terminator->getOperands().getType(),
2517 op.getFunctionType().getResults())) {
2518 if (operandType == resultType)
2519 continue;
2520 return emitSilenceableFailure(terminator)
2521 << "the type of the terminator operand #" << i
2522 << " must match the type of the corresponding parent op result ("
2523 << operandType << " vs " << resultType << ")";
2524 }
2525
2526 auto funcOp = cast<FunctionOpInterface>(*op);
2528 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2529 if (!diag.succeeded())
2530 return diag;
2531
2532 return verifyYieldingSingleBlockOp(funcOp,
2533 /*allowExternal=*/true);
2534}
2535
2536LogicalResult transform::NamedSequenceOp::verify() {
2537 // Actual verification happens in a separate function for reusability.
2538 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2539}
2540
2541template <typename FnTy>
2542static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2543 Type bbArgType, TypeRange extraBindingTypes,
2544 FnTy bodyBuilder) {
2545 SmallVector<Type> types;
2546 types.reserve(1 + extraBindingTypes.size());
2547 types.push_back(bbArgType);
2548 llvm::append_range(types, extraBindingTypes);
2549
2550 OpBuilder::InsertionGuard guard(builder);
2551 Region *region = state.regions.back().get();
2552 Block *bodyBlock =
2553 builder.createBlock(region, region->begin(), types,
2554 SmallVector<Location>(types.size(), state.location));
2555
2556 // Populate body.
2557 builder.setInsertionPointToStart(bodyBlock);
2558 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2559 bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2560 } else {
2561 bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2562 bodyBlock->getArguments().drop_front());
2563 }
2564}
2565
2566void transform::NamedSequenceOp::build(OpBuilder &builder,
2567 OperationState &state, StringRef symName,
2568 Type rootType, TypeRange resultTypes,
2569 SequenceBodyBuilderFn bodyBuilder,
2571 ArrayRef<DictionaryAttr> argAttrs) {
2573 builder.getStringAttr(symName));
2574 state.addAttribute(getFunctionTypeAttrName(state.name),
2575 TypeAttr::get(FunctionType::get(builder.getContext(),
2576 rootType, resultTypes)));
2577 state.attributes.append(attrs.begin(), attrs.end());
2578 state.addRegion();
2579
2580 buildSequenceBody(builder, state, rootType,
2581 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2582}
2583
2584//===----------------------------------------------------------------------===//
2585// NumAssociationsOp
2586//===----------------------------------------------------------------------===//
2587
2589transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2592 size_t numAssociations =
2594 .Case([&](TransformHandleTypeInterface opHandle) {
2595 return llvm::range_size(state.getPayloadOps(getHandle()));
2596 })
2597 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2598 return llvm::range_size(state.getPayloadValues(getHandle()));
2599 })
2600 .Case([&](TransformParamTypeInterface param) {
2601 return llvm::range_size(state.getParams(getHandle()));
2602 })
2603 .DefaultUnreachable("unknown kind of transform dialect type");
2604 results.setParams(cast<OpResult>(getNum()),
2605 rewriter.getI64IntegerAttr(numAssociations));
2607}
2608
2609LogicalResult transform::NumAssociationsOp::verify() {
2610 // Verify that the result type accepts an i64 attribute as payload.
2611 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2612 return resultType
2613 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2614 .checkAndReport();
2615}
2616
2617//===----------------------------------------------------------------------===//
2618// SelectOp
2619//===----------------------------------------------------------------------===//
2620
2622transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2626 auto payloadOps = state.getPayloadOps(getTarget());
2627 for (Operation *op : payloadOps) {
2628 if (op->getName().getStringRef() == getOpName())
2629 result.push_back(op);
2630 }
2631 results.set(cast<OpResult>(getResult()), result);
2633}
2634
2635//===----------------------------------------------------------------------===//
2636// SplitHandleOp
2637//===----------------------------------------------------------------------===//
2638
2639void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2640 Value target, int64_t numResultHandles) {
2641 result.addOperands(target);
2642 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2643}
2644
2646transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2649 int64_t numPayloads =
2651 .Case<TransformHandleTypeInterface>([&](auto x) {
2652 return llvm::range_size(state.getPayloadOps(getHandle()));
2653 })
2654 .Case<TransformValueHandleTypeInterface>([&](auto x) {
2655 return llvm::range_size(state.getPayloadValues(getHandle()));
2656 })
2657 .Case<TransformParamTypeInterface>([&](auto x) {
2658 return llvm::range_size(state.getParams(getHandle()));
2659 })
2660 .DefaultUnreachable("unknown transform dialect type interface");
2661
2662 auto produceNumOpsError = [&]() {
2663 return emitSilenceableError()
2664 << getHandle() << " expected to contain " << this->getNumResults()
2665 << " payloads but it contains " << numPayloads << " payloads";
2666 };
2667
2668 // Fail if there are more payload ops than results and no overflow result was
2669 // specified.
2670 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2671 return produceNumOpsError();
2672
2673 // Fail if there are more results than payload ops. Unless:
2674 // - "fail_on_payload_too_small" is set to "false", or
2675 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2676 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2677 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2678 return produceNumOpsError();
2679
2680 // Distribute payloads.
2681 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2682 if (getOverflowResult())
2683 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2684
2685 auto container = [&]() {
2686 if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2687 return llvm::map_to_vector(
2688 state.getPayloadOps(getHandle()),
2689 [](Operation *op) -> MappedValue { return op; });
2690 }
2691 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2692 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2693 [](Value v) -> MappedValue { return v; });
2694 }
2695 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2696 "unsupported kind of transform dialect type");
2697 return llvm::map_to_vector(state.getParams(getHandle()),
2698 [](Attribute a) -> MappedValue { return a; });
2699 }();
2700
2701 for (auto &&en : llvm::enumerate(container)) {
2702 int64_t resultNum = en.index();
2703 if (resultNum >= getNumResults())
2704 resultNum = *getOverflowResult();
2705 resultHandles[resultNum].push_back(en.value());
2706 }
2707
2708 // Set transform op results.
2709 for (auto &&it : llvm::enumerate(resultHandles))
2710 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2711 it.value());
2712
2714}
2715
2716void transform::SplitHandleOp::getEffects(
2718 onlyReadsHandle(getHandleMutable(), effects);
2719 producesHandle(getOperation()->getOpResults(), effects);
2720 // There are no effects on the Payload IR as this is only a handle
2721 // manipulation.
2722}
2723
2724LogicalResult transform::SplitHandleOp::verify() {
2725 if (getOverflowResult().has_value() &&
2726 !(*getOverflowResult() < getNumResults()))
2727 return emitOpError("overflow_result is not a valid result index");
2728
2729 for (Type resultType : getResultTypes()) {
2730 if (implementSameTransformInterface(getHandle().getType(), resultType))
2731 continue;
2732
2733 return emitOpError("expects result types to implement the same transform "
2734 "interface as the operand type");
2735 }
2736
2737 return success();
2738}
2739
2740//===----------------------------------------------------------------------===//
2741// ReplicateOp
2742//===----------------------------------------------------------------------===//
2743
2745transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2748 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2749 for (const auto &en : llvm::enumerate(getHandles())) {
2750 Value handle = en.value();
2751 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2752 SmallVector<Operation *> current =
2753 llvm::to_vector(state.getPayloadOps(handle));
2755 payload.reserve(numRepetitions * current.size());
2756 for (unsigned i = 0; i < numRepetitions; ++i)
2757 llvm::append_range(payload, current);
2758 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2759 } else {
2760 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2761 "expected param type");
2762 ArrayRef<Attribute> current = state.getParams(handle);
2764 params.reserve(numRepetitions * current.size());
2765 for (unsigned i = 0; i < numRepetitions; ++i)
2766 llvm::append_range(params, current);
2767 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2768 params);
2769 }
2770 }
2772}
2773
2774void transform::ReplicateOp::getEffects(
2776 onlyReadsHandle(getPatternMutable(), effects);
2777 onlyReadsHandle(getHandlesMutable(), effects);
2778 producesHandle(getOperation()->getOpResults(), effects);
2779}
2780
2781//===----------------------------------------------------------------------===//
2782// SequenceOp
2783//===----------------------------------------------------------------------===//
2784
2786transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2789 // Map the entry block argument to the list of operations.
2790 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2791 if (failed(mapBlockArguments(state)))
2793
2794 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2795 results);
2796}
2797
2798static ParseResult parseSequenceOpOperands(
2799 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2800 Type &rootType,
2802 SmallVectorImpl<Type> &extraBindingTypes) {
2804 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2805 if (!hasRoot.has_value()) {
2806 root = std::nullopt;
2807 return success();
2808 }
2809 if (failed(hasRoot.value()))
2810 return failure();
2811 root = rootOperand;
2812
2813 if (succeeded(parser.parseOptionalComma())) {
2814 if (failed(parser.parseOperandList(extraBindings)))
2815 return failure();
2816 }
2817 if (failed(parser.parseColon()))
2818 return failure();
2819
2820 // The paren is truly optional.
2821 (void)parser.parseOptionalLParen();
2822
2823 if (failed(parser.parseType(rootType))) {
2824 return failure();
2825 }
2826
2827 if (!extraBindings.empty()) {
2828 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2829 return failure();
2830 }
2831
2832 if (extraBindingTypes.size() != extraBindings.size()) {
2833 return parser.emitError(parser.getNameLoc(),
2834 "expected types to be provided for all operands");
2835 }
2836
2837 // The paren is truly optional.
2838 (void)parser.parseOptionalRParen();
2839 return success();
2840}
2841
2843 Value root, Type rootType,
2844 ValueRange extraBindings,
2845 TypeRange extraBindingTypes) {
2846 if (!root)
2847 return;
2848
2849 printer << root;
2850 bool hasExtras = !extraBindings.empty();
2851 if (hasExtras) {
2852 printer << ", ";
2853 printer.printOperands(extraBindings);
2854 }
2855
2856 printer << " : ";
2857 if (hasExtras)
2858 printer << "(";
2859
2860 printer << rootType;
2861 if (hasExtras)
2862 printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2863}
2864
2865/// Returns `true` if the given op operand may be consuming the handle value in
2866/// the Transform IR. That is, if it may have a Free effect on it.
2868 // Conservatively assume the effect being present in absence of the interface.
2869 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2870 if (!iface)
2871 return true;
2872
2873 return isHandleConsumed(use.get(), iface);
2874}
2875
2876LogicalResult
2878 function_ref<InFlightDiagnostic()> reportError) {
2879 OpOperand *potentialConsumer = nullptr;
2880 for (OpOperand &use : value.getUses()) {
2882 continue;
2883
2884 if (!potentialConsumer) {
2885 potentialConsumer = &use;
2886 continue;
2887 }
2888
2889 InFlightDiagnostic diag = reportError()
2890 << " has more than one potential consumer";
2891 diag.attachNote(potentialConsumer->getOwner()->getLoc())
2892 << "used here as operand #" << potentialConsumer->getOperandNumber();
2893 diag.attachNote(use.getOwner()->getLoc())
2894 << "used here as operand #" << use.getOperandNumber();
2895 return diag;
2896 }
2897
2898 return success();
2899}
2900
2901LogicalResult transform::SequenceOp::verify() {
2902 assert(getBodyBlock()->getNumArguments() >= 1 &&
2903 "the number of arguments must have been verified to be more than 1 by "
2904 "PossibleTopLevelTransformOpTrait");
2905
2906 if (!getRoot() && !getExtraBindings().empty()) {
2907 return emitOpError()
2908 << "does not expect extra operands when used as top-level";
2909 }
2910
2911 // Check if a block argument has more than one consuming use.
2912 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2913 if (failed(checkDoubleConsume(arg, [this, arg]() {
2914 return (emitOpError() << "block argument #" << arg.getArgNumber());
2915 }))) {
2916 return failure();
2917 }
2918 }
2919
2920 // Check properties of the nested operations they cannot check themselves.
2921 for (Operation &child : *getBodyBlock()) {
2922 if (!isa<TransformOpInterface>(child) &&
2923 &child != &getBodyBlock()->back()) {
2925 emitOpError()
2926 << "expected children ops to implement TransformOpInterface";
2927 diag.attachNote(child.getLoc()) << "op without interface";
2928 return diag;
2929 }
2930
2931 for (OpResult result : child.getResults()) {
2932 auto report = [&]() {
2933 return (child.emitError() << "result #" << result.getResultNumber());
2934 };
2935 if (failed(checkDoubleConsume(result, report)))
2936 return failure();
2937 }
2938 }
2939
2940 if (!getBodyBlock()->mightHaveTerminator())
2941 return emitOpError() << "expects to have a terminator in the body";
2942
2943 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2944 getOperation()->getResultTypes()) {
2946 << "expects the types of the terminator operands "
2947 "to match the types of the result";
2948 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2949 return diag;
2950 }
2951 return success();
2952}
2953
2954void transform::SequenceOp::getEffects(
2957}
2958
2960transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2961 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
2962 if (getOperation()->getNumOperands() > 0)
2963 return getOperation()->getOperands();
2964 return OperandRange(getOperation()->operand_end(),
2965 getOperation()->operand_end());
2966}
2967
2968void transform::SequenceOp::getSuccessorRegions(
2970 if (point.isParent()) {
2971 Region *bodyRegion = &getBody();
2972 regions.emplace_back(bodyRegion, getNumOperands() != 0
2973 ? bodyRegion->getArguments()
2975 return;
2976 }
2977
2979 &getBody() &&
2980 "unexpected region index");
2981 regions.emplace_back(getOperation(), getOperation()->getResults());
2982}
2983
2984void transform::SequenceOp::getRegionInvocationBounds(
2986 (void)operands;
2987 bounds.emplace_back(1, 1);
2988}
2989
2990void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2991 TypeRange resultTypes,
2992 FailurePropagationMode failurePropagationMode,
2993 Value root,
2994 SequenceBodyBuilderFn bodyBuilder) {
2995 build(builder, state, resultTypes, failurePropagationMode, root,
2996 /*extra_bindings=*/ValueRange());
2997 Type bbArgType = root.getType();
2998 buildSequenceBody(builder, state, bbArgType,
2999 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3000}
3001
3002void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3003 TypeRange resultTypes,
3004 FailurePropagationMode failurePropagationMode,
3005 Value root, ValueRange extraBindings,
3006 SequenceBodyBuilderArgsFn bodyBuilder) {
3007 build(builder, state, resultTypes, failurePropagationMode, root,
3008 extraBindings);
3009 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3010 bodyBuilder);
3011}
3012
3013void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3014 TypeRange resultTypes,
3015 FailurePropagationMode failurePropagationMode,
3016 Type bbArgType,
3017 SequenceBodyBuilderFn bodyBuilder) {
3018 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3019 /*extra_bindings=*/ValueRange());
3020 buildSequenceBody(builder, state, bbArgType,
3021 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3022}
3023
3024void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3025 TypeRange resultTypes,
3026 FailurePropagationMode failurePropagationMode,
3027 Type bbArgType, TypeRange extraBindingTypes,
3028 SequenceBodyBuilderArgsFn bodyBuilder) {
3029 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3030 /*extra_bindings=*/ValueRange());
3031 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3032}
3033
3034//===----------------------------------------------------------------------===//
3035// PrintOp
3036//===----------------------------------------------------------------------===//
3037
3038void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3039 StringRef name) {
3040 if (!name.empty())
3041 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3042}
3043
3044void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3045 Value target, StringRef name) {
3046 result.addOperands({target});
3047 build(builder, result, name);
3048}
3049
3051transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3054 llvm::outs() << "[[[ IR printer: ";
3055 if (getName().has_value())
3056 llvm::outs() << *getName() << " ";
3057
3058 OpPrintingFlags printFlags;
3059 if (getAssumeVerified().value_or(false))
3060 printFlags.assumeVerified();
3061 if (getUseLocalScope().value_or(false))
3062 printFlags.useLocalScope();
3063 if (getSkipRegions().value_or(false))
3064 printFlags.skipRegions();
3065
3066 if (!getTarget()) {
3067 llvm::outs() << "top-level ]]]\n";
3068 state.getTopLevel()->print(llvm::outs(), printFlags);
3069 llvm::outs() << "\n";
3070 llvm::outs().flush();
3072 }
3073
3074 llvm::outs() << "]]]\n";
3075 for (Operation *target : state.getPayloadOps(getTarget())) {
3076 target->print(llvm::outs(), printFlags);
3077 llvm::outs() << "\n";
3078 }
3079
3080 llvm::outs().flush();
3082}
3083
3084void transform::PrintOp::getEffects(
3086 // We don't really care about mutability here, but `getTarget` now
3087 // unconditionally casts to a specific type before verification could run
3088 // here.
3089 if (!getTargetMutable().empty())
3090 onlyReadsHandle(getTargetMutable()[0], effects);
3091 onlyReadsPayload(effects);
3092
3093 // There is no resource for stderr file descriptor, so just declare print
3094 // writes into the default resource.
3095 effects.emplace_back(MemoryEffects::Write::get());
3096}
3097
3098//===----------------------------------------------------------------------===//
3099// VerifyOp
3100//===----------------------------------------------------------------------===//
3101
3103transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3109 << "failed to verify payload op";
3110 diag.attachNote(target->getLoc()) << "payload op";
3111 return diag;
3112 }
3114}
3115
3116void transform::VerifyOp::getEffects(
3118 transform::onlyReadsHandle(getTargetMutable(), effects);
3119}
3120
3121//===----------------------------------------------------------------------===//
3122// YieldOp
3123//===----------------------------------------------------------------------===//
3124
3125void transform::YieldOp::getEffects(
3127 onlyReadsHandle(getOperandsMutable(), effects);
3128}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
Definition DLTI.cpp:38
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue > > blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getOpResults()
Definition Operation.h:420
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition Region.cpp:62
iterator begin()
Definition Region.h:55
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:378
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.