MLIR 22.0.0git
TransformOps.cpp
Go to the documentation of this file.
1//===- TransformOps.cpp - Transform dialect operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/Dominance.h"
24#include "mlir/IR/Verifier.h"
30#include "mlir/Transforms/CSE.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/SmallPtrSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/DebugLog.h"
41#include "llvm/Support/ErrorHandling.h"
42#include "llvm/Support/InterleavedRange.h"
43#include <optional>
44
45#define DEBUG_TYPE "transform-dialect"
46#define DEBUG_TYPE_MATCHER "transform-matcher"
47
48using namespace mlir;
49
50static ParseResult parseApplyRegisteredPassOptions(
51 OpAsmParser &parser, DictionaryAttr &options,
54 Operation *op,
55 DictionaryAttr options,
56 ValueRange dynamicOptions);
57static ParseResult parseSequenceOpOperands(
58 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
59 Type &rootType,
61 SmallVectorImpl<Type> &extraBindingTypes);
62static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
63 Value root, Type rootType,
64 ValueRange extraBindings,
65 TypeRange extraBindingTypes);
66static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
68static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
70 ArrayAttr &actions);
71
72/// Helper function to check if the given transform op is contained in (or
73/// equal to) the given payload target op. In that case, an error is returned.
74/// Transforming transform IR that is currently executing is generally unsafe.
76ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
77 Operation *payload) {
78 Operation *transformAncestor = transform.getOperation();
79 while (transformAncestor) {
80 if (transformAncestor == payload) {
82 transform.emitDefiniteFailure()
83 << "cannot apply transform to itself (or one of its ancestors)";
84 diag.attachNote(payload->getLoc()) << "target payload op";
85 return diag;
86 }
87 transformAncestor = transformAncestor->getParentOp();
88 }
90}
91
92#define GET_OP_CLASSES
93#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
94
95//===----------------------------------------------------------------------===//
96// AlternativesOp
97//===----------------------------------------------------------------------===//
98
99OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
100 RegionSuccessor successor) {
101 if (!successor.isParent() && getOperation()->getNumOperands() == 1)
102 return getOperation()->getOperands();
103 return OperandRange(getOperation()->operand_end(),
104 getOperation()->operand_end());
105}
106
107void transform::AlternativesOp::getSuccessorRegions(
109 for (Region &alternative : llvm::drop_begin(
110 getAlternatives(), point.isParent()
111 ? 0
114 ->getRegionNumber() +
115 1)) {
116 regions.emplace_back(&alternative, !getOperands().empty()
117 ? alternative.getArguments()
119 }
120 if (!point.isParent())
121 regions.emplace_back(getOperation(), getOperation()->getResults());
122}
123
124void transform::AlternativesOp::getRegionInvocationBounds(
126 (void)operands;
127 // The region corresponding to the first alternative is always executed, the
128 // remaining may or may not be executed.
129 bounds.reserve(getNumRegions());
130 bounds.emplace_back(1, 1);
131 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
132}
133
136 for (const auto &res : block->getParentOp()->getOpResults())
137 results.set(res, {});
138}
139
141transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
144 SmallVector<Operation *> originals;
145 if (Value scopeHandle = getScope())
146 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
147 else
148 originals.push_back(state.getTopLevel());
149
150 for (Operation *original : originals) {
151 if (original->isAncestor(getOperation())) {
153 << "scope must not contain the transforms being applied";
154 diag.attachNote(original->getLoc()) << "scope";
155 return diag;
156 }
157 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
159 << "only isolated-from-above ops can be alternative scopes";
160 diag.attachNote(original->getLoc()) << "scope";
161 return diag;
162 }
163 }
164
165 for (Region &reg : getAlternatives()) {
166 // Clone the scope operations and make the transforms in this alternative
167 // region apply to them by virtue of mapping the block argument (the only
168 // visible handle) to the cloned scope operations. This effectively prevents
169 // the transformation from accessing any IR outside the scope.
170 auto scope = state.make_region_scope(reg);
171 auto clones = llvm::to_vector(
172 llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
173 auto deleteClones = llvm::make_scope_exit([&] {
174 for (Operation *clone : clones)
175 clone->erase();
176 });
177 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
179
180 bool failed = false;
181 for (Operation &transform : reg.front().without_terminator()) {
183 state.applyTransform(cast<TransformOpInterface>(transform));
184 if (result.isSilenceableFailure()) {
185 LDBG() << "alternative failed: " << result.getMessage();
186 failed = true;
187 break;
188 }
189
190 if (::mlir::failed(result.silence()))
192 }
193
194 // If all operations in the given alternative succeeded, no need to consider
195 // the rest. Replace the original scoping operation with the clone on which
196 // the transformations were performed.
197 if (!failed) {
198 // We will be using the clones, so cancel their scheduled deletion.
199 deleteClones.release();
200 TrackingListener listener(state, *this);
201 IRRewriter rewriter(getContext(), &listener);
202 for (const auto &kvp : llvm::zip(originals, clones)) {
203 Operation *original = std::get<0>(kvp);
204 Operation *clone = std::get<1>(kvp);
205 original->getBlock()->getOperations().insert(original->getIterator(),
206 clone);
207 rewriter.replaceOp(original, clone->getResults());
208 }
209 detail::forwardTerminatorOperands(&reg.front(), state, results);
211 }
212 }
213 return emitSilenceableError() << "all alternatives failed";
214}
215
216void transform::AlternativesOp::getEffects(
218 consumesHandle(getOperation()->getOpOperands(), effects);
219 producesHandle(getOperation()->getOpResults(), effects);
220 for (Region *region : getRegions()) {
221 if (!region->empty())
222 producesHandle(region->front().getArguments(), effects);
223 }
224 modifiesPayload(effects);
225}
226
227LogicalResult transform::AlternativesOp::verify() {
228 for (Region &alternative : getAlternatives()) {
229 Block &block = alternative.front();
230 Operation *terminator = block.getTerminator();
231 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
233 << "expects terminator operands to have the "
234 "same type as results of the operation";
235 diag.attachNote(terminator->getLoc()) << "terminator";
236 return diag;
237 }
238 }
239
240 return success();
241}
242
243//===----------------------------------------------------------------------===//
244// AnnotateOp
245//===----------------------------------------------------------------------===//
246
248transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
252 llvm::to_vector(state.getPayloadOps(getTarget()));
253
254 Attribute attr = UnitAttr::get(getContext());
255 if (auto paramH = getParam()) {
256 ArrayRef<Attribute> params = state.getParams(paramH);
257 if (params.size() != 1) {
258 if (targets.size() != params.size()) {
259 return emitSilenceableError()
260 << "parameter and target have different payload lengths ("
261 << params.size() << " vs " << targets.size() << ")";
262 }
263 for (auto &&[target, attr] : llvm::zip_equal(targets, params))
264 target->setAttr(getName(), attr);
266 }
267 attr = params[0];
268 }
269 for (auto *target : targets)
270 target->setAttr(getName(), attr);
272}
273
274void transform::AnnotateOp::getEffects(
276 onlyReadsHandle(getTargetMutable(), effects);
277 onlyReadsHandle(getParamMutable(), effects);
278 modifiesPayload(effects);
279}
280
281//===----------------------------------------------------------------------===//
282// ApplyCommonSubexpressionEliminationOp
283//===----------------------------------------------------------------------===//
284
286transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
288 ApplyToEachResultList &results, transform::TransformState &state) {
289 // Make sure that this transform is not applied to itself. Modifying the
290 // transform IR while it is being interpreted is generally dangerous.
291 DiagnosedSilenceableFailure payloadCheck =
293 if (!payloadCheck.succeeded())
294 return payloadCheck;
295
296 DominanceInfo domInfo;
299}
300
301void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
303 transform::onlyReadsHandle(getTargetMutable(), effects);
305}
306
307//===----------------------------------------------------------------------===//
308// ApplyDeadCodeEliminationOp
309//===----------------------------------------------------------------------===//
310
311DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
313 ApplyToEachResultList &results, transform::TransformState &state) {
314 // Make sure that this transform is not applied to itself. Modifying the
315 // transform IR while it is being interpreted is generally dangerous.
316 DiagnosedSilenceableFailure payloadCheck =
318 if (!payloadCheck.succeeded())
319 return payloadCheck;
320
321 // Maintain a worklist of potentially dead ops.
322 SetVector<Operation *> worklist;
323
324 // Helper function that adds all defining ops of used values (operands and
325 // operands of nested ops).
326 auto addDefiningOpsToWorklist = [&](Operation *op) {
327 op->walk([&](Operation *op) {
328 for (Value v : op->getOperands())
329 if (Operation *defOp = v.getDefiningOp())
330 if (target->isProperAncestor(defOp))
331 worklist.insert(defOp);
332 });
333 };
334
335 // Helper function that erases an op.
336 auto eraseOp = [&](Operation *op) {
337 // Remove op and nested ops from the worklist.
338 op->walk([&](Operation *op) {
339 const auto *it = llvm::find(worklist, op);
340 if (it != worklist.end())
341 worklist.erase(it);
342 });
343 rewriter.eraseOp(op);
344 };
345
346 // Initial walk over the IR.
347 target->walk<WalkOrder::PostOrder>([&](Operation *op) {
348 if (op != target && isOpTriviallyDead(op)) {
349 addDefiningOpsToWorklist(op);
350 eraseOp(op);
351 }
352 });
353
354 // Erase all ops that have become dead.
355 while (!worklist.empty()) {
356 Operation *op = worklist.pop_back_val();
357 if (!isOpTriviallyDead(op))
358 continue;
359 addDefiningOpsToWorklist(op);
360 eraseOp(op);
361 }
362
364}
365
366void transform::ApplyDeadCodeEliminationOp::getEffects(
368 transform::onlyReadsHandle(getTargetMutable(), effects);
370}
371
372//===----------------------------------------------------------------------===//
373// ApplyPatternsOp
374//===----------------------------------------------------------------------===//
375
376DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
378 ApplyToEachResultList &results, transform::TransformState &state) {
379 // Make sure that this transform is not applied to itself. Modifying the
380 // transform IR while it is being interpreted is generally dangerous. Even
381 // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
382 // performs many additional simplifications such as dead code elimination.
383 DiagnosedSilenceableFailure payloadCheck =
385 if (!payloadCheck.succeeded())
386 return payloadCheck;
387
388 // Gather all specified patterns.
389 MLIRContext *ctx = target->getContext();
391 if (!getRegion().empty()) {
392 for (Operation &op : getRegion().front()) {
393 cast<transform::PatternDescriptorOpInterface>(&op)
394 .populatePatternsWithState(patterns, state);
395 }
396 }
397
398 // Configure the GreedyPatternRewriteDriver.
400 config.setListener(
401 static_cast<RewriterBase::Listener *>(rewriter.getListener()));
402 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
403
404 config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
406 : getMaxIterations());
407 config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
409 : getMaxNumRewrites());
410
411 // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
412 // was requested, apply the greedy pattern rewrite only once. (The greedy
413 // pattern rewrite driver already iterates to a fixpoint internally.)
414 bool cseChanged = false;
415 // One or two iterations should be sufficient. Stop iterating after a certain
416 // threshold to make debugging easier.
417 static const int64_t kNumMaxIterations = 50;
418 int64_t iteration = 0;
419 do {
420 LogicalResult result = failure();
421 if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
422 // Op is isolated from above. Apply patterns and also perform region
423 // simplification.
424 result = applyPatternsGreedily(target, frozenPatterns, config);
425 } else {
426 // Manually gather list of ops because the other
427 // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
428 // from above. This way, patterns can be applied to ops that are not
429 // isolated from above. Regions are not being simplified. Furthermore,
430 // only a single greedy rewrite iteration is performed.
432 target->walk([&](Operation *nestedOp) {
433 if (target != nestedOp)
434 ops.push_back(nestedOp);
435 });
436 result = applyOpPatternsGreedily(ops, frozenPatterns, config);
437 }
438
439 // A failure typically indicates that the pattern application did not
440 // converge.
441 if (failed(result)) {
443 << "greedy pattern application failed";
444 }
445
446 if (getApplyCse()) {
447 DominanceInfo domInfo;
449 &cseChanged);
450 }
451 } while (cseChanged && ++iteration < kNumMaxIterations);
452
453 if (iteration == kNumMaxIterations)
454 return emitDefiniteFailure() << "fixpoint iteration did not converge";
455
457}
458
459LogicalResult transform::ApplyPatternsOp::verify() {
460 if (!getRegion().empty()) {
461 for (Operation &op : getRegion().front()) {
462 if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
464 << "expected children ops to implement "
465 "PatternDescriptorOpInterface";
466 diag.attachNote(op.getLoc()) << "op without interface";
467 return diag;
468 }
469 }
470 }
471 return success();
472}
473
474void transform::ApplyPatternsOp::getEffects(
476 transform::onlyReadsHandle(getTargetMutable(), effects);
478}
479
480void transform::ApplyPatternsOp::build(
482 function_ref<void(OpBuilder &, Location)> bodyBuilder) {
483 result.addOperands(target);
484
485 OpBuilder::InsertionGuard g(builder);
486 Region *region = result.addRegion();
487 builder.createBlock(region);
488 if (bodyBuilder)
489 bodyBuilder(builder, result.location);
490}
491
492//===----------------------------------------------------------------------===//
493// ApplyCanonicalizationPatternsOp
494//===----------------------------------------------------------------------===//
495
496void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
498 MLIRContext *ctx = patterns.getContext();
499 for (Dialect *dialect : ctx->getLoadedDialects())
500 dialect->getCanonicalizationPatterns(patterns);
502 op.getCanonicalizationPatterns(patterns, ctx);
503}
504
505//===----------------------------------------------------------------------===//
506// ApplyConversionPatternsOp
507//===----------------------------------------------------------------------===//
508
509DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
512 MLIRContext *ctx = getContext();
513
514 // Instantiate the default type converter if a type converter builder is
515 // specified.
516 std::unique_ptr<TypeConverter> defaultTypeConverter;
517 transform::TypeConverterBuilderOpInterface typeConverterBuilder =
518 getDefaultTypeConverter();
519 if (typeConverterBuilder)
520 defaultTypeConverter = typeConverterBuilder.getTypeConverter();
521
522 // Configure conversion target.
523 ConversionTarget conversionTarget(*getContext());
524 if (getLegalOps())
525 for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
526 conversionTarget.addLegalOp(
527 OperationName(cast<StringAttr>(attr).getValue(), ctx));
528 if (getIllegalOps())
529 for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
530 conversionTarget.addIllegalOp(
531 OperationName(cast<StringAttr>(attr).getValue(), ctx));
532 if (getLegalDialects())
533 for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
534 conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
535 if (getIllegalDialects())
536 for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
537 conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
538
539 // Gather all specified patterns.
541 // Need to keep the converters alive until after pattern application because
542 // the patterns take a reference to an object that would otherwise get out of
543 // scope.
545 if (!getPatterns().empty()) {
546 for (Operation &op : getPatterns().front()) {
547 auto descriptor =
548 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
549
550 // Check if this pattern set specifies a type converter.
551 std::unique_ptr<TypeConverter> typeConverter =
552 descriptor.getTypeConverter();
553 TypeConverter *converter = nullptr;
554 if (typeConverter) {
555 keepAliveConverters.emplace_back(std::move(typeConverter));
556 converter = keepAliveConverters.back().get();
557 } else {
558 // No type converter specified: Use the default type converter.
559 if (!defaultTypeConverter) {
561 << "pattern descriptor does not specify type "
562 "converter and apply_conversion_patterns op has "
563 "no default type converter";
564 diag.attachNote(op.getLoc()) << "pattern descriptor op";
565 return diag;
566 }
567 converter = defaultTypeConverter.get();
568 }
569
570 // Add descriptor-specific updates to the conversion target, which may
571 // depend on the final type converter. In structural converters, the
572 // legality of types dictates the dynamic legality of an operation.
573 descriptor.populateConversionTargetRules(*converter, conversionTarget);
574
575 descriptor.populatePatterns(*converter, patterns);
576 }
577 }
578
579 // Attach a tracking listener if handles should be preserved. We configure the
580 // listener to allow op replacements with different names, as conversion
581 // patterns typically replace ops with replacement ops that have a different
582 // name.
583 TrackingListenerConfig trackingConfig;
584 trackingConfig.requireMatchingReplacementOpName = false;
585 ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
586 ConversionConfig conversionConfig;
587 if (getPreserveHandles())
588 conversionConfig.listener = &trackingListener;
589
590 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
591 for (Operation *target : state.getPayloadOps(getTarget())) {
592 // Make sure that this transform is not applied to itself. Modifying the
593 // transform IR while it is being interpreted is generally dangerous.
594 DiagnosedSilenceableFailure payloadCheck =
596 if (!payloadCheck.succeeded())
597 return payloadCheck;
598
599 LogicalResult status = failure();
600 if (getPartialConversion()) {
601 status = applyPartialConversion(target, conversionTarget, frozenPatterns,
602 conversionConfig);
603 } else {
604 status = applyFullConversion(target, conversionTarget, frozenPatterns,
605 conversionConfig);
606 }
607
608 // Check dialect conversion state.
610 if (failed(status)) {
611 diag = emitSilenceableError() << "dialect conversion failed";
612 diag.attachNote(target->getLoc()) << "target op";
613 }
614
615 // Check tracking listener error state.
616 DiagnosedSilenceableFailure trackingFailure =
617 trackingListener.checkAndResetError();
618 if (!trackingFailure.succeeded()) {
619 if (diag.succeeded()) {
620 // Tracking failure is the only failure.
621 return trackingFailure;
622 }
623 diag.attachNote() << "tracking listener also failed: "
624 << trackingFailure.getMessage();
625 (void)trackingFailure.silence();
626 }
627
628 if (!diag.succeeded())
629 return diag;
630 }
631
633}
634
635LogicalResult transform::ApplyConversionPatternsOp::verify() {
636 if (getNumRegions() != 1 && getNumRegions() != 2)
637 return emitOpError() << "expected 1 or 2 regions";
638 if (!getPatterns().empty()) {
639 for (Operation &op : getPatterns().front()) {
640 if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
642 emitOpError() << "expected pattern children ops to implement "
643 "ConversionPatternDescriptorOpInterface";
644 diag.attachNote(op.getLoc()) << "op without interface";
645 return diag;
646 }
647 }
648 }
649 if (getNumRegions() == 2) {
650 Region &typeConverterRegion = getRegion(1);
651 if (!llvm::hasSingleElement(typeConverterRegion.front()))
652 return emitOpError()
653 << "expected exactly one op in default type converter region";
654 Operation *maybeTypeConverter = &typeConverterRegion.front().front();
655 auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
656 maybeTypeConverter);
657 if (!typeConverterOp) {
659 << "expected default converter child op to "
660 "implement TypeConverterBuilderOpInterface";
661 diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
662 return diag;
663 }
664 // Check default type converter type.
665 if (!getPatterns().empty()) {
666 for (Operation &op : getPatterns().front()) {
667 auto descriptor =
668 cast<transform::ConversionPatternDescriptorOpInterface>(&op);
669 if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
670 return failure();
671 }
672 }
673 }
674 return success();
675}
676
677void transform::ApplyConversionPatternsOp::getEffects(
679 if (!getPreserveHandles()) {
680 transform::consumesHandle(getTargetMutable(), effects);
681 } else {
682 transform::onlyReadsHandle(getTargetMutable(), effects);
683 }
685}
686
687void transform::ApplyConversionPatternsOp::build(
689 function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
690 function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
691 result.addOperands(target);
692
693 {
694 OpBuilder::InsertionGuard g(builder);
695 Region *region1 = result.addRegion();
696 builder.createBlock(region1);
697 if (patternsBodyBuilder)
698 patternsBodyBuilder(builder, result.location);
699 }
700 {
701 OpBuilder::InsertionGuard g(builder);
702 Region *region2 = result.addRegion();
703 builder.createBlock(region2);
704 if (typeConverterBodyBuilder)
705 typeConverterBodyBuilder(builder, result.location);
706 }
707}
708
709//===----------------------------------------------------------------------===//
710// ApplyToLLVMConversionPatternsOp
711//===----------------------------------------------------------------------===//
712
713void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
714 TypeConverter &typeConverter, RewritePatternSet &patterns) {
715 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
716 assert(dialect && "expected that dialect is loaded");
717 auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
718 // ConversionTarget is currently ignored because the enclosing
719 // apply_conversion_patterns op sets up its own ConversionTarget.
721 iface->populateConvertToLLVMConversionPatterns(
722 target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
723}
724
725LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
726 transform::TypeConverterBuilderOpInterface builder) {
727 if (builder.getTypeConverterType() != "LLVMTypeConverter")
728 return emitOpError("expected LLVMTypeConverter");
729 return success();
730}
731
732LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
733 Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
734 if (!dialect)
735 return emitOpError("unknown dialect or dialect not loaded: ")
736 << getDialectName();
737 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
738 if (!iface)
739 return emitOpError(
740 "dialect does not implement ConvertToLLVMPatternInterface or "
741 "extension was not loaded: ")
742 << getDialectName();
743 return success();
744}
745
746//===----------------------------------------------------------------------===//
747// ApplyLoopInvariantCodeMotionOp
748//===----------------------------------------------------------------------===//
749
751transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
752 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
755 // Currently, LICM does not remove operations, so we don't need tracking.
756 // If this ever changes, add a LICM entry point that takes a rewriter.
759}
760
761void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
763 transform::onlyReadsHandle(getTargetMutable(), effects);
765}
766
767//===----------------------------------------------------------------------===//
768// ApplyRegisteredPassOp
769//===----------------------------------------------------------------------===//
770
771void transform::ApplyRegisteredPassOp::getEffects(
773 consumesHandle(getTargetMutable(), effects);
774 onlyReadsHandle(getDynamicOptionsMutable(), effects);
775 producesHandle(getOperation()->getOpResults(), effects);
776 modifiesPayload(effects);
777}
778
780transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
783 // Obtain a single options-string to pass to the pass(-pipeline) from options
784 // passed in as a dictionary of keys mapping to values which are either
785 // attributes or param-operands pointing to attributes.
786 OperandRange dynamicOptions = getDynamicOptions();
787
788 std::string options;
789 llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
790
791 // A helper to convert an option's attribute value into a corresponding
792 // string representation, with the ability to obtain the attr(s) from a param.
793 std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
794 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
795 // The corresponding value attribute(s) is/are passed in via a param.
796 // Obtain the param-operand via its specified index.
797 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
798 assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
799 "the number of ParamOperandAttrs in the options DictionaryAttr"
800 "should be the same as the number of options passed as params");
801 ArrayRef<Attribute> attrsAssociatedToParam =
802 state.getParams(dynamicOptions[dynamicOptionIdx]);
803 // Recursive so as to append all attrs associated to the param.
804 llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
805 ",");
806 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
807 // Recursive so as to append all nested attrs of the array.
808 llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
809 } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
810 // Convert to unquoted string.
811 optionsStream << strAttr.getValue().str();
812 } else {
813 // For all other attributes, ask the attr to print itself (without type).
814 valueAttr.print(optionsStream, /*elideType=*/true);
815 }
816 };
817
818 // Convert the options DictionaryAttr into a single string.
819 llvm::interleave(
820 getOptions(), optionsStream,
821 [&](auto namedAttribute) {
822 optionsStream << namedAttribute.getName().str(); // Append the key.
823 optionsStream << "="; // And the key-value separator.
824 appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
825 },
826 " ");
827 optionsStream.flush();
828
829 // Get pass or pass pipeline from registry.
830 const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
831 if (!info)
832 info = PassInfo::lookup(getPassName());
833 if (!info)
834 return emitDefiniteFailure()
835 << "unknown pass or pass pipeline: " << getPassName();
836
837 // Create pass manager and add the pass or pass pipeline.
839 if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
840 emitError(msg);
841 return failure();
842 }))) {
843 return emitDefiniteFailure()
844 << "failed to add pass or pass pipeline to pipeline: "
845 << getPassName();
846 }
847
848 auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
849 for (Operation *target : targets) {
850 // Make sure that this transform is not applied to itself. Modifying the
851 // transform IR while it is being interpreted is generally dangerous. Even
852 // more so when applying passes because they may perform a wide range of IR
853 // modifications.
854 DiagnosedSilenceableFailure payloadCheck =
856 if (!payloadCheck.succeeded())
857 return payloadCheck;
858
859 // Run the pass or pass pipeline on the current target operation.
860 if (failed(pm.run(target))) {
861 auto diag = emitSilenceableError() << "pass pipeline failed";
862 diag.attachNote(target->getLoc()) << "target op";
863 return diag;
864 }
865 }
866
867 // The applied pass will have directly modified the payload IR(s).
868 results.set(llvm::cast<OpResult>(getResult()), targets);
870}
871
873 OpAsmParser &parser, DictionaryAttr &options,
875 // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
876 SmallVector<NamedAttribute> keyValuePairs;
877 size_t dynamicOptionsIdx = 0;
878
879 // Helper for allowing parsing of option values which can be of the form:
880 // - a normal attribute
881 // - an operand (which would be converted to an attr referring to the operand)
882 // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
883 std::function<ParseResult(Attribute &)> parseValue =
884 [&](Attribute &valueAttr) -> ParseResult {
885 // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
886 if (succeeded(parser.parseOptionalLSquare())) {
888
889 // Recursively parse the array's elements, which might be operands.
890 if (parser.parseCommaSeparatedList(
892 [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
893 " in options dictionary") ||
894 parser.parseRSquare())
895 return failure(); // NB: Attempted parse should've output error message.
896
897 valueAttr = ArrayAttr::get(parser.getContext(), attrs);
898
899 return success();
900 }
901
902 // Parse the value, which can be either an attribute or an operand.
903 OptionalParseResult parsedValueAttr =
904 parser.parseOptionalAttribute(valueAttr);
905 if (!parsedValueAttr.has_value()) {
907 ParseResult parsedOperand = parser.parseOperand(operand);
908 if (failed(parsedOperand))
909 return failure(); // NB: Attempted parse should've output error message.
910 // To make use of the operand, we need to store it in the options dict.
911 // As SSA-values cannot occur in attributes, what we do instead is store
912 // an attribute in its place that contains the index of the param-operand,
913 // so that an attr-value associated to the param can be resolved later on.
914 dynamicOptions.push_back(operand);
915 auto wrappedIndex = IntegerAttr::get(
916 IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
917 valueAttr =
918 transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
919 } else if (failed(parsedValueAttr.value())) {
920 return failure(); // NB: Attempted parse should have output error message.
921 } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
922 return parser.emitError(parser.getCurrentLocation())
923 << "the param_operand attribute is a marker reserved for "
924 << "indicating a value will be passed via params and is only used "
925 << "in the generic print format";
926 }
927
928 return success();
929 };
930
931 // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
932 // string and `value` looks like either an attribute or an operand-in-an-attr.
933 std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
934 std::string key;
935 Attribute valueAttr;
936
937 if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
938 return parser.emitError(parser.getCurrentLocation())
939 << "expected key to either be an identifier or a string";
940
941 if (failed(parser.parseEqual()))
942 return parser.emitError(parser.getCurrentLocation())
943 << "expected '=' after key in key-value pair";
944
945 if (failed(parseValue(valueAttr)))
946 return parser.emitError(parser.getCurrentLocation())
947 << "expected a valid attribute or operand as value associated "
948 << "to key '" << key << "'";
949
950 keyValuePairs.push_back(NamedAttribute(key, valueAttr));
951
952 return success();
953 };
954
957 " in options dictionary"))
958 return failure(); // NB: Attempted parse should have output error message.
959
960 if (DictionaryAttr::findDuplicate(
961 keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
962 .has_value())
963 return parser.emitError(parser.getCurrentLocation())
964 << "duplicate keys found in options dictionary";
965
966 options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
967
968 return success();
969}
970
972 Operation *op,
973 DictionaryAttr options,
974 ValueRange dynamicOptions) {
975 if (options.empty())
976 return;
977
978 std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
979 if (auto paramOperandAttr =
980 dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
981 // Resolve index of param-operand to its actual SSA-value and print that.
982 printer.printOperand(
983 dynamicOptions[paramOperandAttr.getIndex().getInt()]);
984 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
985 // This case is so that ArrayAttr-contained operands are pretty-printed.
986 printer << "[";
987 llvm::interleaveComma(arrayAttr, printer, printOptionValue);
988 printer << "]";
989 } else {
990 printer.printAttribute(valueAttr);
991 }
992 };
993
994 printer << "{";
995 llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
996 printer << namedAttribute.getName();
997 printer << " = ";
998 printOptionValue(namedAttribute.getValue());
999 });
1000 printer << "}";
1001}
1002
1003LogicalResult transform::ApplyRegisteredPassOp::verify() {
1004 // Check that there is a one-to-one correspondence between param operands
1005 // and references to dynamic options in the options dictionary.
1006
1007 auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1008
1009 // Helper for option values to mark seen operands as having been seen (once).
1010 std::function<LogicalResult(Attribute)> checkOptionValue =
1011 [&](Attribute valueAttr) -> LogicalResult {
1012 if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1013 int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1014 if (dynamicOptionIdx < 0 ||
1015 dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1016 return emitOpError()
1017 << "dynamic option index " << dynamicOptionIdx
1018 << " is out of bounds for the number of dynamic options: "
1019 << dynamicOptions.size();
1020 if (dynamicOptions[dynamicOptionIdx] == nullptr)
1021 return emitOpError() << "dynamic option index " << dynamicOptionIdx
1022 << " is already used in options";
1023 dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1024 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1025 // Recurse into ArrayAttrs as they may contain references to operands.
1026 for (auto eltAttr : arrayAttr)
1027 if (failed(checkOptionValue(eltAttr)))
1028 return failure();
1029 }
1030 return success();
1031 };
1032
1033 for (NamedAttribute namedAttr : getOptions())
1034 if (failed(checkOptionValue(namedAttr.getValue())))
1035 return failure();
1036
1037 // All dynamicOptions-params seen in the dict will have been set to null.
1038 for (Value dynamicOption : dynamicOptions)
1039 if (dynamicOption)
1040 return emitOpError() << "a param operand does not have a corresponding "
1041 << "param_operand attr in the options dict";
1042
1043 return success();
1044}
1045
1046//===----------------------------------------------------------------------===//
1047// CastOp
1048//===----------------------------------------------------------------------===//
1049
1051transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1052 Operation *target, ApplyToEachResultList &results,
1054 results.push_back(target);
1056}
1057
1058void transform::CastOp::getEffects(
1060 onlyReadsPayload(effects);
1061 onlyReadsHandle(getInputMutable(), effects);
1062 producesHandle(getOperation()->getOpResults(), effects);
1063}
1064
1065bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1066 assert(inputs.size() == 1 && "expected one input");
1067 assert(outputs.size() == 1 && "expected one output");
1068 return llvm::all_of(
1069 std::initializer_list<Type>{inputs.front(), outputs.front()},
1070 llvm::IsaPred<transform::TransformHandleTypeInterface>);
1071}
1072
1073//===----------------------------------------------------------------------===//
1074// CollectMatchingOp
1075//===----------------------------------------------------------------------===//
1076
1077/// Applies matcher operations from the given `block` using
1078/// `blockArgumentMapping` to initialize block arguments. Updates `state`
1079/// accordingly. If any of the matcher produces a silenceable failure, discards
1080/// it (printing the content to the debug output stream) and returns failure. If
1081/// any of the matchers produces a definite failure, reports it and returns
1082/// failure. If all matchers in the block succeed, populates `mappings` with the
1083/// payload entities associated with the block terminator operands. Note that
1084/// `mappings` will be cleared before that.
1087 ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1090 assert(block.getParent() && "cannot match using a detached block");
1091 auto matchScope = state.make_region_scope(*block.getParent());
1092 if (failed(
1093 state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1095
1096 for (Operation &match : block.without_terminator()) {
1097 if (!isa<transform::MatchOpInterface>(match)) {
1098 return emitDefiniteFailure(match.getLoc())
1099 << "expected operations in the match part to "
1100 "implement MatchOpInterface";
1101 }
1103 state.applyTransform(cast<transform::TransformOpInterface>(match));
1104 if (diag.succeeded())
1105 continue;
1106
1107 return diag;
1108 }
1109
1110 // Remember the values mapped to the terminator operands so we can
1111 // forward them to the action.
1112 ValueRange yieldedValues = block.getTerminator()->getOperands();
1113 // Our contract with the caller is that the mappings will contain only the
1114 // newly mapped values, clear the rest.
1115 mappings.clear();
1116 transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1118}
1119
1120/// Returns `true` if both types implement one of the interfaces provided as
1121/// template parameters.
1122template <typename... Tys>
1123static bool implementSameInterface(Type t1, Type t2) {
1124 return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1125}
1126
1127/// Returns `true` if both types implement one of the transform dialect
1128/// interfaces.
1130 return implementSameInterface<transform::TransformHandleTypeInterface,
1131 transform::TransformParamTypeInterface,
1132 transform::TransformValueHandleTypeInterface>(
1133 t1, t2);
1134}
1135
1136//===----------------------------------------------------------------------===//
1137// CollectMatchingOp
1138//===----------------------------------------------------------------------===//
1139
1141transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1145 getOperation(), getMatcher());
1146 if (matcher.isExternal()) {
1147 return emitDefiniteFailure()
1148 << "unresolved external symbol " << getMatcher();
1149 }
1150
1152 rawResults.resize(getOperation()->getNumResults());
1153 std::optional<DiagnosedSilenceableFailure> maybeFailure;
1154 for (Operation *root : state.getPayloadOps(getRoot())) {
1155 WalkResult walkResult = root->walk([&](Operation *op) {
1156 LDBG(DEBUG_TYPE_MATCHER, 1)
1157 << "matching "
1158 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1159 << " @" << op;
1160
1161 // Try matching.
1163 SmallVector<transform::MappedValue> inputMapping({op});
1165 matcher.getFunctionBody().front(),
1166 ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1167 mappings);
1168 if (diag.isDefiniteFailure())
1169 return WalkResult::interrupt();
1170 if (diag.isSilenceableFailure()) {
1171 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1172 << " failed: " << diag.getMessage();
1173 return WalkResult::advance();
1174 }
1175
1176 // If succeeded, collect results.
1177 for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1178 if (mapping.size() != 1) {
1179 maybeFailure.emplace(emitSilenceableError()
1180 << "result #" << i << ", associated with "
1181 << mapping.size()
1182 << " payload objects, expected 1");
1183 return WalkResult::interrupt();
1184 }
1185 rawResults[i].push_back(mapping[0]);
1186 }
1187 return WalkResult::advance();
1188 });
1189 if (walkResult.wasInterrupted())
1190 return std::move(*maybeFailure);
1191 assert(!maybeFailure && "failure set but the walk was not interrupted");
1192
1193 for (auto &&[opResult, rawResult] :
1194 llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1195 results.setMappedValues(opResult, rawResult);
1196 }
1197 }
1199}
1200
1201void transform::CollectMatchingOp::getEffects(
1203 onlyReadsHandle(getRootMutable(), effects);
1204 producesHandle(getOperation()->getOpResults(), effects);
1205 onlyReadsPayload(effects);
1206}
1207
1208LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1209 SymbolTableCollection &symbolTable) {
1210 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1211 symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1212 if (!matcherSymbol ||
1213 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1214 return emitError() << "unresolved matcher symbol " << getMatcher();
1215
1216 ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1217 if (argumentTypes.size() != 1 ||
1218 !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1219 return emitError()
1220 << "expected the matcher to take one operation handle argument";
1221 }
1222 if (!matcherSymbol.getArgAttr(
1223 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1224 return emitError() << "expected the matcher argument to be marked readonly";
1225 }
1226
1227 ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1228 if (resultTypes.size() != getOperation()->getNumResults()) {
1229 return emitError()
1230 << "expected the matcher to yield as many values as op has results ("
1231 << getOperation()->getNumResults() << "), got "
1232 << resultTypes.size();
1233 }
1234
1235 for (auto &&[i, matcherType, resultType] :
1236 llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1237 if (implementSameTransformInterface(matcherType, resultType))
1238 continue;
1239
1240 return emitError()
1241 << "mismatching type interfaces for matcher result and op result #"
1242 << i;
1243 }
1244
1245 return success();
1246}
1247
1248//===----------------------------------------------------------------------===//
1249// ForeachMatchOp
1250//===----------------------------------------------------------------------===//
1251
1252// This is fine because nothing is actually consumed by this op.
1253bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1254
1256transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1260 matchActionPairs;
1261 matchActionPairs.reserve(getMatchers().size());
1262 SymbolTableCollection symbolTable;
1263 for (auto &&[matcher, action] :
1264 llvm::zip_equal(getMatchers(), getActions())) {
1265 auto matcherSymbol =
1266 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1267 getOperation(), cast<SymbolRefAttr>(matcher));
1268 auto actionSymbol =
1269 symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1270 getOperation(), cast<SymbolRefAttr>(action));
1271 assert(matcherSymbol && actionSymbol &&
1272 "unresolved symbols not caught by the verifier");
1273
1274 if (matcherSymbol.isExternal())
1275 return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1276 if (actionSymbol.isExternal())
1277 return emitDefiniteFailure() << "unresolved external symbol " << action;
1278
1279 matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1280 }
1281
1282 DiagnosedSilenceableFailure overallDiag =
1284
1285 SmallVector<SmallVector<MappedValue>> matchInputMapping;
1286 SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1287 SmallVector<SmallVector<MappedValue>> actionResultMapping;
1288 // Explicitly add the mapping for the first block argument (the op being
1289 // matched).
1290 matchInputMapping.emplace_back();
1292 getForwardedInputs(), state);
1293 SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1294 actionResultMapping.resize(getForwardedOutputs().size());
1295
1296 for (Operation *root : state.getPayloadOps(getRoot())) {
1297 WalkResult walkResult = root->walk([&](Operation *op) {
1298 // If getRestrictRoot is not present, skip over the root op itself so we
1299 // don't invalidate it.
1300 if (!getRestrictRoot() && op == root)
1301 return WalkResult::advance();
1302
1303 LDBG(DEBUG_TYPE_MATCHER, 1)
1304 << "matching "
1305 << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1306 << " @" << op;
1307
1308 firstMatchArgument.clear();
1309 firstMatchArgument.push_back(op);
1310
1311 // Try all the match/action pairs until the first successful match.
1312 for (auto [matcher, action] : matchActionPairs) {
1314 matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1315 state, matchOutputMapping);
1316 if (diag.isDefiniteFailure())
1317 return WalkResult::interrupt();
1318 if (diag.isSilenceableFailure()) {
1319 LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1320 << " failed: " << diag.getMessage();
1321 continue;
1322 }
1323
1324 auto scope = state.make_region_scope(action.getFunctionBody());
1325 if (failed(state.mapBlockArguments(
1326 action.getFunctionBody().front().getArguments(),
1327 matchOutputMapping))) {
1328 return WalkResult::interrupt();
1329 }
1330
1331 for (Operation &transform :
1332 action.getFunctionBody().front().without_terminator()) {
1334 state.applyTransform(cast<TransformOpInterface>(transform));
1335 if (result.isDefiniteFailure())
1336 return WalkResult::interrupt();
1337 if (result.isSilenceableFailure()) {
1338 if (overallDiag.succeeded()) {
1339 overallDiag = emitSilenceableError() << "actions failed";
1340 }
1341 overallDiag.attachNote(action->getLoc())
1342 << "failed action: " << result.getMessage();
1343 overallDiag.attachNote(op->getLoc())
1344 << "when applied to this matching payload";
1345 (void)result.silence();
1346 continue;
1347 }
1348 }
1349 if (failed(detail::appendValueMappings(
1350 MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1351 action.getFunctionBody().front().getTerminator()->getOperands(),
1352 state, getFlattenResults()))) {
1354 << "action @" << action.getName()
1355 << " has results associated with multiple payload entities, "
1356 "but flattening was not requested";
1357 return WalkResult::interrupt();
1358 }
1359 break;
1360 }
1361 return WalkResult::advance();
1362 });
1363 if (walkResult.wasInterrupted())
1365 }
1366
1367 // The root operation should not have been affected, so we can just reassign
1368 // the payload to the result. Note that we need to consume the root handle to
1369 // make sure any handles to operations inside, that could have been affected
1370 // by actions, are invalidated.
1371 results.set(llvm::cast<OpResult>(getUpdated()),
1372 state.getPayloadOps(getRoot()));
1373 for (auto &&[result, mapping] :
1374 llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1375 results.setMappedValues(result, mapping);
1376 }
1377 return overallDiag;
1378}
1379
1380void transform::ForeachMatchOp::getAsmResultNames(
1381 OpAsmSetValueNameFn setNameFn) {
1382 setNameFn(getUpdated(), "updated_root");
1383 for (Value v : getForwardedOutputs()) {
1384 setNameFn(v, "yielded");
1385 }
1386}
1387
1388void transform::ForeachMatchOp::getEffects(
1390 // Bail if invalid.
1391 if (getOperation()->getNumOperands() < 1 ||
1392 getOperation()->getNumResults() < 1) {
1393 return modifiesPayload(effects);
1394 }
1395
1396 consumesHandle(getRootMutable(), effects);
1397 onlyReadsHandle(getForwardedInputsMutable(), effects);
1398 producesHandle(getOperation()->getOpResults(), effects);
1399 modifiesPayload(effects);
1400}
1401
1402/// Parses the comma-separated list of symbol reference pairs of the format
1403/// `@matcher -> @action`.
1404static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1406 ArrayAttr &actions) {
1407 StringAttr matcher;
1408 StringAttr action;
1409 SmallVector<Attribute> matcherList;
1410 SmallVector<Attribute> actionList;
1411 do {
1412 if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1413 parser.parseSymbolName(action)) {
1414 return failure();
1415 }
1416 matcherList.push_back(SymbolRefAttr::get(matcher));
1417 actionList.push_back(SymbolRefAttr::get(action));
1418 } while (parser.parseOptionalComma().succeeded());
1419
1420 matchers = parser.getBuilder().getArrayAttr(matcherList);
1421 actions = parser.getBuilder().getArrayAttr(actionList);
1422 return success();
1423}
1424
1425/// Prints the comma-separated list of symbol reference pairs of the format
1426/// `@matcher -> @action`.
1428 ArrayAttr matchers, ArrayAttr actions) {
1429 printer.increaseIndent();
1430 printer.increaseIndent();
1431 for (auto &&[matcher, action, idx] : llvm::zip_equal(
1432 matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1433 printer.printNewline();
1434 printer << cast<SymbolRefAttr>(matcher) << " -> "
1435 << cast<SymbolRefAttr>(action);
1436 if (idx != matchers.size() - 1)
1437 printer << ", ";
1438 }
1439 printer.decreaseIndent();
1440 printer.decreaseIndent();
1441}
1442
1443LogicalResult transform::ForeachMatchOp::verify() {
1444 if (getMatchers().size() != getActions().size())
1445 return emitOpError() << "expected the same number of matchers and actions";
1446 if (getMatchers().empty())
1447 return emitOpError() << "expected at least one match/action pair";
1448
1450 for (Attribute name : getMatchers()) {
1451 if (matcherNames.insert(name).second)
1452 continue;
1453 emitWarning() << "matcher " << name
1454 << " is used more than once, only the first match will apply";
1455 }
1456
1457 return success();
1458}
1459
1460/// Checks that the attributes of the function-like operation have correct
1461/// consumption effect annotations. If `alsoVerifyInternal`, checks for
1462/// annotations being present even if they can be inferred from the body.
1464verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1465 bool alsoVerifyInternal = false) {
1466 auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1467 llvm::SmallDenseSet<unsigned> consumedArguments;
1468 if (!op.isExternal()) {
1469 transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1470 consumedArguments);
1471 }
1472 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1473 bool isConsumed =
1474 op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1475 nullptr;
1476 bool isReadOnly =
1477 op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1478 nullptr;
1479 if (isConsumed && isReadOnly) {
1480 return transformOp.emitSilenceableError()
1481 << "argument #" << i << " cannot be both readonly and consumed";
1482 }
1483 if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1484 return transformOp.emitSilenceableError()
1485 << "must provide consumed/readonly status for arguments of "
1486 "external or called ops";
1487 }
1488 if (op.isExternal())
1489 continue;
1490
1491 if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1492 return transformOp.emitSilenceableError()
1493 << "argument #" << i
1494 << " is consumed in the body but is not marked as such";
1495 }
1496 if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1497 // Cannot use op.emitWarning() here as it would attempt to verify the op
1498 // before printing, resulting in infinite recursion.
1499 emitWarning(op->getLoc())
1500 << "op argument #" << i
1501 << " is not consumed in the body but is marked as consumed";
1502 }
1503 }
1505}
1506
1507LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1508 SymbolTableCollection &symbolTable) {
1509 assert(getMatchers().size() == getActions().size());
1510 auto consumedAttr =
1511 StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1512 for (auto &&[matcher, action] :
1513 llvm::zip_equal(getMatchers(), getActions())) {
1514 // Presence and typing.
1515 auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1516 symbolTable.lookupNearestSymbolFrom(getOperation(),
1517 cast<SymbolRefAttr>(matcher)));
1518 auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1519 symbolTable.lookupNearestSymbolFrom(getOperation(),
1520 cast<SymbolRefAttr>(action)));
1521 if (!matcherSymbol ||
1522 !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1523 return emitError() << "unresolved matcher symbol " << matcher;
1524 if (!actionSymbol ||
1525 !isa<TransformOpInterface>(actionSymbol.getOperation()))
1526 return emitError() << "unresolved action symbol " << action;
1527
1529 /*emitWarnings=*/false,
1530 /*alsoVerifyInternal=*/true)
1531 .checkAndReport())) {
1532 return failure();
1533 }
1535 /*emitWarnings=*/false,
1536 /*alsoVerifyInternal=*/true)
1537 .checkAndReport())) {
1538 return failure();
1539 }
1540
1541 // Input -> matcher forwarding.
1542 TypeRange operandTypes = getOperandTypes();
1543 TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1544 if (operandTypes.size() != matcherArguments.size()) {
1546 emitError() << "the number of operands (" << operandTypes.size()
1547 << ") doesn't match the number of matcher arguments ("
1548 << matcherArguments.size() << ") for " << matcher;
1549 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1550 return diag;
1551 }
1552 for (auto &&[i, operand, argument] :
1553 llvm::enumerate(operandTypes, matcherArguments)) {
1554 if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1556 emitOpError()
1557 << "does not expect matcher symbol to consume its operand #" << i;
1558 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1559 return diag;
1560 }
1561
1562 if (implementSameTransformInterface(operand, argument))
1563 continue;
1564
1566 emitError()
1567 << "mismatching type interfaces for operand and matcher argument #"
1568 << i << " of matcher " << matcher;
1569 diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1570 return diag;
1571 }
1572
1573 // Matcher -> action forwarding.
1574 TypeRange matcherResults = matcherSymbol.getResultTypes();
1575 TypeRange actionArguments = actionSymbol.getArgumentTypes();
1576 if (matcherResults.size() != actionArguments.size()) {
1577 return emitError() << "mismatching number of matcher results and "
1578 "action arguments between "
1579 << matcher << " (" << matcherResults.size() << ") and "
1580 << action << " (" << actionArguments.size() << ")";
1581 }
1582 for (auto &&[i, matcherType, actionType] :
1583 llvm::enumerate(matcherResults, actionArguments)) {
1584 if (implementSameTransformInterface(matcherType, actionType))
1585 continue;
1586
1587 return emitError() << "mismatching type interfaces for matcher result "
1588 "and action argument #"
1589 << i << "of matcher " << matcher << " and action "
1590 << action;
1591 }
1592
1593 // Action -> result forwarding.
1594 TypeRange actionResults = actionSymbol.getResultTypes();
1595 auto resultTypes = TypeRange(getResultTypes()).drop_front();
1596 if (actionResults.size() != resultTypes.size()) {
1598 emitError() << "the number of action results ("
1599 << actionResults.size() << ") for " << action
1600 << " doesn't match the number of extra op results ("
1601 << resultTypes.size() << ")";
1602 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1603 return diag;
1604 }
1605 for (auto &&[i, resultType, actionType] :
1606 llvm::enumerate(resultTypes, actionResults)) {
1607 if (implementSameTransformInterface(resultType, actionType))
1608 continue;
1609
1611 emitError() << "mismatching type interfaces for action result #" << i
1612 << " of action " << action << " and op result";
1613 diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1614 return diag;
1615 }
1616 }
1617 return success();
1618}
1619
1620//===----------------------------------------------------------------------===//
1621// ForeachOp
1622//===----------------------------------------------------------------------===//
1623
1625transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1628 // We store the payloads before executing the body as ops may be removed from
1629 // the mapping by the TrackingRewriter while iteration is in progress.
1631 detail::prepareValueMappings(payloads, getTargets(), state);
1632 size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1633 bool withZipShortest = getWithZipShortest();
1634
1635 // In case of `zip_shortest`, set the number of iterations to the
1636 // smallest payload in the targets.
1637 if (withZipShortest) {
1638 numIterations =
1639 llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1640 const SmallVector<MappedValue> &b) {
1641 return a.size() < b.size();
1642 })->size();
1643
1644 for (auto &payload : payloads)
1645 payload.resize(numIterations);
1646 }
1647
1648 // As we will be "zipping" over them, check all payloads have the same size.
1649 // `zip_shortest` adjusts all payloads to the same size, so skip this check
1650 // when true.
1651 for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1652 argIdx++) {
1653 if (payloads[argIdx].size() != numIterations) {
1654 return emitSilenceableError()
1655 << "prior targets' payload size (" << numIterations
1656 << ") differs from payload size (" << payloads[argIdx].size()
1657 << ") of target " << getTargets()[argIdx];
1658 }
1659 }
1660
1661 // Start iterating, indexing into payloads to obtain the right arguments to
1662 // call the body with - each slice of payloads at the same argument index
1663 // corresponding to a tuple to use as the body's block arguments.
1664 ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1665 SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1666 for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1667 auto scope = state.make_region_scope(getBody());
1668 // Set up arguments to the region's block.
1669 for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1670 MappedValue argument = payloads[argIdx][iterIdx];
1671 // Note that each blockArg's handle gets associated with just a single
1672 // element from the corresponding target's payload.
1673 if (failed(state.mapBlockArgument(blockArg, {argument})))
1675 }
1676
1677 // Execute loop body.
1678 for (Operation &transform : getBody().front().without_terminator()) {
1680 llvm::cast<transform::TransformOpInterface>(transform));
1681 if (!result.succeeded())
1682 return result;
1683 }
1684
1685 // Append yielded payloads to corresponding results from prior iterations.
1686 OperandRange yieldOperands = getYieldOp().getOperands();
1687 for (auto &&[result, yieldOperand, resTuple] :
1688 llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1689 // NB: each iteration we add any number of ops/vals/params to a result.
1690 if (isa<TransformHandleTypeInterface>(result.getType()))
1691 llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1692 else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1693 llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1694 else if (isa<TransformParamTypeInterface>(result.getType()))
1695 llvm::append_range(resTuple, state.getParams(yieldOperand));
1696 else
1697 assert(false && "unhandled handle type");
1698 }
1699
1700 // Associate the accumulated result payloads to the op's actual results.
1701 for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1702 results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1703
1705}
1706
1707void transform::ForeachOp::getEffects(
1709 // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1710 // arity errors, this method might get called before/in absence of `verify()`.
1711 for (auto &&[target, blockArg] :
1712 llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1713 BlockArgument blockArgument = blockArg;
1714 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1715 return isHandleConsumed(blockArgument,
1716 cast<TransformOpInterface>(&op));
1717 })) {
1718 consumesHandle(target, effects);
1719 } else {
1720 onlyReadsHandle(target, effects);
1721 }
1722 }
1723
1724 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1725 return doesModifyPayload(cast<TransformOpInterface>(&op));
1726 })) {
1727 modifiesPayload(effects);
1728 } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1729 return doesReadPayload(cast<TransformOpInterface>(&op));
1730 })) {
1731 onlyReadsPayload(effects);
1732 }
1733
1734 producesHandle(getOperation()->getOpResults(), effects);
1735}
1736
1737void transform::ForeachOp::getSuccessorRegions(
1739 Region *bodyRegion = &getBody();
1740 if (point.isParent()) {
1741 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1742 return;
1743 }
1744
1745 // Branch back to the region or the parent.
1747 &getBody() &&
1748 "unexpected region index");
1749 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1750 regions.emplace_back(getOperation(), getOperation()->getResults());
1751}
1752
1754transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
1755 // Each block argument handle is mapped to a subset (one op to be precise)
1756 // of the payload of the corresponding `targets` operand of ForeachOp.
1757 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
1758 return getOperation()->getOperands();
1759}
1760
1761transform::YieldOp transform::ForeachOp::getYieldOp() {
1762 return cast<transform::YieldOp>(getBody().front().getTerminator());
1763}
1764
1765LogicalResult transform::ForeachOp::verify() {
1766 for (auto [targetOpt, bodyArgOpt] :
1767 llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1768 if (!targetOpt || !bodyArgOpt)
1769 return emitOpError() << "expects the same number of targets as the body "
1770 "has block arguments";
1771 if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1772 return emitOpError(
1773 "expects co-indexed targets and the body's "
1774 "block arguments to have the same op/value/param type");
1775 }
1776
1777 for (auto [resultOpt, yieldOperandOpt] :
1778 llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1779 if (!resultOpt || !yieldOperandOpt)
1780 return emitOpError() << "expects the same number of results as the "
1781 "yield terminator has operands";
1782 if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1783 return emitOpError("expects co-indexed results and yield "
1784 "operands to have the same op/value/param type");
1785 }
1786
1787 return success();
1788}
1789
1790//===----------------------------------------------------------------------===//
1791// GetParentOp
1792//===----------------------------------------------------------------------===//
1793
1795transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1799 DenseSet<Operation *> resultSet;
1800 for (Operation *target : state.getPayloadOps(getTarget())) {
1801 Operation *parent = target;
1802 for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1803 parent = parent->getParentOp();
1804 while (parent) {
1805 bool checkIsolatedFromAbove =
1806 !getIsolatedFromAbove() ||
1808 bool checkOpName = !getOpName().has_value() ||
1809 parent->getName().getStringRef() == *getOpName();
1810 if (checkIsolatedFromAbove && checkOpName)
1811 break;
1812 parent = parent->getParentOp();
1813 }
1814 if (!parent) {
1815 if (getAllowEmptyResults()) {
1816 results.set(llvm::cast<OpResult>(getResult()), parents);
1818 }
1820 emitSilenceableError()
1821 << "could not find a parent op that matches all requirements";
1822 diag.attachNote(target->getLoc()) << "target op";
1823 return diag;
1824 }
1825 }
1826 if (getDeduplicate()) {
1827 if (resultSet.insert(parent).second)
1828 parents.push_back(parent);
1829 } else {
1830 parents.push_back(parent);
1831 }
1832 }
1833 results.set(llvm::cast<OpResult>(getResult()), parents);
1835}
1836
1837//===----------------------------------------------------------------------===//
1838// GetConsumersOfResult
1839//===----------------------------------------------------------------------===//
1840
1842transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1845 int64_t resultNumber = getResultNumber();
1846 auto payloadOps = state.getPayloadOps(getTarget());
1847 if (std::empty(payloadOps)) {
1848 results.set(cast<OpResult>(getResult()), {});
1850 }
1851 if (!llvm::hasSingleElement(payloadOps))
1852 return emitDefiniteFailure()
1853 << "handle must be mapped to exactly one payload op";
1854
1855 Operation *target = *payloadOps.begin();
1856 if (target->getNumResults() <= resultNumber)
1857 return emitDefiniteFailure() << "result number overflow";
1858 results.set(llvm::cast<OpResult>(getResult()),
1859 llvm::to_vector(target->getResult(resultNumber).getUsers()));
1861}
1862
1863//===----------------------------------------------------------------------===//
1864// GetDefiningOp
1865//===----------------------------------------------------------------------===//
1866
1868transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1871 SmallVector<Operation *> definingOps;
1872 for (Value v : state.getPayloadValues(getTarget())) {
1873 if (llvm::isa<BlockArgument>(v)) {
1875 emitSilenceableError() << "cannot get defining op of block argument";
1876 diag.attachNote(v.getLoc()) << "target value";
1877 return diag;
1878 }
1879 definingOps.push_back(v.getDefiningOp());
1880 }
1881 results.set(llvm::cast<OpResult>(getResult()), definingOps);
1883}
1884
1885//===----------------------------------------------------------------------===//
1886// GetProducerOfOperand
1887//===----------------------------------------------------------------------===//
1888
1890transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1893 int64_t operandNumber = getOperandNumber();
1894 SmallVector<Operation *> producers;
1895 for (Operation *target : state.getPayloadOps(getTarget())) {
1896 Operation *producer =
1897 target->getNumOperands() <= operandNumber
1898 ? nullptr
1899 : target->getOperand(operandNumber).getDefiningOp();
1900 if (!producer) {
1902 emitSilenceableError()
1903 << "could not find a producer for operand number: " << operandNumber
1904 << " of " << *target;
1905 diag.attachNote(target->getLoc()) << "target op";
1906 return diag;
1907 }
1908 producers.push_back(producer);
1909 }
1910 results.set(llvm::cast<OpResult>(getResult()), producers);
1912}
1913
1914//===----------------------------------------------------------------------===//
1915// GetOperandOp
1916//===----------------------------------------------------------------------===//
1917
1919transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1922 SmallVector<Value> operands;
1923 for (Operation *target : state.getPayloadOps(getTarget())) {
1924 SmallVector<int64_t> operandPositions;
1926 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1927 target->getNumOperands(), operandPositions);
1928 if (diag.isSilenceableFailure()) {
1929 diag.attachNote(target->getLoc())
1930 << "while considering positions of this payload operation";
1931 return diag;
1932 }
1933 llvm::append_range(operands,
1934 llvm::map_range(operandPositions, [&](int64_t pos) {
1935 return target->getOperand(pos);
1936 }));
1937 }
1938 results.setValues(cast<OpResult>(getResult()), operands);
1940}
1941
1942LogicalResult transform::GetOperandOp::verify() {
1943 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1944 getIsInverted(), getIsAll());
1945}
1946
1947//===----------------------------------------------------------------------===//
1948// GetResultOp
1949//===----------------------------------------------------------------------===//
1950
1952transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1955 SmallVector<Value> opResults;
1956 for (Operation *target : state.getPayloadOps(getTarget())) {
1957 SmallVector<int64_t> resultPositions;
1959 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1960 target->getNumResults(), resultPositions);
1961 if (diag.isSilenceableFailure()) {
1962 diag.attachNote(target->getLoc())
1963 << "while considering positions of this payload operation";
1964 return diag;
1965 }
1966 llvm::append_range(opResults,
1967 llvm::map_range(resultPositions, [&](int64_t pos) {
1968 return target->getResult(pos);
1969 }));
1970 }
1971 results.setValues(cast<OpResult>(getResult()), opResults);
1973}
1974
1975LogicalResult transform::GetResultOp::verify() {
1976 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1977 getIsInverted(), getIsAll());
1978}
1979
1980//===----------------------------------------------------------------------===//
1981// GetTypeOp
1982//===----------------------------------------------------------------------===//
1983
1984void transform::GetTypeOp::getEffects(
1986 onlyReadsHandle(getValueMutable(), effects);
1987 producesHandle(getOperation()->getOpResults(), effects);
1988 onlyReadsPayload(effects);
1989}
1990
1992transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1996 for (Value value : state.getPayloadValues(getValue())) {
1997 Type type = value.getType();
1998 if (getElemental()) {
1999 if (auto shaped = dyn_cast<ShapedType>(type)) {
2000 type = shaped.getElementType();
2001 }
2002 }
2003 params.push_back(TypeAttr::get(type));
2004 }
2005 results.setParams(cast<OpResult>(getResult()), params);
2007}
2008
2009//===----------------------------------------------------------------------===//
2010// IncludeOp
2011//===----------------------------------------------------------------------===//
2012
2013/// Applies the transform ops contained in `block`. Maps `results` to the same
2014/// values as the operands of the block terminator.
2016applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2018 transform::TransformResults &results) {
2019 // Apply the sequenced ops one by one.
2020 for (Operation &transform : block.without_terminator()) {
2022 state.applyTransform(cast<transform::TransformOpInterface>(transform));
2023 if (result.isDefiniteFailure())
2024 return result;
2025
2026 if (result.isSilenceableFailure()) {
2027 if (mode == transform::FailurePropagationMode::Propagate) {
2028 // Propagate empty results in case of early exit.
2029 forwardEmptyOperands(&block, state, results);
2030 return result;
2031 }
2032 (void)result.silence();
2033 }
2034 }
2035
2036 // Forward the operation mapping for values yielded from the sequence to the
2037 // values produced by the sequence op.
2038 transform::detail::forwardTerminatorOperands(&block, state, results);
2040}
2041
2043transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2047 getOperation(), getTarget());
2048 assert(callee && "unverified reference to unknown symbol");
2049
2050 if (callee.isExternal())
2051 return emitDefiniteFailure() << "unresolved external named sequence";
2052
2053 // Map operands to block arguments.
2055 detail::prepareValueMappings(mappings, getOperands(), state);
2056 auto scope = state.make_region_scope(callee.getBody());
2057 for (auto &&[arg, map] :
2058 llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2059 if (failed(state.mapBlockArgument(arg, map)))
2061 }
2062
2064 callee.getBody().front(), getFailurePropagationMode(), state, results);
2065 mappings.clear();
2066 detail::prepareValueMappings(
2067 mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2068 for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2069 results.setMappedValues(result, mapping);
2070 return result;
2071}
2072
2074verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2075
2076void transform::IncludeOp::getEffects(
2078 // Always mark as modifying the payload.
2079 // TODO: a mechanism to annotate effects on payload. Even when all handles are
2080 // only read, the payload may still be modified, so we currently stay on the
2081 // conservative side and always indicate modification. This may prevent some
2082 // code reordering.
2083 modifiesPayload(effects);
2084
2085 // Results are always produced.
2086 producesHandle(getOperation()->getOpResults(), effects);
2087
2088 // Adds default effects to operands and results. This will be added if
2089 // preconditions fail so the trait verifier doesn't complain about missing
2090 // effects and the real precondition failure is reported later on.
2091 auto defaultEffects = [&] {
2092 onlyReadsHandle(getOperation()->getOpOperands(), effects);
2093 };
2094
2095 // Bail if the callee is unknown. This may run as part of the verification
2096 // process before we verified the validity of the callee or of this op.
2097 auto target =
2098 getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2099 if (!target)
2100 return defaultEffects();
2102 getOperation(), getTarget());
2103 if (!callee)
2104 return defaultEffects();
2105
2106 for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2107 if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2108 consumesHandle(getOperation()->getOpOperand(i), effects);
2109 else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2110 onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2111 }
2112}
2113
2114LogicalResult
2115transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2116 // Access through indirection and do additional checking because this may be
2117 // running before the main op verifier.
2118 auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2119 if (!targetAttr)
2120 return emitOpError() << "expects a 'target' symbol reference attribute";
2121
2122 auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2123 *this, targetAttr);
2124 if (!target)
2125 return emitOpError() << "does not reference a named transform sequence";
2126
2127 FunctionType fnType = target.getFunctionType();
2128 if (fnType.getNumInputs() != getNumOperands())
2129 return emitError("incorrect number of operands for callee");
2130
2131 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2132 if (getOperand(i).getType() != fnType.getInput(i)) {
2133 return emitOpError("operand type mismatch: expected operand type ")
2134 << fnType.getInput(i) << ", but provided "
2135 << getOperand(i).getType() << " for operand number " << i;
2136 }
2137 }
2138
2139 if (fnType.getNumResults() != getNumResults())
2140 return emitError("incorrect number of results for callee");
2141
2142 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2143 Type resultType = getResult(i).getType();
2144 Type funcType = fnType.getResult(i);
2145 if (!implementSameTransformInterface(resultType, funcType)) {
2146 return emitOpError() << "type of result #" << i
2147 << " must implement the same transform dialect "
2148 "interface as the corresponding callee result";
2149 }
2150 }
2151
2153 cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2154 /*alsoVerifyInternal=*/true)
2155 .checkAndReport();
2156}
2157
2158//===----------------------------------------------------------------------===//
2159// MatchOperationEmptyOp
2160//===----------------------------------------------------------------------===//
2161
2162DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2163 ::std::optional<::mlir::Operation *> maybeCurrent,
2165 if (!maybeCurrent.has_value()) {
2166 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2168 }
2169 LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2170 return emitSilenceableError() << "operation is not empty";
2171}
2172
2173//===----------------------------------------------------------------------===//
2174// MatchOperationNameOp
2175//===----------------------------------------------------------------------===//
2176
2177DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2178 Operation *current, transform::TransformResults &results,
2180 StringRef currentOpName = current->getName().getStringRef();
2181 for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2182 if (acceptedAttr.getValue() == currentOpName)
2184 }
2185 return emitSilenceableError() << "wrong operation name";
2186}
2187
2188//===----------------------------------------------------------------------===//
2189// MatchParamCmpIOp
2190//===----------------------------------------------------------------------===//
2191
2193transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2196 auto signedAPIntAsString = [&](const APInt &value) {
2197 std::string str;
2198 llvm::raw_string_ostream os(str);
2199 value.print(os, /*isSigned=*/true);
2200 return str;
2201 };
2202
2203 ArrayRef<Attribute> params = state.getParams(getParam());
2204 ArrayRef<Attribute> references = state.getParams(getReference());
2205
2206 if (params.size() != references.size()) {
2207 return emitSilenceableError()
2208 << "parameters have different payload lengths (" << params.size()
2209 << " vs " << references.size() << ")";
2210 }
2211
2212 for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2213 auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2214 auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2215 if (!intAttr || !refAttr) {
2216 return emitDefiniteFailure()
2217 << "non-integer parameter value not expected";
2218 }
2219 if (intAttr.getType() != refAttr.getType()) {
2220 return emitDefiniteFailure()
2221 << "mismatching integer attribute types in parameter #" << i;
2222 }
2223 APInt value = intAttr.getValue();
2224 APInt refValue = refAttr.getValue();
2225
2226 // TODO: this copy will not be necessary in C++20.
2227 int64_t position = i;
2228 auto reportError = [&](StringRef direction) {
2230 emitSilenceableError() << "expected parameter to be " << direction
2231 << " " << signedAPIntAsString(refValue)
2232 << ", got " << signedAPIntAsString(value);
2233 diag.attachNote(getParam().getLoc())
2234 << "value # " << position
2235 << " associated with the parameter defined here";
2236 return diag;
2237 };
2238
2239 switch (getPredicate()) {
2240 case MatchCmpIPredicate::eq:
2241 if (value.eq(refValue))
2242 break;
2243 return reportError("equal to");
2244 case MatchCmpIPredicate::ne:
2245 if (value.ne(refValue))
2246 break;
2247 return reportError("not equal to");
2248 case MatchCmpIPredicate::lt:
2249 if (value.slt(refValue))
2250 break;
2251 return reportError("less than");
2252 case MatchCmpIPredicate::le:
2253 if (value.sle(refValue))
2254 break;
2255 return reportError("less than or equal to");
2256 case MatchCmpIPredicate::gt:
2257 if (value.sgt(refValue))
2258 break;
2259 return reportError("greater than");
2260 case MatchCmpIPredicate::ge:
2261 if (value.sge(refValue))
2262 break;
2263 return reportError("greater than or equal to");
2264 }
2265 }
2267}
2268
2269void transform::MatchParamCmpIOp::getEffects(
2271 onlyReadsHandle(getParamMutable(), effects);
2272 onlyReadsHandle(getReferenceMutable(), effects);
2273}
2274
2275//===----------------------------------------------------------------------===//
2276// ParamConstantOp
2277//===----------------------------------------------------------------------===//
2278
2280transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2283 results.setParams(cast<OpResult>(getParam()), {getValue()});
2285}
2286
2287//===----------------------------------------------------------------------===//
2288// MergeHandlesOp
2289//===----------------------------------------------------------------------===//
2290
2292transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2295 ValueRange handles = getHandles();
2296 if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2297 SmallVector<Operation *> operations;
2298 for (Value operand : handles)
2299 llvm::append_range(operations, state.getPayloadOps(operand));
2300 if (!getDeduplicate()) {
2301 results.set(llvm::cast<OpResult>(getResult()), operations);
2303 }
2304
2305 SetVector<Operation *> uniqued(llvm::from_range, operations);
2306 results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2308 }
2309
2310 if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2312 for (Value attribute : handles)
2313 llvm::append_range(attrs, state.getParams(attribute));
2314 if (!getDeduplicate()) {
2315 results.setParams(cast<OpResult>(getResult()), attrs);
2317 }
2318
2319 SetVector<Attribute> uniqued(llvm::from_range, attrs);
2320 results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2322 }
2323
2324 assert(
2325 llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2326 "expected value handle type");
2327 SmallVector<Value> payloadValues;
2328 for (Value value : handles)
2329 llvm::append_range(payloadValues, state.getPayloadValues(value));
2330 if (!getDeduplicate()) {
2331 results.setValues(cast<OpResult>(getResult()), payloadValues);
2333 }
2334
2335 SetVector<Value> uniqued(llvm::from_range, payloadValues);
2336 results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2338}
2339
2340bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2341 // Handles may be the same if deduplicating is enabled.
2342 return getDeduplicate();
2343}
2344
2345void transform::MergeHandlesOp::getEffects(
2347 onlyReadsHandle(getHandlesMutable(), effects);
2348 producesHandle(getOperation()->getOpResults(), effects);
2349
2350 // There are no effects on the Payload IR as this is only a handle
2351 // manipulation.
2352}
2353
2354OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2355 if (getDeduplicate() || getHandles().size() != 1)
2356 return {};
2357
2358 // If deduplication is not required and there is only one operand, it can be
2359 // used directly instead of merging.
2360 return getHandles().front();
2361}
2362
2363//===----------------------------------------------------------------------===//
2364// NamedSequenceOp
2365//===----------------------------------------------------------------------===//
2366
2368transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2371 if (isExternal())
2372 return emitDefiniteFailure() << "unresolved external named sequence";
2373
2374 // Map the entry block argument to the list of operations.
2375 // Note: this is the same implementation as PossibleTopLevelTransformOp but
2376 // without attaching the interface / trait since that is tailored to a
2377 // dangling top-level op that does not get "called".
2378 auto scope = state.make_region_scope(getBody());
2379 if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
2380 state, this->getOperation(), getBody())))
2382
2383 return applySequenceBlock(getBody().front(),
2384 FailurePropagationMode::Propagate, state, results);
2385}
2386
2387void transform::NamedSequenceOp::getEffects(
2389
2390ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
2393 parser, result, /*allowVariadic=*/false,
2394 getFunctionTypeAttrName(result.name),
2395 [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2397 std::string &) { return builder.getFunctionType(inputs, results); },
2398 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2399}
2400
2401void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
2403 printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2404 getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2405 getResAttrsAttrName());
2406}
2407
2408/// Verifies that a symbol function-like transform dialect operation has the
2409/// signature and the terminator that have conforming types, i.e., types
2410/// implementing the same transform dialect type interface. If `allowExternal`
2411/// is set, allow external symbols (declarations) and don't check the terminator
2412/// as it may not exist.
2414verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2415 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2418 << "cannot be defined inside another transform op";
2419 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2420 return diag;
2421 }
2422
2423 if (op.isExternal() || op.getFunctionBody().empty()) {
2424 if (allowExternal)
2426
2427 return emitSilenceableFailure(op) << "cannot be external";
2428 }
2429
2430 if (op.getFunctionBody().front().empty())
2431 return emitSilenceableFailure(op) << "expected a non-empty body block";
2432
2433 Operation *terminator = &op.getFunctionBody().front().back();
2434 if (!isa<transform::YieldOp>(terminator)) {
2436 << "expected '"
2437 << transform::YieldOp::getOperationName()
2438 << "' as terminator";
2439 diag.attachNote(terminator->getLoc()) << "terminator";
2440 return diag;
2441 }
2442
2443 if (terminator->getNumOperands() != op.getResultTypes().size()) {
2444 return emitSilenceableFailure(terminator)
2445 << "expected terminator to have as many operands as the parent op "
2446 "has results";
2447 }
2448 for (auto [i, operandType, resultType] : llvm::zip_equal(
2449 llvm::seq<unsigned>(0, terminator->getNumOperands()),
2450 terminator->getOperands().getType(), op.getResultTypes())) {
2451 if (operandType == resultType)
2452 continue;
2453 return emitSilenceableFailure(terminator)
2454 << "the type of the terminator operand #" << i
2455 << " must match the type of the corresponding parent op result ("
2456 << operandType << " vs " << resultType << ")";
2457 }
2458
2460}
2461
2462/// Verification of a NamedSequenceOp. This does not report the error
2463/// immediately, so it can be used to check for op's well-formedness before the
2464/// verifier runs, e.g., during trait verification.
2466verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2467 if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2468 if (!parent->getAttr(
2469 transform::TransformDialect::kWithNamedSequenceAttrName)) {
2472 << "expects the parent symbol table to have the '"
2473 << transform::TransformDialect::kWithNamedSequenceAttrName
2474 << "' attribute";
2475 diag.attachNote(parent->getLoc()) << "symbol table operation";
2476 return diag;
2477 }
2478 }
2479
2480 if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2483 << "cannot be defined inside another transform op";
2484 diag.attachNote(parent.getLoc()) << "ancestor transform op";
2485 return diag;
2486 }
2487
2488 if (op.isExternal() || op.getBody().empty())
2489 return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2490 emitWarnings);
2491
2492 if (op.getBody().front().empty())
2493 return emitSilenceableFailure(op) << "expected a non-empty body block";
2494
2495 Operation *terminator = &op.getBody().front().back();
2496 if (!isa<transform::YieldOp>(terminator)) {
2498 << "expected '"
2499 << transform::YieldOp::getOperationName()
2500 << "' as terminator";
2501 diag.attachNote(terminator->getLoc()) << "terminator";
2502 return diag;
2503 }
2504
2505 if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2506 return emitSilenceableFailure(terminator)
2507 << "expected terminator to have as many operands as the parent op "
2508 "has results";
2509 }
2510 for (auto [i, operandType, resultType] :
2511 llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2512 terminator->getOperands().getType(),
2513 op.getFunctionType().getResults())) {
2514 if (operandType == resultType)
2515 continue;
2516 return emitSilenceableFailure(terminator)
2517 << "the type of the terminator operand #" << i
2518 << " must match the type of the corresponding parent op result ("
2519 << operandType << " vs " << resultType << ")";
2520 }
2521
2522 auto funcOp = cast<FunctionOpInterface>(*op);
2524 verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2525 if (!diag.succeeded())
2526 return diag;
2527
2528 return verifyYieldingSingleBlockOp(funcOp,
2529 /*allowExternal=*/true);
2530}
2531
2532LogicalResult transform::NamedSequenceOp::verify() {
2533 // Actual verification happens in a separate function for reusability.
2534 return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2535}
2536
2537template <typename FnTy>
2538static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2539 Type bbArgType, TypeRange extraBindingTypes,
2540 FnTy bodyBuilder) {
2541 SmallVector<Type> types;
2542 types.reserve(1 + extraBindingTypes.size());
2543 types.push_back(bbArgType);
2544 llvm::append_range(types, extraBindingTypes);
2545
2546 OpBuilder::InsertionGuard guard(builder);
2547 Region *region = state.regions.back().get();
2548 Block *bodyBlock =
2549 builder.createBlock(region, region->begin(), types,
2550 SmallVector<Location>(types.size(), state.location));
2551
2552 // Populate body.
2553 builder.setInsertionPointToStart(bodyBlock);
2554 if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2555 bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2556 } else {
2557 bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2558 bodyBlock->getArguments().drop_front());
2559 }
2560}
2561
2562void transform::NamedSequenceOp::build(OpBuilder &builder,
2563 OperationState &state, StringRef symName,
2564 Type rootType, TypeRange resultTypes,
2565 SequenceBodyBuilderFn bodyBuilder,
2567 ArrayRef<DictionaryAttr> argAttrs) {
2569 builder.getStringAttr(symName));
2570 state.addAttribute(getFunctionTypeAttrName(state.name),
2571 TypeAttr::get(FunctionType::get(builder.getContext(),
2572 rootType, resultTypes)));
2573 state.attributes.append(attrs.begin(), attrs.end());
2574 state.addRegion();
2575
2576 buildSequenceBody(builder, state, rootType,
2577 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2578}
2579
2580//===----------------------------------------------------------------------===//
2581// NumAssociationsOp
2582//===----------------------------------------------------------------------===//
2583
2585transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2588 size_t numAssociations =
2590 .Case([&](TransformHandleTypeInterface opHandle) {
2591 return llvm::range_size(state.getPayloadOps(getHandle()));
2592 })
2593 .Case([&](TransformValueHandleTypeInterface valueHandle) {
2594 return llvm::range_size(state.getPayloadValues(getHandle()));
2595 })
2596 .Case([&](TransformParamTypeInterface param) {
2597 return llvm::range_size(state.getParams(getHandle()));
2598 })
2599 .DefaultUnreachable("unknown kind of transform dialect type");
2600 results.setParams(cast<OpResult>(getNum()),
2601 rewriter.getI64IntegerAttr(numAssociations));
2603}
2604
2605LogicalResult transform::NumAssociationsOp::verify() {
2606 // Verify that the result type accepts an i64 attribute as payload.
2607 auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2608 return resultType
2609 .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2610 .checkAndReport();
2611}
2612
2613//===----------------------------------------------------------------------===//
2614// SelectOp
2615//===----------------------------------------------------------------------===//
2616
2618transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2622 auto payloadOps = state.getPayloadOps(getTarget());
2623 for (Operation *op : payloadOps) {
2624 if (op->getName().getStringRef() == getOpName())
2625 result.push_back(op);
2626 }
2627 results.set(cast<OpResult>(getResult()), result);
2629}
2630
2631//===----------------------------------------------------------------------===//
2632// SplitHandleOp
2633//===----------------------------------------------------------------------===//
2634
2635void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2636 Value target, int64_t numResultHandles) {
2637 result.addOperands(target);
2638 result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2639}
2640
2642transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2645 int64_t numPayloads =
2647 .Case<TransformHandleTypeInterface>([&](auto x) {
2648 return llvm::range_size(state.getPayloadOps(getHandle()));
2649 })
2650 .Case<TransformValueHandleTypeInterface>([&](auto x) {
2651 return llvm::range_size(state.getPayloadValues(getHandle()));
2652 })
2653 .Case<TransformParamTypeInterface>([&](auto x) {
2654 return llvm::range_size(state.getParams(getHandle()));
2655 })
2656 .DefaultUnreachable("unknown transform dialect type interface");
2657
2658 auto produceNumOpsError = [&]() {
2659 return emitSilenceableError()
2660 << getHandle() << " expected to contain " << this->getNumResults()
2661 << " payloads but it contains " << numPayloads << " payloads";
2662 };
2663
2664 // Fail if there are more payload ops than results and no overflow result was
2665 // specified.
2666 if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2667 return produceNumOpsError();
2668
2669 // Fail if there are more results than payload ops. Unless:
2670 // - "fail_on_payload_too_small" is set to "false", or
2671 // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2672 if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2673 (numPayloads != 0 || !getPassThroughEmptyHandle()))
2674 return produceNumOpsError();
2675
2676 // Distribute payloads.
2677 SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2678 if (getOverflowResult())
2679 resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2680
2681 auto container = [&]() {
2682 if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2683 return llvm::map_to_vector(
2684 state.getPayloadOps(getHandle()),
2685 [](Operation *op) -> MappedValue { return op; });
2686 }
2687 if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2688 return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2689 [](Value v) -> MappedValue { return v; });
2690 }
2691 assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2692 "unsupported kind of transform dialect type");
2693 return llvm::map_to_vector(state.getParams(getHandle()),
2694 [](Attribute a) -> MappedValue { return a; });
2695 }();
2696
2697 for (auto &&en : llvm::enumerate(container)) {
2698 int64_t resultNum = en.index();
2699 if (resultNum >= getNumResults())
2700 resultNum = *getOverflowResult();
2701 resultHandles[resultNum].push_back(en.value());
2702 }
2703
2704 // Set transform op results.
2705 for (auto &&it : llvm::enumerate(resultHandles))
2706 results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2707 it.value());
2708
2710}
2711
2712void transform::SplitHandleOp::getEffects(
2714 onlyReadsHandle(getHandleMutable(), effects);
2715 producesHandle(getOperation()->getOpResults(), effects);
2716 // There are no effects on the Payload IR as this is only a handle
2717 // manipulation.
2718}
2719
2720LogicalResult transform::SplitHandleOp::verify() {
2721 if (getOverflowResult().has_value() &&
2722 !(*getOverflowResult() < getNumResults()))
2723 return emitOpError("overflow_result is not a valid result index");
2724
2725 for (Type resultType : getResultTypes()) {
2726 if (implementSameTransformInterface(getHandle().getType(), resultType))
2727 continue;
2728
2729 return emitOpError("expects result types to implement the same transform "
2730 "interface as the operand type");
2731 }
2732
2733 return success();
2734}
2735
2736//===----------------------------------------------------------------------===//
2737// ReplicateOp
2738//===----------------------------------------------------------------------===//
2739
2741transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2744 unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2745 for (const auto &en : llvm::enumerate(getHandles())) {
2746 Value handle = en.value();
2747 if (isa<TransformHandleTypeInterface>(handle.getType())) {
2748 SmallVector<Operation *> current =
2749 llvm::to_vector(state.getPayloadOps(handle));
2751 payload.reserve(numRepetitions * current.size());
2752 for (unsigned i = 0; i < numRepetitions; ++i)
2753 llvm::append_range(payload, current);
2754 results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2755 } else {
2756 assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2757 "expected param type");
2758 ArrayRef<Attribute> current = state.getParams(handle);
2760 params.reserve(numRepetitions * current.size());
2761 for (unsigned i = 0; i < numRepetitions; ++i)
2762 llvm::append_range(params, current);
2763 results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2764 params);
2765 }
2766 }
2768}
2769
2770void transform::ReplicateOp::getEffects(
2772 onlyReadsHandle(getPatternMutable(), effects);
2773 onlyReadsHandle(getHandlesMutable(), effects);
2774 producesHandle(getOperation()->getOpResults(), effects);
2775}
2776
2777//===----------------------------------------------------------------------===//
2778// SequenceOp
2779//===----------------------------------------------------------------------===//
2780
2782transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2785 // Map the entry block argument to the list of operations.
2786 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2787 if (failed(mapBlockArguments(state)))
2789
2790 return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2791 results);
2792}
2793
2794static ParseResult parseSequenceOpOperands(
2795 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2796 Type &rootType,
2798 SmallVectorImpl<Type> &extraBindingTypes) {
2800 OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2801 if (!hasRoot.has_value()) {
2802 root = std::nullopt;
2803 return success();
2804 }
2805 if (failed(hasRoot.value()))
2806 return failure();
2807 root = rootOperand;
2808
2809 if (succeeded(parser.parseOptionalComma())) {
2810 if (failed(parser.parseOperandList(extraBindings)))
2811 return failure();
2812 }
2813 if (failed(parser.parseColon()))
2814 return failure();
2815
2816 // The paren is truly optional.
2817 (void)parser.parseOptionalLParen();
2818
2819 if (failed(parser.parseType(rootType))) {
2820 return failure();
2821 }
2822
2823 if (!extraBindings.empty()) {
2824 if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2825 return failure();
2826 }
2827
2828 if (extraBindingTypes.size() != extraBindings.size()) {
2829 return parser.emitError(parser.getNameLoc(),
2830 "expected types to be provided for all operands");
2831 }
2832
2833 // The paren is truly optional.
2834 (void)parser.parseOptionalRParen();
2835 return success();
2836}
2837
2839 Value root, Type rootType,
2840 ValueRange extraBindings,
2841 TypeRange extraBindingTypes) {
2842 if (!root)
2843 return;
2844
2845 printer << root;
2846 bool hasExtras = !extraBindings.empty();
2847 if (hasExtras) {
2848 printer << ", ";
2849 printer.printOperands(extraBindings);
2850 }
2851
2852 printer << " : ";
2853 if (hasExtras)
2854 printer << "(";
2855
2856 printer << rootType;
2857 if (hasExtras)
2858 printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2859}
2860
2861/// Returns `true` if the given op operand may be consuming the handle value in
2862/// the Transform IR. That is, if it may have a Free effect on it.
2864 // Conservatively assume the effect being present in absence of the interface.
2865 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2866 if (!iface)
2867 return true;
2868
2869 return isHandleConsumed(use.get(), iface);
2870}
2871
2872LogicalResult
2874 function_ref<InFlightDiagnostic()> reportError) {
2875 OpOperand *potentialConsumer = nullptr;
2876 for (OpOperand &use : value.getUses()) {
2878 continue;
2879
2880 if (!potentialConsumer) {
2881 potentialConsumer = &use;
2882 continue;
2883 }
2884
2885 InFlightDiagnostic diag = reportError()
2886 << " has more than one potential consumer";
2887 diag.attachNote(potentialConsumer->getOwner()->getLoc())
2888 << "used here as operand #" << potentialConsumer->getOperandNumber();
2889 diag.attachNote(use.getOwner()->getLoc())
2890 << "used here as operand #" << use.getOperandNumber();
2891 return diag;
2892 }
2893
2894 return success();
2895}
2896
2897LogicalResult transform::SequenceOp::verify() {
2898 assert(getBodyBlock()->getNumArguments() >= 1 &&
2899 "the number of arguments must have been verified to be more than 1 by "
2900 "PossibleTopLevelTransformOpTrait");
2901
2902 if (!getRoot() && !getExtraBindings().empty()) {
2903 return emitOpError()
2904 << "does not expect extra operands when used as top-level";
2905 }
2906
2907 // Check if a block argument has more than one consuming use.
2908 for (BlockArgument arg : getBodyBlock()->getArguments()) {
2909 if (failed(checkDoubleConsume(arg, [this, arg]() {
2910 return (emitOpError() << "block argument #" << arg.getArgNumber());
2911 }))) {
2912 return failure();
2913 }
2914 }
2915
2916 // Check properties of the nested operations they cannot check themselves.
2917 for (Operation &child : *getBodyBlock()) {
2918 if (!isa<TransformOpInterface>(child) &&
2919 &child != &getBodyBlock()->back()) {
2921 emitOpError()
2922 << "expected children ops to implement TransformOpInterface";
2923 diag.attachNote(child.getLoc()) << "op without interface";
2924 return diag;
2925 }
2926
2927 for (OpResult result : child.getResults()) {
2928 auto report = [&]() {
2929 return (child.emitError() << "result #" << result.getResultNumber());
2930 };
2931 if (failed(checkDoubleConsume(result, report)))
2932 return failure();
2933 }
2934 }
2935
2936 if (!getBodyBlock()->mightHaveTerminator())
2937 return emitOpError() << "expects to have a terminator in the body";
2938
2939 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2940 getOperation()->getResultTypes()) {
2942 << "expects the types of the terminator operands "
2943 "to match the types of the result";
2944 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2945 return diag;
2946 }
2947 return success();
2948}
2949
2950void transform::SequenceOp::getEffects(
2953}
2954
2956transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2957 assert(successor.getSuccessor() == &getBody() && "unexpected region index");
2958 if (getOperation()->getNumOperands() > 0)
2959 return getOperation()->getOperands();
2960 return OperandRange(getOperation()->operand_end(),
2961 getOperation()->operand_end());
2962}
2963
2964void transform::SequenceOp::getSuccessorRegions(
2966 if (point.isParent()) {
2967 Region *bodyRegion = &getBody();
2968 regions.emplace_back(bodyRegion, getNumOperands() != 0
2969 ? bodyRegion->getArguments()
2971 return;
2972 }
2973
2975 &getBody() &&
2976 "unexpected region index");
2977 regions.emplace_back(getOperation(), getOperation()->getResults());
2978}
2979
2980void transform::SequenceOp::getRegionInvocationBounds(
2982 (void)operands;
2983 bounds.emplace_back(1, 1);
2984}
2985
2986void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2987 TypeRange resultTypes,
2988 FailurePropagationMode failurePropagationMode,
2989 Value root,
2990 SequenceBodyBuilderFn bodyBuilder) {
2991 build(builder, state, resultTypes, failurePropagationMode, root,
2992 /*extra_bindings=*/ValueRange());
2993 Type bbArgType = root.getType();
2994 buildSequenceBody(builder, state, bbArgType,
2995 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2996}
2997
2998void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2999 TypeRange resultTypes,
3000 FailurePropagationMode failurePropagationMode,
3001 Value root, ValueRange extraBindings,
3002 SequenceBodyBuilderArgsFn bodyBuilder) {
3003 build(builder, state, resultTypes, failurePropagationMode, root,
3004 extraBindings);
3005 buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3006 bodyBuilder);
3007}
3008
3009void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3010 TypeRange resultTypes,
3011 FailurePropagationMode failurePropagationMode,
3012 Type bbArgType,
3013 SequenceBodyBuilderFn bodyBuilder) {
3014 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3015 /*extra_bindings=*/ValueRange());
3016 buildSequenceBody(builder, state, bbArgType,
3017 /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3018}
3019
3020void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3021 TypeRange resultTypes,
3022 FailurePropagationMode failurePropagationMode,
3023 Type bbArgType, TypeRange extraBindingTypes,
3024 SequenceBodyBuilderArgsFn bodyBuilder) {
3025 build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3026 /*extra_bindings=*/ValueRange());
3027 buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3028}
3029
3030//===----------------------------------------------------------------------===//
3031// PrintOp
3032//===----------------------------------------------------------------------===//
3033
3034void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3035 StringRef name) {
3036 if (!name.empty())
3037 result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3038}
3039
3040void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3041 Value target, StringRef name) {
3042 result.addOperands({target});
3043 build(builder, result, name);
3044}
3045
3047transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3050 llvm::outs() << "[[[ IR printer: ";
3051 if (getName().has_value())
3052 llvm::outs() << *getName() << " ";
3053
3054 OpPrintingFlags printFlags;
3055 if (getAssumeVerified().value_or(false))
3056 printFlags.assumeVerified();
3057 if (getUseLocalScope().value_or(false))
3058 printFlags.useLocalScope();
3059 if (getSkipRegions().value_or(false))
3060 printFlags.skipRegions();
3061
3062 if (!getTarget()) {
3063 llvm::outs() << "top-level ]]]\n";
3064 state.getTopLevel()->print(llvm::outs(), printFlags);
3065 llvm::outs() << "\n";
3066 llvm::outs().flush();
3068 }
3069
3070 llvm::outs() << "]]]\n";
3071 for (Operation *target : state.getPayloadOps(getTarget())) {
3072 target->print(llvm::outs(), printFlags);
3073 llvm::outs() << "\n";
3074 }
3075
3076 llvm::outs().flush();
3078}
3079
3080void transform::PrintOp::getEffects(
3082 // We don't really care about mutability here, but `getTarget` now
3083 // unconditionally casts to a specific type before verification could run
3084 // here.
3085 if (!getTargetMutable().empty())
3086 onlyReadsHandle(getTargetMutable()[0], effects);
3087 onlyReadsPayload(effects);
3088
3089 // There is no resource for stderr file descriptor, so just declare print
3090 // writes into the default resource.
3091 effects.emplace_back(MemoryEffects::Write::get());
3092}
3093
3094//===----------------------------------------------------------------------===//
3095// VerifyOp
3096//===----------------------------------------------------------------------===//
3097
3099transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3105 << "failed to verify payload op";
3106 diag.attachNote(target->getLoc()) << "payload op";
3107 return diag;
3108 }
3110}
3111
3112void transform::VerifyOp::getEffects(
3114 transform::onlyReadsHandle(getTargetMutable(), effects);
3115}
3116
3117//===----------------------------------------------------------------------===//
3118// YieldOp
3119//===----------------------------------------------------------------------===//
3120
3121void transform::YieldOp::getEffects(
3123 onlyReadsHandle(getOperandsMutable(), effects);
3124}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a dlti.dl_entry attribute.
Definition DLTI.cpp:38
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue > > blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getType() const
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
unsigned getNumOperands()
Definition Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getOpResults()
Definition Operation.h:420
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition Region.cpp:62
iterator begin()
Definition Region.h:55
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition CSE.cpp:378
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.