MLIR 23.0.0git
TransformOps.cpp
Go to the documentation of this file.
1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Dominance.h"
24#include "mlir/IR/Verifier.h"
30#include "mlir/Transforms/CSE.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/DebugLog.h"
41#include "llvm/Support/ErrorHandling.h"
42#include "llvm/Support/InterleavedRange.h"
43#include <optional>
44
45#define DEBUG_TYPE "transform-dialect"
46#define DEBUG_TYPE_MATCHER "transform-matcher"
47
48using namespace mlir;
49
50static ParseResult parseApplyRegisteredPassOptions(
51 OpAsmParser &parser, DictionaryAttr &options,
54 Operation *op,
55 DictionaryAttr options,
56 ValueRange dynamicOptions);
57static ParseResult parseSequenceOpOperands(
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
59 Type &rootType,
61 SmallVectorImpl<Type> &extraBindingTypes);
62static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
63 Value root, Type rootType,
64 ValueRange extraBindings,
65 TypeRange extraBindingTypes);
66static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
68static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
70 ArrayAttr &actions);
71
72/// Helper function to check if the given transform op is contained in (or
73/// equal to) the given payload target op. In that case, an error is returned.
74/// Transforming transform IR that is currently executing is generally unsafe.
76ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
77 Operation *payload) {
78 Operation *transformAncestor = transform.getOperation();
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
82 transform.emitDefiniteFailure()
83 << "cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->getLoc()) << "target payload op";
85 return diag;
86 }
87 transformAncestor = transformAncestor->getParentOp();
88 }
90}
91
92#define GET_OP_CLASSES
93#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
94
95//===----------------------------------------------------------------------===//
96// AlternativesOp
97//===----------------------------------------------------------------------===//
98
99OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
100 RegionSuccessor successor) {
101 if (!successor.isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
103 return OperandRange(getOperation()->operand_end(),
104 getOperation()->operand_end());
105}
106
107void transform::AlternativesOp::getSuccessorRegions(
109 for (Region &alternative : llvm::drop_begin(
110 getAlternatives(), point.isParent()
111 ? 0
113 ->getParentRegion()
114 ->getRegionNumber() +
115 1)) {
116 regions.emplace_back(&alternative);
117 }
118 if (!point.isParent())
119 regions.push_back(RegionSuccessor::parent());
120}
121
123transform::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) {
124 if (successor.isParent())
125 return getOperation()->getResults();
126 return successor.getSuccessor()->getArguments();
127}
128
129void transform::AlternativesOp::getRegionInvocationBounds(
131 (void)operands;
132 // The region corresponding to the first alternative is always executed, the
133 // remaining may or may not be executed.
134 bounds.reserve(getNumRegions());
135 bounds.emplace_back(1, 1);
136 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
137}
138
141 for (const auto &res : block->getParentOp()->getOpResults())
142 results.set(res, {});
143}
144
146transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
149 SmallVector<Operation *> originals;
150 if (Value scopeHandle = getScope())
151 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
152 else
153 originals.push_back(state.getTopLevel());
154
155 for (Operation *original : originals) {
156 if (original->isAncestor(getOperation())) {
158 << "scope must not contain the transforms being applied";
159 diag.attachNote(original->getLoc()) << "scope";
160 return diag;
161 }
162 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
164 << "only isolated-from-above ops can be alternative scopes";
165 diag.attachNote(original->getLoc()) << "scope";
166 return diag;
167 }
168 }
169
170 for (Region &reg : getAlternatives()) {
171 // Clone the scope operations and make the transforms in this alternative
172 // region apply to them by virtue of mapping the block argument (the only
173 // visible handle) to the cloned scope operations. This effectively prevents
174 // the transformation from accessing any IR outside the scope.
175 auto scope = state.make_region_scope(reg);
176 auto clones = llvm::to_vector(
177 llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
178 llvm::scope_exit deleteClones([&] {
179 for (Operation *clone : clones)
180 clone->erase();
181 });
182 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
184
185 bool failed = false;
186 for (Operation &transform : reg.front().without_terminator()) {
188 state.applyTransform(cast<TransformOpInterface>(transform));
189 if (result.isSilenceableFailure()) {
190 LDBG() << "alternative failed: " << result.getMessage();
191 failed = true;
192 break;
193 }
194
195 if (::mlir::failed(result.silence()))
197 }
198
199 // If all operations in the given alternative succeeded, no need to consider
200 // the rest. Replace the original scoping operation with the clone on which
201 // the transformations were performed.
202 if (!failed) {
203 // We will be using the clones, so cancel their scheduled deletion.
204 deleteClones.release();
205 TrackingListener listener(state, *this);
206 IRRewriter rewriter(getContext(), &listener);
207 for (const auto &kvp : llvm::zip(originals, clones)) {
208 Operation *original = std::get<0>(kvp);
209 Operation *clone = std::get<1>(kvp);
210 original->getBlock()->getOperations().insert(original->getIterator(),
211 clone);
212 rewriter.replaceOp(original, clone->getResults());
213 }
214 detail::forwardTerminatorOperands(&reg.front(), state, results);
216 }
217 }
218 return emitSilenceableError() << "all alternatives failed";
219}
220
221void transform::AlternativesOp::getEffects(
223 consumesHandle(getOperation()->getOpOperands(), effects);
224 producesHandle(getOperation()->getOpResults(), effects);
225 for (Region *region : getRegions()) {
226 if (!region->empty())
227 producesHandle(region->front().getArguments(), effects);
228 }
229 modifiesPayload(effects);
230}
231
232LogicalResult transform::AlternativesOp::verify() {
233 for (Region &alternative : getAlternatives()) {
234 Block &block = alternative.front();
235 Operation *terminator = block.getTerminator();
236 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
238 << "expects terminator operands to have the "
239 "same type as results of the operation";
240 diag.attachNote(terminator->getLoc()) << "terminator";
241 return diag;
242 }
243 }
244
245 return success();
246}
247
248//===----------------------------------------------------------------------===//
249// AnnotateOp
250//===----------------------------------------------------------------------===//
251
253transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
257 llvm::to_vector(state.getPayloadOps(getTarget()));
258
259 Attribute attr = UnitAttr::get(getContext());
260 if (auto paramH = getParam()) {
261 ArrayRef<Attribute> params = state.getParams(paramH);
262 if (params.size() != 1) {
263 if (targets.size() != params.size()) {
264 return emitSilenceableError()
265 << "parameter and target have different payload lengths ("
266 << params.size() << " vs " << targets.size() << ")";
267 }
268 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
269 target->setAttr(getName(), attr);
271 }
272 attr = params[0];
273 }
274 for (auto *target : targets)
275 target->setAttr(getName(), attr);
277}
278
279void transform::AnnotateOp::getEffects(
281 onlyReadsHandle(getTargetMutable(), effects);
282 onlyReadsHandle(getParamMutable(), effects);
283 modifiesPayload(effects);
284}
285
286//===----------------------------------------------------------------------===//
287// ApplyCommonSubexpressionEliminationOp
288//===----------------------------------------------------------------------===//
289
291transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
293 ApplyToEachResultList &results, transform::TransformState &state) {
294 // Make sure that this transform is not applied to itself. Modifying the
295 // transform IR while it is being interpreted is generally dangerous.
296 DiagnosedSilenceableFailure payloadCheck =
298 if (!payloadCheck.succeeded())
299 return payloadCheck;
300
301 DominanceInfo domInfo;
304}
305
306void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
308 transform::onlyReadsHandle(getTargetMutable(), effects);
310}
311
312//===----------------------------------------------------------------------===//
313// ApplyDeadCodeEliminationOp
314//===----------------------------------------------------------------------===//
315
316DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
318 ApplyToEachResultList &results, transform::TransformState &state) {
319 // Make sure that this transform is not applied to itself. Modifying the
320 // transform IR while it is being interpreted is generally dangerous.
321 DiagnosedSilenceableFailure payloadCheck =
323 if (!payloadCheck.succeeded())
324 return payloadCheck;
325
326 // Maintain a worklist of potentially dead ops.
327 SetVector<Operation *> worklist;
328
329 // Helper function that adds all defining ops of used values (operands and
330 // operands of nested ops).
331 auto addDefiningOpsToWorklist = [&](Operation *op) {
332 op->walk([&](Operation *op) {
333 for (Value v : op->getOperands())
334 if (Operation *defOp = v.getDefiningOp())
335 if (target->isProperAncestor(defOp))
336 worklist.insert(defOp);
337 });
338 };
339
340 // Helper function that erases an op.
341 auto eraseOp = [&](Operation *op) {
342 // Remove op and nested ops from the worklist.
343 op->walk([&](Operation *op) {
344 const auto *it = llvm::find(worklist, op);
345 if (it != worklist.end())
346 worklist.erase(it);
347 });
348 rewriter.eraseOp(op);
349 };
350
351 // Initial walk over the IR.
352 target->walk<WalkOrder::PostOrder>([&](Operation *op) {
353 if (op != target && isOpTriviallyDead(op)) {
354 addDefiningOpsToWorklist(op);
355 eraseOp(op);
356 }
357 });
358
359 // Erase all ops that have become dead.
360 while (!worklist.empty()) {
361 Operation *op = worklist.pop_back_val();
362 if (!isOpTriviallyDead(op))
363 continue;
364 addDefiningOpsToWorklist(op);
365 eraseOp(op);
366 }
367
369}
370
371void transform::ApplyDeadCodeEliminationOp::getEffects(
373 transform::onlyReadsHandle(getTargetMutable(), effects);
375}
376
377//===----------------------------------------------------------------------===//
378// ApplyPatternsOp
379//===----------------------------------------------------------------------===//
380
381DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
383 ApplyToEachResultList &results, transform::TransformState &state) {
384 // Make sure that this transform is not applied to itself. Modifying the
385 // transform IR while it is being interpreted is generally dangerous. Even
386 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
387 // performs many additional simplifications such as dead code elimination.
388 DiagnosedSilenceableFailure payloadCheck =
390 if (!payloadCheck.succeeded())
391 return payloadCheck;
392
393 // Gather all specified patterns.
394 MLIRContext *ctx = target->getContext();
396 if (!getRegion().empty()) {
397 for (Operation &op : getRegion().front()) {
398 cast<transform::PatternDescriptorOpInterface>(&op)
399 .populatePatternsWithState(patterns, state);
400 }
401 }
402
403 // Configure the GreedyPatternRewriteDriver.
405 config.setListener(
406 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
407 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
408
409 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
411 : getMaxIterations());
412 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
414 : getMaxNumRewrites());
415
416 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
417 // was requested, apply the greedy pattern rewrite only once. (The greedy
418 // pattern rewrite driver already iterates to a fixpoint internally.)
419 bool cseChanged = false;
420 // One or two iterations should be sufficient. Stop iterating after a certain
421 // threshold to make debugging easier.
422 static const int64_t kNumMaxIterations = 50;
423 int64_t iteration = 0;
424 do {
425 LogicalResult result = failure();
426 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
427 // Op is isolated from above. Apply patterns and also perform region
428 // simplification.
429 result = applyPatternsGreedily(target, frozenPatterns, config);
430 } else {
431 // Manually gather list of ops because the other
432 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
433 // from above. This way, patterns can be applied to ops that are not
434 // isolated from above. Regions are not being simplified. Furthermore,
435 // only a single greedy rewrite iteration is performed.
437 target->walk([&](Operation *nestedOp) {
438 if (target != nestedOp)
439 ops.push_back(nestedOp);
440 });
441 result = applyOpPatternsGreedily(ops, frozenPatterns, config);
442 }
443
444 // A failure typically indicates that the pattern application did not
445 // converge.
446 if (failed(result)) {
448 << "greedy pattern application failed";
449 }
450
451 if (getApplyCse()) {
452 DominanceInfo domInfo;
454 &cseChanged);
455 }
456 } while (cseChanged && ++iteration < kNumMaxIterations);
457
458 if (iteration == kNumMaxIterations)
459 return emitDefiniteFailure() << "fixpoint iteration did not converge";
460
462}
463
464LogicalResult transform::ApplyPatternsOp::verify() {
465 if (!getRegion().empty()) {
466 for (Operation &op : getRegion().front()) {
467 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
469 << "expected children ops to implement "
470 "PatternDescriptorOpInterface";
471 diag.attachNote(op.getLoc()) << "op without interface";
472 return diag;
473 }
474 }
475 }
476 return success();
477}
478
479void transform::ApplyPatternsOp::getEffects(
481 transform::onlyReadsHandle(getTargetMutable(), effects);
483}
484
485void transform::ApplyPatternsOp::build(
487 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
488 result.addOperands(target);
489
490 OpBuilder::InsertionGuard g(builder);
491 Region *region = result.addRegion();
492 builder.createBlock(region);
493 if (bodyBuilder)
494 bodyBuilder(builder, result.location);
495}
496
497//===----------------------------------------------------------------------===//
498// ApplyCanonicalizationPatternsOp
499//===----------------------------------------------------------------------===//
500
501void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
503 MLIRContext *ctx = patterns.getContext();
504 for (Dialect *dialect : ctx->getLoadedDialects())
505 dialect->getCanonicalizationPatterns(patterns);
507 op.getCanonicalizationPatterns(patterns, ctx);
508}
509
510//===----------------------------------------------------------------------===//
511// ApplyConversionPatternsOp
512//===----------------------------------------------------------------------===//
513
514DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
517 MLIRContext *ctx = getContext();
518
519 // Instantiate the default type converter if a type converter builder is
520 // specified.
521 std::unique_ptr<TypeConverter> defaultTypeConverter;
522 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
523 getDefaultTypeConverter();
524 if (typeConverterBuilder)
525 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
526
527 // Configure conversion target.
528 ConversionTarget conversionTarget(*getContext());
529 if (getLegalOps())
530 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
531 conversionTarget.addLegalOp(
532 OperationName(cast<StringAttr>(attr).getValue(), ctx));
533 if (getIllegalOps())
534 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
535 conversionTarget.addIllegalOp(
536 OperationName(cast<StringAttr>(attr).getValue(), ctx));
537 if (getLegalDialects())
538 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
539 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
540 if (getIllegalDialects())
541 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
542 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
543
544 // Gather all specified patterns.
546 // Need to keep the converters alive until after pattern application because
547 // the patterns take a reference to an object that would otherwise get out of
548 // scope.
550 if (!getPatterns().empty()) {
551 for (Operation &op : getPatterns().front()) {
552 auto descriptor =
553 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
554
555 // Check if this pattern set specifies a type converter.
556 std::unique_ptr<TypeConverter> typeConverter =
557 descriptor.getTypeConverter();
558 TypeConverter *converter = nullptr;
559 if (typeConverter) {
560 keepAliveConverters.emplace_back(std::move(typeConverter));
561 converter = keepAliveConverters.back().get();
562 } else {
563 // No type converter specified: Use the default type converter.
564 if (!defaultTypeConverter) {
566 << "pattern descriptor does not specify type "
567 "converter and apply_conversion_patterns op has "
568 "no default type converter";
569 diag.attachNote(op.getLoc()) << "pattern descriptor op";
570 return diag;
571 }
572 converter = defaultTypeConverter.get();
573 }
574
575 // Add descriptor-specific updates to the conversion target, which may
576 // depend on the final type converter. In structural converters, the
577 // legality of types dictates the dynamic legality of an operation.
578 descriptor.populateConversionTargetRules(*converter, conversionTarget);
579
580 descriptor.populatePatterns(*converter, patterns);
581 }
582 }
583
584 // Attach a tracking listener if handles should be preserved. We configure the
585 // listener to allow op replacements with different names, as conversion
586 // patterns typically replace ops with replacement ops that have a different
587 // name.
588 TrackingListenerConfig trackingConfig;
589 trackingConfig.requireMatchingReplacementOpName = false;
590 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
591 ConversionConfig conversionConfig;
592 if (getPreserveHandles())
593 conversionConfig.listener = &trackingListener;
594
595 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
596 for (Operation *target : state.getPayloadOps(getTarget())) {
597 // Make sure that this transform is not applied to itself. Modifying the
598 // transform IR while it is being interpreted is generally dangerous.
599 DiagnosedSilenceableFailure payloadCheck =
601 if (!payloadCheck.succeeded())
602 return payloadCheck;
603
604 LogicalResult status = failure();
605 if (getPartialConversion()) {
606 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
607 conversionConfig);
608 } else {
609 status = applyFullConversion(target, conversionTarget, frozenPatterns,
610 conversionConfig);
611 }
612
613 // Check dialect conversion state.
615 if (failed(status)) {
616 diag = emitSilenceableError() << "dialect conversion failed";
617 diag.attachNote(target->getLoc()) << "target op";
618 }
619
620 // Check tracking listener error state.
621 DiagnosedSilenceableFailure trackingFailure =
622 trackingListener.checkAndResetError();
623 if (!trackingFailure.succeeded()) {
624 if (diag.succeeded()) {
625 // Tracking failure is the only failure.
626 return trackingFailure;
627 }
628 diag.attachNote() << "tracking listener also failed: "
629 << trackingFailure.getMessage();
630 (void)trackingFailure.silence();
631 }
632
633 if (!diag.succeeded())
634 return diag;
635 }
636
638}
639
640LogicalResult transform::ApplyConversionPatternsOp::verify() {
641 if (getNumRegions() != 1 && getNumRegions() != 2)
642 return emitOpError() << "expected 1 or 2 regions";
643 if (!getPatterns().empty()) {
644 for (Operation &op : getPatterns().front()) {
645 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
647 emitOpError() << "expected pattern children ops to implement "
648 "ConversionPatternDescriptorOpInterface";
649 diag.attachNote(op.getLoc()) << "op without interface";
650 return diag;
651 }
652 }
653 }
654 if (getNumRegions() == 2) {
655 Region &typeConverterRegion = getRegion(1);
656 if (!llvm::hasSingleElement(typeConverterRegion.front()))
657 return emitOpError()
658 << "expected exactly one op in default type converter region";
659 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
660 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
661 maybeTypeConverter);
662 if (!typeConverterOp) {
664 << "expected default converter child op to "
665 "implement TypeConverterBuilderOpInterface";
666 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
667 return diag;
668 }
669 // Check default type converter type.
670 if (!getPatterns().empty()) {
671 for (Operation &op : getPatterns().front()) {
672 auto descriptor =
673 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
674 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
675 return failure();
676 }
677 }
678 }
679 return success();
680}
681
682void transform::ApplyConversionPatternsOp::getEffects(
684 if (!getPreserveHandles()) {
685 transform::consumesHandle(getTargetMutable(), effects);
686 } else {
687 transform::onlyReadsHandle(getTargetMutable(), effects);
688 }
690}
691
692void transform::ApplyConversionPatternsOp::build(
694 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
695 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
696 result.addOperands(target);
697
698 {
699 OpBuilder::InsertionGuard g(builder);
700 Region *region1 = result.addRegion();
701 builder.createBlock(region1);
702 if (patternsBodyBuilder)
703 patternsBodyBuilder(builder, result.location);
704 }
705 {
706 OpBuilder::InsertionGuard g(builder);
707 Region *region2 = result.addRegion();
708 builder.createBlock(region2);
709 if (typeConverterBodyBuilder)
710 typeConverterBodyBuilder(builder, result.location);
711 }
712}
713
714//===----------------------------------------------------------------------===//
715// ApplyToLLVMConversionPatternsOp
716//===----------------------------------------------------------------------===//
717
718void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
719 TypeConverter &typeConverter, RewritePatternSet &patterns) {
720 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
721 assert(dialect && "expected that dialect is loaded");
722 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
723 // ConversionTarget is currently ignored because the enclosing
724 // apply_conversion_patterns op sets up its own ConversionTarget.
726 iface->populateConvertToLLVMConversionPatterns(
727 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
728}
729
730LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
731 transform::TypeConverterBuilderOpInterface builder) {
732 if (builder.getTypeConverterType() != "LLVMTypeConverter")
733 return emitOpError("expected LLVMTypeConverter");
734 return success();
735}
736
737LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
738 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
739 if (!dialect)
740 return emitOpError("unknown dialect or dialect not loaded: ")
741 << getDialectName();
742 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
743 if (!iface)
744 return emitOpError(
745 "dialect does not implement ConvertToLLVMPatternInterface or "
746 "extension was not loaded: ")
747 << getDialectName();
748 return success();
749}
750
751//===----------------------------------------------------------------------===//
752// ApplyLoopInvariantCodeMotionOp
753//===----------------------------------------------------------------------===//
754
756transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
757 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
760 // Currently, LICM does not remove operations, so we don't need tracking.
761 // If this ever changes, add a LICM entry point that takes a rewriter.
764}
765
766void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
768 transform::onlyReadsHandle(getTargetMutable(), effects);
770}
771
772//===----------------------------------------------------------------------===//
773// ApplyRegisteredPassOp
774//===----------------------------------------------------------------------===//
775
776void transform::ApplyRegisteredPassOp::getEffects(
778 consumesHandle(getTargetMutable(), effects);
779 onlyReadsHandle(getDynamicOptionsMutable(), effects);
780 producesHandle(getOperation()->getOpResults(), effects);
781 modifiesPayload(effects);
782}
783
785transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
788 // Obtain a single options-string to pass to the pass(-pipeline) from options
789 // passed in as a dictionary of keys mapping to values which are either
790 // attributes or param-operands pointing to attributes.
791 OperandRange dynamicOptions = getDynamicOptions();
792
793 std::string options;
794 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
795
796 // A helper to convert an option's attribute value into a corresponding
797 // string representation, with the ability to obtain the attr(s) from a param.
798 std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
799 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
800 // The corresponding value attribute(s) is/are passed in via a param.
801 // Obtain the param-operand via its specified index.
802 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
803 assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
804 "the number of ParamOperandAttrs in the options DictionaryAttr"
805 "should be the same as the number of options passed as params");
806 ArrayRef<Attribute> attrsAssociatedToParam =
807 state.getParams(dynamicOptions[dynamicOptionIdx]);
808 // Recursive so as to append all attrs associated to the param.
809 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
810 ",");
811 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
812 // Recursive so as to append all nested attrs of the array.
813 llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
814 } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
815 // Convert to unquoted string.
816 optionsStream << strAttr.getValue().str();
817 } else {
818 // For all other attributes, ask the attr to print itself (without type).
819 valueAttr.print(optionsStream, /*elideType=*/true);
820 }
821 };
822
823 // Convert the options DictionaryAttr into a single string.
824 llvm::interleave(
825 getOptions(), optionsStream,
826 [&](auto namedAttribute) {
827 optionsStream << namedAttribute.getName().str(); // Append the key.
828 optionsStream << "="; // And the key-value separator.
829 appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
830 },
831 " ");
832 optionsStream.flush();
833
834 // Get pass or pass pipeline from registry.
835 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
836 if (!info)
837 info = PassInfo::lookup(getPassName());
838 if (!info)
839 return emitDefiniteFailure()
840 << "unknown pass or pass pipeline: " << getPassName();
841
842 // Create pass manager and add the pass or pass pipeline.
844 if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
845 emitError(msg);
846 return failure();
847 }))) {
848 return emitDefiniteFailure()
849 << "failed to add pass or pass pipeline to pipeline: "
850 << getPassName();
851 }
852
853 auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
854 for (Operation *target : targets) {
855 // Make sure that this transform is not applied to itself. Modifying the
856 // transform IR while it is being interpreted is generally dangerous. Even
857 // more so when applying passes because they may perform a wide range of IR
858 // modifications.
859 DiagnosedSilenceableFailure payloadCheck =
861 if (!payloadCheck.succeeded())
862 return payloadCheck;
863
864 // Run the pass or pass pipeline on the current target operation.
865 if (failed(pm.run(target))) {
866 auto diag = emitSilenceableError() << "pass pipeline failed";
867 diag.attachNote(target->getLoc()) << "target op";
868 return diag;
869 }
870 }
871
872 // The applied pass will have directly modified the payload IR(s).
873 results.set(llvm::cast<OpResult>(getResult()), targets);
875}
876
878 OpAsmParser &parser, DictionaryAttr &options,
880 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
881 SmallVector<NamedAttribute> keyValuePairs;
882 size_t dynamicOptionsIdx = 0;
883
884 // Helper for allowing parsing of option values which can be of the form:
885 // - a normal attribute
886 // - an operand (which would be converted to an attr referring to the operand)
887 // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
888 std::function<ParseResult(Attribute &)> parseValue =
889 [&](Attribute &valueAttr) -> ParseResult {
890 // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
891 if (succeeded(parser.parseOptionalLSquare())) {
893
894 // Recursively parse the array's elements, which might be operands.
895 if (parser.parseCommaSeparatedList(
897 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
898 " in options dictionary") ||
899 parser.parseRSquare())
900 return failure(); // NB: Attempted parse should've output error message.
901
902 valueAttr = ArrayAttr::get(parser.getContext(), attrs);
903
904 return success();
905 }
906
907 // Parse the value, which can be either an attribute or an operand.
908 OptionalParseResult parsedValueAttr =
909 parser.parseOptionalAttribute(valueAttr);
910 if (!parsedValueAttr.has_value()) {
912 ParseResult parsedOperand = parser.parseOperand(operand);
913 if (failed(parsedOperand))
914 return failure(); // NB: Attempted parse should've output error message.
915 // To make use of the operand, we need to store it in the options dict.
916 // As SSA-values cannot occur in attributes, what we do instead is store
917 // an attribute in its place that contains the index of the param-operand,
918 // so that an attr-value associated to the param can be resolved later on.
919 dynamicOptions.push_back(operand);
920 auto wrappedIndex = IntegerAttr::get(
921 IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
922 valueAttr =
923 transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
924 } else if (failed(parsedValueAttr.value())) {
925 return failure(); // NB: Attempted parse should have output error message.
926 } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
927 return parser.emitError(parser.getCurrentLocation())
928 << "the param_operand attribute is a marker reserved for "
929 << "indicating a value will be passed via params and is only used "
930 << "in the generic print format";
931 }
932
933 return success();
934 };
935
936 // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
937 // string and `value` looks like either an attribute or an operand-in-an-attr.
938 std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
939 std::string key;
940 Attribute valueAttr;
941
942 if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
943 return parser.emitError(parser.getCurrentLocation())
944 << "expected key to either be an identifier or a string";
945
946 if (failed(parser.parseEqual()))
947 return parser.emitError(parser.getCurrentLocation())
948 << "expected '=' after key in key-value pair";
949
950 if (failed(parseValue(valueAttr)))
951 return parser.emitError(parser.getCurrentLocation())
952 << "expected a valid attribute or operand as value associated "
953 << "to key '" << key << "'";
954
955 keyValuePairs.push_back(NamedAttribute(key, valueAttr));
956
957 return success();
958 };
959
962 " in options dictionary"))
963 return failure(); // NB: Attempted parse should have output error message.
964
965 if (DictionaryAttr::findDuplicate(
966 keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
967 .has_value())
968 return parser.emitError(parser.getCurrentLocation())
969 << "duplicate keys found in options dictionary";
970
971 options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
972
973 return success();
974}
975
977 Operation *op,
978 DictionaryAttr options,
979 ValueRange dynamicOptions) {
980 if (options.empty())
981 return;
982
983 std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
984 if (auto paramOperandAttr =
985 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
986 // Resolve index of param-operand to its actual SSA-value and print that.
987 printer.printOperand(
988 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
989 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
990 // This case is so that ArrayAttr-contained operands are pretty-printed.
991 printer << "[";
992 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
993 printer << "]";
994 } else {
995 printer.printAttribute(valueAttr);
996 }
997 };
998
999 printer << "{";
1000 llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
1001 printer << namedAttribute.getName();
1002 printer << " = ";
1003 printOptionValue(namedAttribute.getValue());
1004 });
1005 printer << "}";
1006}
1007
1008LogicalResult transform::ApplyRegisteredPassOp::verify() {
1009 // Check that there is a one-to-one correspondence between param operands
1010 // and references to dynamic options in the options dictionary.
1011
1012 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1013
1014 // Helper for option values to mark seen operands as having been seen (once).
1015 std::function<LogicalResult(Attribute)> checkOptionValue =
1016 [&](Attribute valueAttr) -> LogicalResult {
1017 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1018 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1019 if (dynamicOptionIdx < 0 ||
1020 dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1021 return emitOpError()
1022 << "dynamic option index " << dynamicOptionIdx
1023 << " is out of bounds for the number of dynamic options: "
1024 << dynamicOptions.size();
1025 if (dynamicOptions[dynamicOptionIdx] == nullptr)
1026 return emitOpError() << "dynamic option index " << dynamicOptionIdx
1027 << " is already used in options";
1028 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1029 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1030 // Recurse into ArrayAttrs as they may contain references to operands.
1031 for (auto eltAttr : arrayAttr)
1032 if (failed(checkOptionValue(eltAttr)))
1033 return failure();
1034 }
1035 return success();
1036 };
1037
1038 for (NamedAttribute namedAttr : getOptions())
1039 if (failed(checkOptionValue(namedAttr.getValue())))
1040 return failure();
1041
1042 // All dynamicOptions-params seen in the dict will have been set to null.
1043 for (Value dynamicOption : dynamicOptions)
1044 if (dynamicOption)
1045 return emitOpError() << "a param operand does not have a corresponding "
1046 << "param_operand attr in the options dict";
1047
1048 return success();
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// CastOp
1053//===----------------------------------------------------------------------===//
1054
1056transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1057 Operation *target, ApplyToEachResultList &results,
1059 results.push_back(target);
1061}
1062
1063void transform::CastOp::getEffects(
1065 onlyReadsPayload(effects);
1066 onlyReadsHandle(getInputMutable(), effects);
1067 producesHandle(getOperation()->getOpResults(), effects);
1068}
1069
1070bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1071 assert(inputs.size() == 1 && "expected one input");
1072 assert(outputs.size() == 1 && "expected one output");
1073 return llvm::all_of(
1074 std::initializer_list<Type>{inputs.front(), outputs.front()},
1075 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1076}
1077
1078//===----------------------------------------------------------------------===//
1079// CollectMatchingOp
1080//===----------------------------------------------------------------------===//
1081
1082/// Applies matcher operations from the given `block` using
1083/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1084/// accordingly. If any of the matcher produces a silenceable failure, discards
1085/// it (printing the content to the debug output stream) and returns failure. If
1086/// any of the matchers produces a definite failure, reports it and returns
1087/// failure. If all matchers in the block succeed, populates `mappings` with the
1088/// payload entities associated with the block terminator operands. Note that
1089/// `mappings` will be cleared before that.
1092 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1095 assert(block.getParent() && "cannot match using a detached block");
1096 auto matchScope = state.make_region_scope(*block.getParent());
1097 if (failed(
1098 state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1100
1101 for (Operation &match : block.without_terminator()) {
1102 if (!isa<transform::MatchOpInterface>(match)) {
1103 return emitDefiniteFailure(match.getLoc())
1104 << "expected operations in the match part to "
1105 "implement MatchOpInterface";
1106 }
1108 state.applyTransform(cast<transform::TransformOpInterface>(match));
1109 if (diag.succeeded())
1110 continue;
1111
1112 return diag;
1113 }
1114
1115 // Remember the values mapped to the terminator operands so we can
1116 // forward them to the action.
1117 ValueRange yieldedValues = block.getTerminator()->getOperands();
1118 // Our contract with the caller is that the mappings will contain only the
1119 // newly mapped values, clear the rest.
1120 mappings.clear();
1121 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1123}
1124
1125/// Returns `true` if both types implement one of the interfaces provided as
1126/// template parameters.
1127template <typename... Tys>
1128static bool implementSameInterface(Type t1, Type t2) {
1129 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1130}
1131
1132/// Returns `true` if both types implement one of the transform dialect
1133/// interfaces.
1135 return implementSameInterface<transform::TransformHandleTypeInterface,
1136 transform::TransformParamTypeInterface,
1137 transform::TransformValueHandleTypeInterface>(
1138 t1, t2);
1139}
1140
1141//===----------------------------------------------------------------------===//
1142// CollectMatchingOp
1143//===----------------------------------------------------------------------===//
1144
1146transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1150 getOperation(), getMatcher());
1151 if (matcher.isExternal()) {
1152 return emitDefiniteFailure()
1153 << "unresolved external symbol " << getMatcher();
1154 }
1155
1157 rawResults.resize(getOperation()->getNumResults());
1158 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1159 for (Operation *root : state.getPayloadOps(getRoot())) {
1160 WalkResult walkResult = root->walk([&](Operation *op) {
1161 LDBG(DEBUG_TYPE_MATCHER, 1)
1162 << "matching "
1163 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1164 << " @" << op;
1165
1166 // Try matching.
1168 SmallVector<transform::MappedValue> inputMapping({op});
1170 matcher.getFunctionBody().front(),
1171 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1172 mappings);
1173 if (diag.isDefiniteFailure())
1174 return WalkResult::interrupt();
1175 if (diag.isSilenceableFailure()) {
1176 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1177 << " failed: " << diag.getMessage();
1178 return WalkResult::advance();
1179 }
1180
1181 // If succeeded, collect results.
1182 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1183 if (mapping.size() != 1) {
1184 maybeFailure.emplace(emitSilenceableError()
1185 << "result #" << i << ", associated with "
1186 << mapping.size()
1187 << " payload objects, expected 1");
1188 return WalkResult::interrupt();
1189 }
1190 rawResults[i].push_back(mapping[0]);
1191 }
1192 return WalkResult::advance();
1193 });
1194 if (walkResult.wasInterrupted())
1195 return std::move(*maybeFailure);
1196 assert(!maybeFailure && "failure set but the walk was not interrupted");
1197
1198 for (auto &&[opResult, rawResult] :
1199 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1200 results.setMappedValues(opResult, rawResult);
1201 }
1202 }
1204}
1205
1206void transform::CollectMatchingOp::getEffects(
1208 onlyReadsHandle(getRootMutable(), effects);
1209 producesHandle(getOperation()->getOpResults(), effects);
1210 onlyReadsPayload(effects);
1211}
1212
1213LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1214 SymbolTableCollection &symbolTable) {
1215 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1216 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1217 if (!matcherSymbol ||
1218 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1219 return emitError() << "unresolved matcher symbol " << getMatcher();
1220
1221 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1222 if (argumentTypes.size() != 1 ||
1223 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1224 return emitError()
1225 << "expected the matcher to take one operation handle argument";
1226 }
1227 if (!matcherSymbol.getArgAttr(
1228 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1229 return emitError() << "expected the matcher argument to be marked readonly";
1230 }
1231
1232 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1233 if (resultTypes.size() != getOperation()->getNumResults()) {
1234 return emitError()
1235 << "expected the matcher to yield as many values as op has results ("
1236 << getOperation()->getNumResults() << "), got "
1237 << resultTypes.size();
1238 }
1239
1240 for (auto &&[i, matcherType, resultType] :
1241 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1242 if (implementSameTransformInterface(matcherType, resultType))
1243 continue;
1244
1245 return emitError()
1246 << "mismatching type interfaces for matcher result and op result #"
1247 << i;
1248 }
1249
1250 return success();
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// ForeachMatchOp
1255//===----------------------------------------------------------------------===//
1256
1257// This is fine because nothing is actually consumed by this op.
1258bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1259
1261transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1265 matchActionPairs;
1266 matchActionPairs.reserve(getMatchers().size());
1267 SymbolTableCollection symbolTable;
1268 for (auto &&[matcher, action] :
1269 llvm::zip_equal(getMatchers(), getActions())) {
1270 auto matcherSymbol =
1271 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1272 getOperation(), cast<SymbolRefAttr>(matcher));
1273 auto actionSymbol =
1274 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1275 getOperation(), cast<SymbolRefAttr>(action));
1276 assert(matcherSymbol && actionSymbol &&
1277 "unresolved symbols not caught by the verifier");
1278
1279 if (matcherSymbol.isExternal())
1280 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1281 if (actionSymbol.isExternal())
1282 return emitDefiniteFailure() << "unresolved external symbol " << action;
1283
1284 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1285 }
1286
1287 DiagnosedSilenceableFailure overallDiag =
1289
1290 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1291 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1292 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1293 // Explicitly add the mapping for the first block argument (the op being
1294 // matched).
1295 matchInputMapping.emplace_back();
1297 getForwardedInputs(), state);
1298 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1299 actionResultMapping.resize(getForwardedOutputs().size());
1300
1301 for (Operation *root : state.getPayloadOps(getRoot())) {
1302 WalkResult walkResult = root->walk([&](Operation *op) {
1303 // If getRestrictRoot is not present, skip over the root op itself so we
1304 // don't invalidate it.
1305 if (!getRestrictRoot() && op == root)
1306 return WalkResult::advance();
1307
1308 LDBG(DEBUG_TYPE_MATCHER, 1)
1309 << "matching "
1310 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1311 << " @" << op;
1312
1313 firstMatchArgument.clear();
1314 firstMatchArgument.push_back(op);
1315
1316 // Try all the match/action pairs until the first successful match.
1317 for (auto [matcher, action] : matchActionPairs) {
1319 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1320 state, matchOutputMapping);
1321 if (diag.isDefiniteFailure())
1322 return WalkResult::interrupt();
1323 if (diag.isSilenceableFailure()) {
1324 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1325 << " failed: " << diag.getMessage();
1326 continue;
1327 }
1328
1329 auto scope = state.make_region_scope(action.getFunctionBody());
1330 if (failed(state.mapBlockArguments(
1331 action.getFunctionBody().front().getArguments(),
1332 matchOutputMapping))) {
1333 return WalkResult::interrupt();
1334 }
1335
1336 for (Operation &transform :
1337 action.getFunctionBody().front().without_terminator()) {
1339 state.applyTransform(cast<TransformOpInterface>(transform));
1340 if (result.isDefiniteFailure())
1341 return WalkResult::interrupt();
1342 if (result.isSilenceableFailure()) {
1343 if (overallDiag.succeeded()) {
1344 overallDiag = emitSilenceableError() << "actions failed";
1345 }
1346 overallDiag.attachNote(action->getLoc())
1347 << "failed action: " << result.getMessage();
1348 overallDiag.attachNote(op->getLoc())
1349 << "when applied to this matching payload";
1350 (void)result.silence();
1351 continue;
1352 }
1353 }
1354 if (failed(detail::appendValueMappings(
1355 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1356 action.getFunctionBody().front().getTerminator()->getOperands(),
1357 state, getFlattenResults()))) {
1359 << "action @" << action.getName()
1360 << " has results associated with multiple payload entities, "
1361 "but flattening was not requested";
1362 return WalkResult::interrupt();
1363 }
1364 break;
1365 }
1366 return WalkResult::advance();
1367 });
1368 if (walkResult.wasInterrupted())
1370 }
1371
1372 // The root operation should not have been affected, so we can just reassign
1373 // the payload to the result. Note that we need to consume the root handle to
1374 // make sure any handles to operations inside, that could have been affected
1375 // by actions, are invalidated.
1376 results.set(llvm::cast<OpResult>(getUpdated()),
1377 state.getPayloadOps(getRoot()));
1378 for (auto &&[result, mapping] :
1379 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1380 results.setMappedValues(result, mapping);
1381 }
1382 return overallDiag;
1383}
1384
1385void transform::ForeachMatchOp::getAsmResultNames(
1386 OpAsmSetValueNameFn setNameFn) {
1387 setNameFn(getUpdated(), "updated_root");
1388 for (Value v : getForwardedOutputs()) {
1389 setNameFn(v, "yielded");
1390 }
1391}
1392
1393void transform::ForeachMatchOp::getEffects(
1395 // Bail if invalid.
1396 if (getOperation()->getNumOperands() < 1 ||
1397 getOperation()->getNumResults() < 1) {
1398 return modifiesPayload(effects);
1399 }
1400
1401 consumesHandle(getRootMutable(), effects);
1402 onlyReadsHandle(getForwardedInputsMutable(), effects);
1403 producesHandle(getOperation()->getOpResults(), effects);
1404 modifiesPayload(effects);
1405}
1406
1407/// Parses the comma-separated list of symbol reference pairs of the format
1408/// `@matcher -> @action`.
1409static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1411 ArrayAttr &actions) {
1412 StringAttr matcher;
1413 StringAttr action;
1414 SmallVector<Attribute> matcherList;
1415 SmallVector<Attribute> actionList;
1416 do {
1417 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1418 parser.parseSymbolName(action)) {
1419 return failure();
1420 }
1421 matcherList.push_back(SymbolRefAttr::get(matcher));
1422 actionList.push_back(SymbolRefAttr::get(action));
1423 } while (parser.parseOptionalComma().succeeded());
1424
1425 matchers = parser.getBuilder().getArrayAttr(matcherList);
1426 actions = parser.getBuilder().getArrayAttr(actionList);
1427 return success();
1428}
1429
1430/// Prints the comma-separated list of symbol reference pairs of the format
1431/// `@matcher -> @action`.
1433 ArrayAttr matchers, ArrayAttr actions) {
1434 printer.increaseIndent();
1435 printer.increaseIndent();
1436 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1437 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1438 printer.printNewline();
1439 printer << cast<SymbolRefAttr>(matcher) << " -> "
1440 << cast<SymbolRefAttr>(action);
1441 if (idx != matchers.size() - 1)
1442 printer << ", ";
1443 }
1444 printer.decreaseIndent();
1445 printer.decreaseIndent();
1446}
1447
1448LogicalResult transform::ForeachMatchOp::verify() {
1449 if (getMatchers().size() != getActions().size())
1450 return emitOpError() << "expected the same number of matchers and actions";
1451 if (getMatchers().empty())
1452 return emitOpError() << "expected at least one match/action pair";
1453
1455 for (Attribute name : getMatchers()) {
1456 if (matcherNames.insert(name).second)
1457 continue;
1458 emitWarning() << "matcher " << name
1459 << " is used more than once, only the first match will apply";
1460 }
1461
1462 return success();
1463}
1464
1465/// Checks that the attributes of the function-like operation have correct
1466/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1467/// annotations being present even if they can be inferred from the body.
1469verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1470 bool alsoVerifyInternal = false) {
1471 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1472 llvm::SmallDenseSet<unsigned> consumedArguments;
1473 if (!op.isExternal()) {
1474 transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1475 consumedArguments);
1476 }
1477 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1478 bool isConsumed =
1479 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1480 nullptr;
1481 bool isReadOnly =
1482 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1483 nullptr;
1484 if (isConsumed && isReadOnly) {
1485 return transformOp.emitSilenceableError()
1486 << "argument #" << i << " cannot be both readonly and consumed";
1487 }
1488 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1489 return transformOp.emitSilenceableError()
1490 << "must provide consumed/readonly status for arguments of "
1491 "external or called ops";
1492 }
1493 if (op.isExternal())
1494 continue;
1495
1496 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1497 return transformOp.emitSilenceableError()
1498 << "argument #" << i
1499 << " is consumed in the body but is not marked as such";
1500 }
1501 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1502 // Cannot use op.emitWarning() here as it would attempt to verify the op
1503 // before printing, resulting in infinite recursion.
1504 emitWarning(op->getLoc())
1505 << "op argument #" << i
1506 << " is not consumed in the body but is marked as consumed";
1507 }
1508 }
1510}
1511
1512LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1513 SymbolTableCollection &symbolTable) {
1514 assert(getMatchers().size() == getActions().size());
1515 auto consumedAttr =
1516 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1517 for (auto &&[matcher, action] :
1518 llvm::zip_equal(getMatchers(), getActions())) {
1519 // Presence and typing.
1520 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1521 symbolTable.lookupNearestSymbolFrom(getOperation(),
1522 cast<SymbolRefAttr>(matcher)));
1523 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1524 symbolTable.lookupNearestSymbolFrom(getOperation(),
1525 cast<SymbolRefAttr>(action)));
1526 if (!matcherSymbol ||
1527 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1528 return emitError() << "unresolved matcher symbol " << matcher;
1529 if (!actionSymbol ||
1530 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1531 return emitError() << "unresolved action symbol " << action;
1532
1534 /*emitWarnings=*/false,
1535 /*alsoVerifyInternal=*/true)
1536 .checkAndReport())) {
1537 return failure();
1538 }
1540 /*emitWarnings=*/false,
1541 /*alsoVerifyInternal=*/true)
1542 .checkAndReport())) {
1543 return failure();
1544 }
1545
1546 // Input -> matcher forwarding.
1547 TypeRange operandTypes = getOperandTypes();
1548 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1549 if (operandTypes.size() != matcherArguments.size()) {
1551 emitError() << "the number of operands (" << operandTypes.size()
1552 << ") doesn't match the number of matcher arguments ("
1553 << matcherArguments.size() << ") for " << matcher;
1554 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1555 return diag;
1556 }
1557 for (auto &&[i, operand, argument] :
1558 llvm::enumerate(operandTypes, matcherArguments)) {
1559 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1561 emitOpError()
1562 << "does not expect matcher symbol to consume its operand #" << i;
1563 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1564 return diag;
1565 }
1566
1567 if (implementSameTransformInterface(operand, argument))
1568 continue;
1569
1571 emitError()
1572 << "mismatching type interfaces for operand and matcher argument #"
1573 << i << " of matcher " << matcher;
1574 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1575 return diag;
1576 }
1577
1578 // Matcher -> action forwarding.
1579 TypeRange matcherResults = matcherSymbol.getResultTypes();
1580 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1581 if (matcherResults.size() != actionArguments.size()) {
1582 return emitError() << "mismatching number of matcher results and "
1583 "action arguments between "
1584 << matcher << " (" << matcherResults.size() << ") and "
1585 << action << " (" << actionArguments.size() << ")";
1586 }
1587 for (auto &&[i, matcherType, actionType] :
1588 llvm::enumerate(matcherResults, actionArguments)) {
1589 if (implementSameTransformInterface(matcherType, actionType))
1590 continue;
1591
1592 return emitError() << "mismatching type interfaces for matcher result "
1593 "and action argument #"
1594 << i << "of matcher " << matcher << " and action "
1595 << action;
1596 }
1597
1598 // Action -> result forwarding.
1599 TypeRange actionResults = actionSymbol.getResultTypes();
1600 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1601 if (actionResults.size() != resultTypes.size()) {
1603 emitError() << "the number of action results ("
1604 << actionResults.size() << ") for " << action
1605 << " doesn't match the number of extra op results ("
1606 << resultTypes.size() << ")";
1607 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1608 return diag;
1609 }
1610 for (auto &&[i, resultType, actionType] :
1611 llvm::enumerate(resultTypes, actionResults)) {
1612 if (implementSameTransformInterface(resultType, actionType))
1613 continue;
1614
1616 emitError() << "mismatching type interfaces for action result #" << i
1617 << " of action " << action << " and op result";
1618 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1619 return diag;
1620 }
1621 }
1622 return success();
1623}
1624
1625//===----------------------------------------------------------------------===//
1626// ForeachOp
1627//===----------------------------------------------------------------------===//
1628
1630transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1633 // We store the payloads before executing the body as ops may be removed from
1634 // the mapping by the TrackingRewriter while iteration is in progress.
1636 detail::prepareValueMappings(payloads, getTargets(), state);
1637 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1638 bool withZipShortest = getWithZipShortest();
1639
1640 // In case of `zip_shortest`, set the number of iterations to the
1641 // smallest payload in the targets.
1642 if (withZipShortest) {
1643 numIterations =
1644 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1645 const SmallVector<MappedValue> &b) {
1646 return a.size() < b.size();
1647 })->size();
1648
1649 for (auto &payload : payloads)
1650 payload.resize(numIterations);
1651 }
1652
1653 // As we will be "zipping" over them, check all payloads have the same size.
1654 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1655 // when true.
1656 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1657 argIdx++) {
1658 if (payloads[argIdx].size() != numIterations) {
1659 return emitSilenceableError()
1660 << "prior targets' payload size (" << numIterations
1661 << ") differs from payload size (" << payloads[argIdx].size()
1662 << ") of target " << getTargets()[argIdx];
1663 }
1664 }
1665
1666 // Start iterating, indexing into payloads to obtain the right arguments to
1667 // call the body with - each slice of payloads at the same argument index
1668 // corresponding to a tuple to use as the body's block arguments.
1669 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1670 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1671 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1672 auto scope = state.make_region_scope(getBody());
1673 // Set up arguments to the region's block.
1674 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1675 MappedValue argument = payloads[argIdx][iterIdx];
1676 // Note that each blockArg's handle gets associated with just a single
1677 // element from the corresponding target's payload.
1678 if (failed(state.mapBlockArgument(blockArg, {argument})))
1680 }
1681
1682 // Execute loop body.
1683 for (Operation &transform : getBody().front().without_terminator()) {
1685 llvm::cast<transform::TransformOpInterface>(transform));
1686 if (!result.succeeded())
1687 return result;
1688 }
1689
1690 // Append yielded payloads to corresponding results from prior iterations.
1691 OperandRange yieldOperands = getYieldOp().getOperands();
1692 for (auto &&[result, yieldOperand, resTuple] :
1693 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1694 // NB: each iteration we add any number of ops/vals/params to a result.
1695 if (isa<TransformHandleTypeInterface>(result.getType()))
1696 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1697 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1698 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1699 else if (isa<TransformParamTypeInterface>(result.getType()))
1700 llvm::append_range(resTuple, state.getParams(yieldOperand));
1701 else
1702 assert(false && "unhandled handle type");
1703 }
1704
1705 // Associate the accumulated result payloads to the op's actual results.
1706 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1707 results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1708
1710}
1711
1712void transform::ForeachOp::getEffects(
1714 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1715 // arity errors, this method might get called before/in absence of `verify()`.
1716 for (auto &&[target, blockArg] :
1717 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1718 BlockArgument blockArgument = blockArg;
1719 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1720 return isHandleConsumed(blockArgument,
1721 cast<TransformOpInterface>(&op));
1722 })) {
1723 consumesHandle(target, effects);
1724 } else {
1725 onlyReadsHandle(target, effects);
1726 }
1727 }
1728
1729 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1730 return doesModifyPayload(cast<TransformOpInterface>(&op));
1731 })) {
1732 modifiesPayload(effects);
1733 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1734 return doesReadPayload(cast<TransformOpInterface>(&op));
1735 })) {
1736 onlyReadsPayload(effects);
1737 }
1738
1739 producesHandle(getOperation()->getOpResults(), effects);
1740}
1741
1742void transform::ForeachOp::getSuccessorRegions(
1744 Region *bodyRegion = &getBody();
1745 if (point.isParent()) {
1746 regions.emplace_back(bodyRegion);
1747 return;
1748 }
1749
1750 // Branch back to the region or the parent.
1751 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
1752 &getBody() &&
1753 "unexpected region index");
1754 regions.emplace_back(bodyRegion);
1755 regions.push_back(RegionSuccessor::parent());
1756}
1757
1758ValueRange transform::ForeachOp::getSuccessorInputs(RegionSuccessor successor) {
1759 return successor.isParent() ? ValueRange(getResults())
1760 : ValueRange(getBody().getArguments());
1761}
1762
1764transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
1765 // Each block argument handle is mapped to a subset (one op to be precise)
1766 // of the payload of the corresponding `targets` operand of ForeachOp.
1767 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
1768 return getOperation()->getOperands();
1769}
1770
1771transform::YieldOp transform::ForeachOp::getYieldOp() {
1772 return cast<transform::YieldOp>(getBody().front().getTerminator());
1773}
1774
1775LogicalResult transform::ForeachOp::verify() {
1776 for (auto [targetOpt, bodyArgOpt] :
1777 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1778 if (!targetOpt || !bodyArgOpt)
1779 return emitOpError() << "expects the same number of targets as the body "
1780 "has block arguments";
1781 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1782 return emitOpError(
1783 "expects co-indexed targets and the body's "
1784 "block arguments to have the same op/value/param type");
1785 }
1786
1787 for (auto [resultOpt, yieldOperandOpt] :
1788 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1789 if (!resultOpt || !yieldOperandOpt)
1790 return emitOpError() << "expects the same number of results as the "
1791 "yield terminator has operands";
1792 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1793 return emitOpError("expects co-indexed results and yield "
1794 "operands to have the same op/value/param type");
1795 }
1796
1797 return success();
1798}
1799
1800//===----------------------------------------------------------------------===//
1801// GetParentOp
1802//===----------------------------------------------------------------------===//
1803
1805transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1809 DenseSet<Operation *> resultSet;
1810 for (Operation *target : state.getPayloadOps(getTarget())) {
1811 Operation *parent = target;
1812 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1813 parent = parent->getParentOp();
1814 while (parent) {
1815 bool checkIsolatedFromAbove =
1816 !getIsolatedFromAbove() ||
1818 bool checkOpName = !getOpName().has_value() ||
1819 parent->getName().getStringRef() == *getOpName();
1820 if (checkIsolatedFromAbove && checkOpName)
1821 break;
1822 parent = parent->getParentOp();
1823 }
1824 if (!parent) {
1825 if (getAllowEmptyResults()) {
1826 results.set(llvm::cast<OpResult>(getResult()), parents);
1828 }
1830 emitSilenceableError()
1831 << "could not find a parent op that matches all requirements";
1832 diag.attachNote(target->getLoc()) << "target op";
1833 return diag;
1834 }
1835 }
1836 if (getDeduplicate()) {
1837 if (resultSet.insert(parent).second)
1838 parents.push_back(parent);
1839 } else {
1840 parents.push_back(parent);
1841 }
1842 }
1843 results.set(llvm::cast<OpResult>(getResult()), parents);
1845}
1846
1847//===----------------------------------------------------------------------===//
1848// GetConsumersOfResult
1849//===----------------------------------------------------------------------===//
1850
1852transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1855 int64_t resultNumber = getResultNumber();
1856 auto payloadOps = state.getPayloadOps(getTarget());
1857 if (std::empty(payloadOps)) {
1858 results.set(cast<OpResult>(getResult()), {});
1860 }
1861 if (!llvm::hasSingleElement(payloadOps))
1862 return emitDefiniteFailure()
1863 << "handle must be mapped to exactly one payload op";
1864
1865 Operation *target = *payloadOps.begin();
1866 if (target->getNumResults() <= resultNumber)
1867 return emitDefiniteFailure() << "result number overflow";
1868 results.set(llvm::cast<OpResult>(getResult()),
1869 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1871}
1872
1873//===----------------------------------------------------------------------===//
1874// GetDefiningOp
1875//===----------------------------------------------------------------------===//
1876
1878transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1881 SmallVector<Operation *> definingOps;
1882 for (Value v : state.getPayloadValues(getTarget())) {
1883 if (llvm::isa<BlockArgument>(v)) {
1885 emitSilenceableError() << "cannot get defining op of block argument";
1886 diag.attachNote(v.getLoc()) << "target value";
1887 return diag;
1888 }
1889 definingOps.push_back(v.getDefiningOp());
1890 }
1891 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1893}
1894
1895//===----------------------------------------------------------------------===//
1896// GetProducerOfOperand
1897//===----------------------------------------------------------------------===//
1898
1900transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1903 int64_t operandNumber = getOperandNumber();
1904 SmallVector<Operation *> producers;
1905 for (Operation *target : state.getPayloadOps(getTarget())) {
1906 Operation *producer =
1907 target->getNumOperands() <= operandNumber
1908 ? nullptr
1909 : target->getOperand(operandNumber).getDefiningOp();
1910 if (!producer) {
1912 emitSilenceableError()
1913 << "could not find a producer for operand number: " << operandNumber
1914 << " of " << *target;
1915 diag.attachNote(target->getLoc()) << "target op";
1916 return diag;
1917 }
1918 producers.push_back(producer);
1919 }
1920 results.set(llvm::cast<OpResult>(getResult()), producers);
1922}
1923
1924//===----------------------------------------------------------------------===//
1925// GetOperandOp
1926//===----------------------------------------------------------------------===//
1927
1929transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1932 SmallVector<Value> operands;
1933 for (Operation *target : state.getPayloadOps(getTarget())) {
1934 SmallVector<int64_t> operandPositions;
1936 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1937 target->getNumOperands(), operandPositions);
1938 if (diag.isSilenceableFailure()) {
1939 diag.attachNote(target->getLoc())
1940 << "while considering positions of this payload operation";
1941 return diag;
1942 }
1943 llvm::append_range(operands,
1944 llvm::map_range(operandPositions, [&](int64_t pos) {
1945 return target->getOperand(pos);
1946 }));
1947 }
1948 results.setValues(cast<OpResult>(getResult()), operands);
1950}
1951
1952LogicalResult transform::GetOperandOp::verify() {
1953 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1954 getIsInverted(), getIsAll());
1955}
1956
1957//===----------------------------------------------------------------------===//
1958// GetResultOp
1959//===----------------------------------------------------------------------===//
1960
1962transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1965 SmallVector<Value> opResults;
1966 for (Operation *target : state.getPayloadOps(getTarget())) {
1967 SmallVector<int64_t> resultPositions;
1969 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1970 target->getNumResults(), resultPositions);
1971 if (diag.isSilenceableFailure()) {
1972 diag.attachNote(target->getLoc())
1973 << "while considering positions of this payload operation";
1974 return diag;
1975 }
1976 llvm::append_range(opResults,
1977 llvm::map_range(resultPositions, [&](int64_t pos) {
1978 return target->getResult(pos);
1979 }));
1980 }
1981 results.setValues(cast<OpResult>(getResult()), opResults);
1983}
1984
1985LogicalResult transform::GetResultOp::verify() {
1986 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1987 getIsInverted(), getIsAll());
1988}
1989
1990//===----------------------------------------------------------------------===//
1991// GetTypeOp
1992//===----------------------------------------------------------------------===//
1993
1994void transform::GetTypeOp::getEffects(
1996 onlyReadsHandle(getValueMutable(), effects);
1997 producesHandle(getOperation()->getOpResults(), effects);
1998 onlyReadsPayload(effects);
1999}
2000
2002transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
2006 for (Value value : state.getPayloadValues(getValue())) {
2007 Type type = value.getType();
2008 if (getElemental()) {
2009 if (auto shaped = dyn_cast<ShapedType>(type)) {
2010 type = shaped.getElementType();
2011 }
2012 }
2013 params.push_back(TypeAttr::get(type));
2014 }
2015 results.setParams(cast<OpResult>(getResult()), params);
2017}
2018
2019//===----------------------------------------------------------------------===//
2020// IncludeOp
2021//===----------------------------------------------------------------------===//
2022
2023/// Applies the transform ops contained in `block`. Maps `results` to the same
2024/// values as the operands of the block terminator.
2026applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2028 transform::TransformResults &results) {
2029 // Apply the sequenced ops one by one.
2030 for (Operation &transform : block.without_terminator()) {
2032 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2033 if (result.isDefiniteFailure())
2034 return result;
2035
2036 if (result.isSilenceableFailure()) {
2037 if (mode == transform::FailurePropagationMode::Propagate) {
2038 // Propagate empty results in case of early exit.
2039 forwardEmptyOperands(&block, state, results);
2040 return result;
2041 }
2042 (void)result.silence();
2043 }
2044 }
2045
2046 // Forward the operation mapping for values yielded from the sequence to the
2047 // values produced by the sequence op.
2048 transform::detail::forwardTerminatorOperands(&block, state, results);
2050}
2051
2053transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2057 getOperation(), getTarget());
2058 assert(callee && "unverified reference to unknown symbol");
2059
2060 if (callee.isExternal())
2061 return emitDefiniteFailure() << "unresolved external named sequence";
2062
2063 // Map operands to block arguments.
2065 detail::prepareValueMappings(mappings, getOperands(), state);
2066 auto scope = state.make_region_scope(callee.getBody());
2067 for (auto &&[arg, map] :
2068 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2069 if (failed(state.mapBlockArgument(arg, map)))
2071 }
2072
2074 callee.getBody().front(), getFailurePropagationMode(), state, results);
2075
2076 if (!result.succeeded())
2077 return result;
2078
2079 mappings.clear();
2080 detail::prepareValueMappings(
2081 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2082 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2083 results.setMappedValues(result, mapping);
2084 return result;
2085}
2086
2088verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2089
2090void transform::IncludeOp::getEffects(
2092 // Always mark as modifying the payload.
2093 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2094 // only read, the payload may still be modified, so we currently stay on the
2095 // conservative side and always indicate modification. This may prevent some
2096 // code reordering.
2097 modifiesPayload(effects);
2098
2099 // Results are always produced.
2100 producesHandle(getOperation()->getOpResults(), effects);
2101
2102 // Adds default effects to operands and results. This will be added if
2103 // preconditions fail so the trait verifier doesn't complain about missing
2104 // effects and the real precondition failure is reported later on.
2105 auto defaultEffects = [&] {
2106 onlyReadsHandle(getOperation()->getOpOperands(), effects);
2107 };
2108
2109 // Bail if the callee is unknown. This may run as part of the verification
2110 // process before we verified the validity of the callee or of this op.
2111 auto target =
2112 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2113 if (!target)
2114 return defaultEffects();
2116 getOperation(), getTarget());
2117 if (!callee)
2118 return defaultEffects();
2119
2120 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2121 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2122 consumesHandle(getOperation()->getOpOperand(i), effects);
2123 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2124 onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2125 }
2126}
2127
2128LogicalResult
2129transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2130 // Access through indirection and do additional checking because this may be
2131 // running before the main op verifier.
2132 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2133 if (!targetAttr)
2134 return emitOpError() << "expects a 'target' symbol reference attribute";
2135
2136 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2137 *this, targetAttr);
2138 if (!target)
2139 return emitOpError() << "does not reference a named transform sequence";
2140
2141 FunctionType fnType = target.getFunctionType();
2142 if (fnType.getNumInputs() != getNumOperands())
2143 return emitError("incorrect number of operands for callee");
2144
2145 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2146 if (getOperand(i).getType() != fnType.getInput(i)) {
2147 return emitOpError("operand type mismatch: expected operand type ")
2148 << fnType.getInput(i) << ", but provided "
2149 << getOperand(i).getType() << " for operand number " << i;
2150 }
2151 }
2152
2153 if (fnType.getNumResults() != getNumResults())
2154 return emitError("incorrect number of results for callee");
2155
2156 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2157 Type resultType = getResult(i).getType();
2158 Type funcType = fnType.getResult(i);
2159 if (!implementSameTransformInterface(resultType, funcType)) {
2160 return emitOpError() << "type of result #" << i
2161 << " must implement the same transform dialect "
2162 "interface as the corresponding callee result";
2163 }
2164 }
2165
2167 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2168 /*alsoVerifyInternal=*/true)
2169 .checkAndReport();
2170}
2171
2172//===----------------------------------------------------------------------===//
2173// MatchOperationEmptyOp
2174//===----------------------------------------------------------------------===//
2175
2176DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2177 ::std::optional<::mlir::Operation *> maybeCurrent,
2179 if (!maybeCurrent.has_value()) {
2180 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2182 }
2183 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2184 return emitSilenceableError() << "operation is not empty";
2185}
2186
2187//===----------------------------------------------------------------------===//
2188// MatchOperationNameOp
2189//===----------------------------------------------------------------------===//
2190
2191DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2192 Operation *current, transform::TransformResults &results,
2194 StringRef currentOpName = current->getName().getStringRef();
2195 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2196 if (acceptedAttr.getValue() == currentOpName)
2198 }
2199 return emitSilenceableError() << "wrong operation name";
2200}
2201
2202//===----------------------------------------------------------------------===//
2203// MatchParamCmpIOp
2204//===----------------------------------------------------------------------===//
2205
2207transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2210 auto signedAPIntAsString = [&](const APInt &value) {
2211 std::string str;
2212 llvm::raw_string_ostream os(str);
2213 value.print(os, /*isSigned=*/true);
2214 return str;
2215 };
2216
2217 ArrayRef<Attribute> params = state.getParams(getParam());
2218 ArrayRef<Attribute> references = state.getParams(getReference());
2219
2220 if (params.size() != references.size()) {
2221 return emitSilenceableError()
2222 << "parameters have different payload lengths (" << params.size()
2223 << " vs " << references.size() << ")";
2224 }
2225
2226 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2227 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2228 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2229 if (!intAttr || !refAttr) {
2230 return emitDefiniteFailure()
2231 << "non-integer parameter value not expected";
2232 }
2233 if (intAttr.getType() != refAttr.getType()) {
2234 return emitDefiniteFailure()
2235 << "mismatching integer attribute types in parameter #" << i;
2236 }
2237 APInt value = intAttr.getValue();
2238 APInt refValue = refAttr.getValue();
2239
2240 // TODO: this copy will not be necessary in C++20.
2241 int64_t position = i;
2242 auto reportError = [&](StringRef direction) {
2244 emitSilenceableError() << "expected parameter to be " << direction
2245 << " " << signedAPIntAsString(refValue)
2246 << ", got " << signedAPIntAsString(value);
2247 diag.attachNote(getParam().getLoc())
2248 << "value # " << position
2249 << " associated with the parameter defined here";
2250 return diag;
2251 };
2252
2253 switch (getPredicate()) {
2254 case MatchCmpIPredicate::eq:
2255 if (value.eq(refValue))
2256 break;
2257 return reportError("equal to");
2258 case MatchCmpIPredicate::ne:
2259 if (value.ne(refValue))
2260 break;
2261 return reportError("not equal to");
2262 case MatchCmpIPredicate::lt:
2263 if (value.slt(refValue))
2264 break;
2265 return reportError("less than");
2266 case MatchCmpIPredicate::le:
2267 if (value.sle(refValue))
2268 break;
2269 return reportError("less than or equal to");
2270 case MatchCmpIPredicate::gt:
2271 if (value.sgt(refValue))
2272 break;
2273 return reportError("greater than");
2274 case MatchCmpIPredicate::ge:
2275 if (value.sge(refValue))
2276 break;
2277 return reportError("greater than or equal to");
2278 }
2279 }
2281}
2282
2283void transform::MatchParamCmpIOp::getEffects(
2285 onlyReadsHandle(getParamMutable(), effects);
2286 onlyReadsHandle(getReferenceMutable(), effects);
2287}
2288
2289//===----------------------------------------------------------------------===//
2290// ParamConstantOp
2291//===----------------------------------------------------------------------===//
2292
2294transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2297 results.setParams(cast<OpResult>(getParam()), {getValue()});
2299}
2300
2301//===----------------------------------------------------------------------===//
2302// MergeHandlesOp
2303//===----------------------------------------------------------------------===//
2304
2306transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2309 ValueRange handles = getHandles();
2310 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2311 SmallVector<Operation *> operations;
2312 for (Value operand : handles)
2313 llvm::append_range(operations, state.getPayloadOps(operand));
2314 if (!getDeduplicate()) {
2315 results.set(llvm::cast<OpResult>(getResult()), operations);
2317 }
2318
2319 SetVector<Operation *> uniqued(llvm::from_range, operations);
2320 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2322 }
2323
2324 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2326 for (Value attribute : handles)
2327 llvm::append_range(attrs, state.getParams(attribute));
2328 if (!getDeduplicate()) {
2329 results.setParams(cast<OpResult>(getResult()), attrs);
2331 }
2332
2333 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2334 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2336 }
2337
2338 assert(
2339 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2340 "expected value handle type");
2341 SmallVector<Value> payloadValues;
2342 for (Value value : handles)
2343 llvm::append_range(payloadValues, state.getPayloadValues(value));
2344 if (!getDeduplicate()) {
2345 results.setValues(cast<OpResult>(getResult()), payloadValues);
2347 }
2348
2349 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2350 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2352}
2353
2354bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2355 // Handles may be the same if deduplicating is enabled.
2356 return getDeduplicate();
2357}
2358
2359void transform::MergeHandlesOp::getEffects(
2361 onlyReadsHandle(getHandlesMutable(), effects);
2362 producesHandle(getOperation()->getOpResults(), effects);
2363
2364 // There are no effects on the Payload IR as this is only a handle
2365 // manipulation.
2366}
2367
2368OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2369 if (getDeduplicate() || getHandles().size() != 1)
2370 return {};
2371
2372 // If deduplication is not required and there is only one operand, it can be
2373 // used directly instead of merging.
2374 return getHandles().front();
2375}
2376
2377//===----------------------------------------------------------------------===//
2378// NamedSequenceOp
2379//===----------------------------------------------------------------------===//
2380
2382transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2385 if (isExternal())
2386 return emitDefiniteFailure() << "unresolved external named sequence";
2387
2388 // Map the entry block argument to the list of operations.
2389 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2390 // without attaching the interface / trait since that is tailored to a
2391 // dangling top-level op that does not get "called".
2392 auto scope = state.make_region_scope(getBody());
2393 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2394 state, this->getOperation(), getBody())))
2396
2397 return applySequenceBlock(getBody().front(),
2398 FailurePropagationMode::Propagate, state, results);
2399}
2400
2401void transform::NamedSequenceOp::getEffects(
2403
2404ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2407 parser, result, /*allowVariadic=*/false,
2408 getFunctionTypeAttrName(result.name),
2409 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2411 std::string &) { return builder.getFunctionType(inputs, results); },
2412 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2413}
2414
2415void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2417 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2418 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2419 getResAttrsAttrName());
2420}
2421
2422/// Verifies that a symbol function-like transform dialect operation has the
2423/// signature and the terminator that have conforming types, i.e., types
2424/// implementing the same transform dialect type interface. If `allowExternal`
2425/// is set, allow external symbols (declarations) and don't check the terminator
2426/// as it may not exist.
2428verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2429 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2432 << "cannot be defined inside another transform op";
2433 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2434 return diag;
2435 }
2436
2437 if (op.isExternal() || op.getFunctionBody().empty()) {
2438 if (allowExternal)
2440
2441 return emitSilenceableFailure(op) << "cannot be external";
2442 }
2443
2444 if (op.getFunctionBody().front().empty())
2445 return emitSilenceableFailure(op) << "expected a non-empty body block";
2446
2447 Operation *terminator = &op.getFunctionBody().front().back();
2448 if (!isa<transform::YieldOp>(terminator)) {
2450 << "expected '"
2451 << transform::YieldOp::getOperationName()
2452 << "' as terminator";
2453 diag.attachNote(terminator->getLoc()) << "terminator";
2454 return diag;
2455 }
2456
2457 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2458 return emitSilenceableFailure(terminator)
2459 << "expected terminator to have as many operands as the parent op "
2460 "has results";
2461 }
2462 for (auto [i, operandType, resultType] : llvm::zip_equal(
2463 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2464 terminator->getOperands().getType(), op.getResultTypes())) {
2465 if (operandType == resultType)
2466 continue;
2467 return emitSilenceableFailure(terminator)
2468 << "the type of the terminator operand #" << i
2469 << " must match the type of the corresponding parent op result ("
2470 << operandType << " vs " << resultType << ")";
2471 }
2472
2474}
2475
2476/// Verification of a NamedSequenceOp. This does not report the error
2477/// immediately, so it can be used to check for op's well-formedness before the
2478/// verifier runs, e.g., during trait verification.
2480verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2481 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2482 if (!parent->getAttr(
2483 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2486 << "expects the parent symbol table to have the '"
2487 << transform::TransformDialect::kWithNamedSequenceAttrName
2488 << "' attribute";
2489 diag.attachNote(parent->getLoc()) << "symbol table operation";
2490 return diag;
2491 }
2492 }
2493
2494 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2497 << "cannot be defined inside another transform op";
2498 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2499 return diag;
2500 }
2501
2502 if (op.isExternal() || op.getBody().empty())
2503 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2504 emitWarnings);
2505
2506 if (op.getBody().front().empty())
2507 return emitSilenceableFailure(op) << "expected a non-empty body block";
2508
2509 Operation *terminator = &op.getBody().front().back();
2510 if (!isa<transform::YieldOp>(terminator)) {
2512 << "expected '"
2513 << transform::YieldOp::getOperationName()
2514 << "' as terminator";
2515 diag.attachNote(terminator->getLoc()) << "terminator";
2516 return diag;
2517 }
2518
2519 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2520 return emitSilenceableFailure(terminator)
2521 << "expected terminator to have as many operands as the parent op "
2522 "has results";
2523 }
2524 for (auto [i, operandType, resultType] :
2525 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2526 terminator->getOperands().getType(),
2527 op.getFunctionType().getResults())) {
2528 if (operandType == resultType)
2529 continue;
2530 return emitSilenceableFailure(terminator)
2531 << "the type of the terminator operand #" << i
2532 << " must match the type of the corresponding parent op result ("
2533 << operandType << " vs " << resultType << ")";
2534 }
2535
2536 auto funcOp = cast<FunctionOpInterface>(*op);
2538 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2539 if (!diag.succeeded())
2540 return diag;
2541
2542 return verifyYieldingSingleBlockOp(funcOp,
2543 /*allowExternal=*/true);
2544}
2545
2546LogicalResult transform::NamedSequenceOp::verify() {
2547 // Actual verification happens in a separate function for reusability.
2548 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2549}
2550
2551template <typename FnTy>
2552static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2553 Type bbArgType, TypeRange extraBindingTypes,
2554 FnTy bodyBuilder) {
2555 SmallVector<Type> types;
2556 types.reserve(1 + extraBindingTypes.size());
2557 types.push_back(bbArgType);
2558 llvm::append_range(types, extraBindingTypes);
2559
2560 OpBuilder::InsertionGuard guard(builder);
2561 Region *region = state.regions.back().get();
2562 Block *bodyBlock =
2563 builder.createBlock(region, region->begin(), types,
2564 SmallVector<Location>(types.size(), state.location));
2565
2566 // Populate body.
2567 builder.setInsertionPointToStart(bodyBlock);
2568 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2569 bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2570 } else {
2571 bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2572 bodyBlock->getArguments().drop_front());
2573 }
2574}
2575
2576void transform::NamedSequenceOp::build(OpBuilder &builder,
2577 OperationState &state, StringRef symName,
2578 Type rootType, TypeRange resultTypes,
2579 SequenceBodyBuilderFn bodyBuilder,
2581 ArrayRef<DictionaryAttr> argAttrs) {
2583 builder.getStringAttr(symName));
2584 state.addAttribute(getFunctionTypeAttrName(state.name),
2585 TypeAttr::get(FunctionType::get(builder.getContext(),
2586 rootType, resultTypes)));
2587 state.attributes.append(attrs.begin(), attrs.end());
2588 state.addRegion();
2589
2590 buildSequenceBody(builder, state, rootType,
2591 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2592}
2593
2594//===----------------------------------------------------------------------===//
2595// NumAssociationsOp
2596//===----------------------------------------------------------------------===//
2597
2599transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2602 size_t numAssociations =
2604 .Case([&](TransformHandleTypeInterface opHandle) {
2605 return llvm::range_size(state.getPayloadOps(getHandle()));
2606 })
2607 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2608 return llvm::range_size(state.getPayloadValues(getHandle()));
2609 })
2610 .Case([&](TransformParamTypeInterface param) {
2611 return llvm::range_size(state.getParams(getHandle()));
2612 })
2613 .DefaultUnreachable("unknown kind of transform dialect type");
2614 results.setParams(cast<OpResult>(getNum()),
2615 rewriter.getI64IntegerAttr(numAssociations));
2617}
2618
2619LogicalResult transform::NumAssociationsOp::verify() {
2620 // Verify that the result type accepts an i64 attribute as payload.
2621 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2622 return resultType
2623 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2624 .checkAndReport();
2625}
2626
2627//===----------------------------------------------------------------------===//
2628// SelectOp
2629//===----------------------------------------------------------------------===//
2630
2632transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2636 auto payloadOps = state.getPayloadOps(getTarget());
2637 for (Operation *op : payloadOps) {
2638 if (op->getName().getStringRef() == getOpName())
2639 result.push_back(op);
2640 }
2641 results.set(cast<OpResult>(getResult()), result);
2643}
2644
2645//===----------------------------------------------------------------------===//
2646// SplitHandleOp
2647//===----------------------------------------------------------------------===//
2648
2649void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2650 Value target, int64_t numResultHandles) {
2651 result.addOperands(target);
2652 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2653}
2654
2656transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2659 int64_t numPayloads =
2661 .Case([&](TransformHandleTypeInterface x) {
2662 return llvm::range_size(state.getPayloadOps(getHandle()));
2663 })
2664 .Case([&](TransformValueHandleTypeInterface x) {
2665 return llvm::range_size(state.getPayloadValues(getHandle()));
2666 })
2667 .Case([&](TransformParamTypeInterface x) {
2668 return llvm::range_size(state.getParams(getHandle()));
2669 })
2670 .DefaultUnreachable("unknown transform dialect type interface");
2671
2672 auto produceNumOpsError = [&]() {
2673 return emitSilenceableError()
2674 << getHandle() << " expected to contain " << this->getNumResults()
2675 << " payloads but it contains " << numPayloads << " payloads";
2676 };
2677
2678 // Fail if there are more payload ops than results and no overflow result was
2679 // specified.
2680 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2681 return produceNumOpsError();
2682
2683 // Fail if there are more results than payload ops. Unless:
2684 // - "fail_on_payload_too_small" is set to "false", or
2685 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2686 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2687 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2688 return produceNumOpsError();
2689
2690 // Distribute payloads.
2691 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2692 if (getOverflowResult())
2693 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2694
2695 auto container = [&]() {
2696 if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2697 return llvm::map_to_vector(
2698 state.getPayloadOps(getHandle()),
2699 [](Operation *op) -> MappedValue { return op; });
2700 }
2701 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2702 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2703 [](Value v) -> MappedValue { return v; });
2704 }
2705 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2706 "unsupported kind of transform dialect type");
2707 return llvm::map_to_vector(state.getParams(getHandle()),
2708 [](Attribute a) -> MappedValue { return a; });
2709 }();
2710
2711 for (auto &&en : llvm::enumerate(container)) {
2712 int64_t resultNum = en.index();
2713 if (resultNum >= getNumResults())
2714 resultNum = *getOverflowResult();
2715 resultHandles[resultNum].push_back(en.value());
2716 }
2717
2718 // Set transform op results.
2719 for (auto &&it : llvm::enumerate(resultHandles))
2720 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2721 it.value());
2722
2724}
2725
2726void transform::SplitHandleOp::getEffects(
2728 onlyReadsHandle(getHandleMutable(), effects);
2729 producesHandle(getOperation()->getOpResults(), effects);
2730 // There are no effects on the Payload IR as this is only a handle
2731 // manipulation.
2732}
2733
2734LogicalResult transform::SplitHandleOp::verify() {
2735 if (getOverflowResult().has_value() &&
2736 !(*getOverflowResult() < getNumResults()))
2737 return emitOpError("overflow_result is not a valid result index");
2738
2739 for (Type resultType : getResultTypes()) {
2740 if (implementSameTransformInterface(getHandle().getType(), resultType))
2741 continue;
2742
2743 return emitOpError("expects result types to implement the same transform "
2744 "interface as the operand type");
2745 }
2746
2747 return success();
2748}
2749
2750//===----------------------------------------------------------------------===//
2751// ReplicateOp
2752//===----------------------------------------------------------------------===//
2753
2755transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2758 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2759 for (const auto &en : llvm::enumerate(getHandles())) {
2760 Value handle = en.value();
2761 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2762 SmallVector<Operation *> current =
2763 llvm::to_vector(state.getPayloadOps(handle));
2765 payload.reserve(numRepetitions * current.size());
2766 for (unsigned i = 0; i < numRepetitions; ++i)
2767 llvm::append_range(payload, current);
2768 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2769 } else {
2770 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2771 "expected param type");
2772 ArrayRef<Attribute> current = state.getParams(handle);
2774 params.reserve(numRepetitions * current.size());
2775 for (unsigned i = 0; i < numRepetitions; ++i)
2776 llvm::append_range(params, current);
2777 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2778 params);
2779 }
2780 }
2782}
2783
2784void transform::ReplicateOp::getEffects(
2786 onlyReadsHandle(getPatternMutable(), effects);
2787 onlyReadsHandle(getHandlesMutable(), effects);
2788 producesHandle(getOperation()->getOpResults(), effects);
2789}
2790
2791//===----------------------------------------------------------------------===//
2792// SequenceOp
2793//===----------------------------------------------------------------------===//
2794
2796transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2799 // Map the entry block argument to the list of operations.
2800 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2801 if (failed(mapBlockArguments(state)))
2803
2804 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2805 results);
2806}
2807
2808static ParseResult parseSequenceOpOperands(
2809 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2810 Type &rootType,
2812 SmallVectorImpl<Type> &extraBindingTypes) {
2814 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2815 if (!hasRoot.has_value()) {
2816 root = std::nullopt;
2817 return success();
2818 }
2819 if (failed(hasRoot.value()))
2820 return failure();
2821 root = rootOperand;
2822
2823 if (succeeded(parser.parseOptionalComma())) {
2824 if (failed(parser.parseOperandList(extraBindings)))
2825 return failure();
2826 }
2827 if (failed(parser.parseColon()))
2828 return failure();
2829
2830 // The paren is truly optional.
2831 (void)parser.parseOptionalLParen();
2832
2833 if (failed(parser.parseType(rootType))) {
2834 return failure();
2835 }
2836
2837 if (!extraBindings.empty()) {
2838 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2839 return failure();
2840 }
2841
2842 if (extraBindingTypes.size() != extraBindings.size()) {
2843 return parser.emitError(parser.getNameLoc(),
2844 "expected types to be provided for all operands");
2845 }
2846
2847 // The paren is truly optional.
2848 (void)parser.parseOptionalRParen();
2849 return success();
2850}
2851
2853 Value root, Type rootType,
2854 ValueRange extraBindings,
2855 TypeRange extraBindingTypes) {
2856 if (!root)
2857 return;
2858
2859 printer << root;
2860 bool hasExtras = !extraBindings.empty();
2861 if (hasExtras) {
2862 printer << ", ";
2863 printer.printOperands(extraBindings);
2864 }
2865
2866 printer << " : ";
2867 if (hasExtras)
2868 printer << "(";
2869
2870 printer << rootType;
2871 if (hasExtras)
2872 printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2873}
2874
2875/// Returns `true` if the given op operand may be consuming the handle value in
2876/// the Transform IR. That is, if it may have a Free effect on it.
2878 // Conservatively assume the effect being present in absence of the interface.
2879 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2880 if (!iface)
2881 return true;
2882
2883 return isHandleConsumed(use.get(), iface);
2884}
2885
2886static LogicalResult
2888 function_ref<InFlightDiagnostic()> reportError) {
2889 OpOperand *potentialConsumer = nullptr;
2890 for (OpOperand &use : value.getUses()) {
2892 continue;
2893
2894 if (!potentialConsumer) {
2895 potentialConsumer = &use;
2896 continue;
2897 }
2898
2899 InFlightDiagnostic diag = reportError()
2900 << " has more than one potential consumer";
2901 diag.attachNote(potentialConsumer->getOwner()->getLoc())
2902 << "used here as operand #" << potentialConsumer->getOperandNumber();
2903 diag.attachNote(use.getOwner()->getLoc())
2904 << "used here as operand #" << use.getOperandNumber();
2905 return diag;
2906 }
2907
2908 return success();
2909}
2910
2911LogicalResult transform::SequenceOp::verify() {
2912 assert(getBodyBlock()->getNumArguments() >= 1 &&
2913 "the number of arguments must have been verified to be more than 1 by "
2914 "PossibleTopLevelTransformOpTrait");
2915
2916 if (!getRoot() && !getExtraBindings().empty()) {
2917 return emitOpError()
2918 << "does not expect extra operands when used as top-level";
2919 }
2920
2921 // Check if a block argument has more than one consuming use.
2922 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2923 if (failed(checkDoubleConsume(arg, [this, arg]() {
2924 return (emitOpError() << "block argument #" << arg.getArgNumber());
2925 }))) {
2926 return failure();
2927 }
2928 }
2929
2930 // Check properties of the nested operations they cannot check themselves.
2931 for (Operation &child : *getBodyBlock()) {
2932 if (!isa<TransformOpInterface>(child) &&
2933 &child != &getBodyBlock()->back()) {
2935 emitOpError()
2936 << "expected children ops to implement TransformOpInterface";
2937 diag.attachNote(child.getLoc()) << "op without interface";
2938 return diag;
2939 }
2940
2941 for (OpResult result : child.getResults()) {
2942 auto report = [&]() {
2943 return (child.emitError() << "result #" << result.getResultNumber());
2944 };
2945 if (failed(checkDoubleConsume(result, report)))
2946 return failure();
2947 }
2948 }
2949
2950 if (!getBodyBlock()->mightHaveTerminator())
2951 return emitOpError() << "expects to have a terminator in the body";
2952
2953 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2954 getOperation()->getResultTypes()) {
2956 << "expects the types of the terminator operands "
2957 "to match the types of the result";
2958 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2959 return diag;
2960 }
2961 return success();
2962}
2963
2964void transform::SequenceOp::getEffects(
2967}
2968
2970transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2971 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
2972 if (getOperation()->getNumOperands() > 0)
2973 return getOperation()->getOperands();
2974 return OperandRange(getOperation()->operand_end(),
2975 getOperation()->operand_end());
2976}
2977
2978void transform::SequenceOp::getSuccessorRegions(
2980 if (point.isParent()) {
2981 Region *bodyRegion = &getBody();
2982 regions.emplace_back(bodyRegion);
2983 return;
2984 }
2985
2986 assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
2987 &getBody() &&
2988 "unexpected region index");
2989 regions.push_back(RegionSuccessor::parent());
2990}
2991
2993transform::SequenceOp::getSuccessorInputs(RegionSuccessor successor) {
2994 if (getNumOperands() == 0)
2995 return ValueRange();
2996 if (successor.isParent())
2997 return getResults();
2998 return getBody().getArguments();
2999}
3000
3001void transform::SequenceOp::getRegionInvocationBounds(
3003 (void)operands;
3004 bounds.emplace_back(1, 1);
3005}
3006
3007void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3008 TypeRange resultTypes,
3009 FailurePropagationMode failurePropagationMode,
3010 Value root,
3011 SequenceBodyBuilderFn bodyBuilder) {
3012 build(builder, state, resultTypes, failurePropagationMode, root,
3013 /*extra_bindings=*/ValueRange());
3014 Type bbArgType = root.getType();
3015 buildSequenceBody(builder, state, bbArgType,
3016 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3017}
3018
3019void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3020 TypeRange resultTypes,
3021 FailurePropagationMode failurePropagationMode,
3022 Value root, ValueRange extraBindings,
3023 SequenceBodyBuilderArgsFn bodyBuilder) {
3024 build(builder, state, resultTypes, failurePropagationMode, root,
3025 extraBindings);
3026 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3027 bodyBuilder);
3028}
3029
3030void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3031 TypeRange resultTypes,
3032 FailurePropagationMode failurePropagationMode,
3033 Type bbArgType,
3034 SequenceBodyBuilderFn bodyBuilder) {
3035 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3036 /*extra_bindings=*/ValueRange());
3037 buildSequenceBody(builder, state, bbArgType,
3038 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3039}
3040
3041void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3042 TypeRange resultTypes,
3043 FailurePropagationMode failurePropagationMode,
3044 Type bbArgType, TypeRange extraBindingTypes,
3045 SequenceBodyBuilderArgsFn bodyBuilder) {
3046 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3047 /*extra_bindings=*/ValueRange());
3048 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3049}
3050
3051//===----------------------------------------------------------------------===//
3052// PrintOp
3053//===----------------------------------------------------------------------===//
3054
3055void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3056 StringRef name) {
3057 if (!name.empty())
3058 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3059}
3060
3061void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3062 Value target, StringRef name) {
3063 result.addOperands({target});
3064 build(builder, result, name);
3065}
3066
3068transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3071 llvm::outs() << "[[[ IR printer: ";
3072 if (getName().has_value())
3073 llvm::outs() << *getName() << " ";
3074
3075 OpPrintingFlags printFlags;
3076 if (getAssumeVerified().value_or(false))
3077 printFlags.assumeVerified();
3078 if (getUseLocalScope().value_or(false))
3079 printFlags.useLocalScope();
3080 if (getSkipRegions().value_or(false))
3081 printFlags.skipRegions();
3082
3083 if (!getTarget()) {
3084 llvm::outs() << "top-level ]]]\n";
3085 state.getTopLevel()->print(llvm::outs(), printFlags);
3086 llvm::outs() << "\n";
3087 llvm::outs().flush();
3089 }
3090
3091 llvm::outs() << "]]]\n";
3092 for (Operation *target : state.getPayloadOps(getTarget())) {
3093 target->print(llvm::outs(), printFlags);
3094 llvm::outs() << "\n";
3095 }
3096
3097 llvm::outs().flush();
3099}
3100
3101void transform::PrintOp::getEffects(
3103 // We don't really care about mutability here, but `getTarget` now
3104 // unconditionally casts to a specific type before verification could run
3105 // here.
3106 if (!getTargetMutable().empty())
3107 onlyReadsHandle(getTargetMutable()[0], effects);
3108 onlyReadsPayload(effects);
3109
3110 // There is no resource for stderr file descriptor, so just declare print
3111 // writes into the default resource.
3112 effects.emplace_back(MemoryEffects::Write::get());
3113}
3114
3115//===----------------------------------------------------------------------===//
3116// VerifyOp
3117//===----------------------------------------------------------------------===//
3118
3120transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3126 << "failed to verify payload op";
3127 diag.attachNote(target->getLoc()) << "payload op";
3128 return diag;
3129 }
3131}
3132
3133void transform::VerifyOp::getEffects(
3135 transform::onlyReadsHandle(getTargetMutable(), effects);
3136}
3137
3138//===----------------------------------------------------------------------===//
3139// YieldOp
3140//===----------------------------------------------------------------------===//
3141
3142void transform::YieldOp::getEffects(
3144 onlyReadsHandle(getOperandsMutable(), effects);
3145}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
Definition DLTI.cpp:38
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue > > blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getOpResults()
Definition Operation.h:420
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:378
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.