MLIR 22.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR standard and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
18#include "mlir/Pass/Pass.h"
20
21using namespace mlir;
22
23namespace {
24
25/// Helper function for converting branch ops. This function converts the
26/// signature of the given block. If the new block signature is different from
27/// `expectedTypes`, returns "failure".
28static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
29 const TypeConverter *converter,
30 Operation *branchOp, Block *block,
31 TypeRange expectedTypes) {
32 assert(converter && "expected non-null type converter");
33 assert(!block->isEntryBlock() && "entry blocks have no predecessors");
34
35 // There is nothing to do if the types already match.
36 if (block->getArgumentTypes() == expectedTypes)
37 return block;
38
39 // Compute the new block argument types and convert the block.
40 std::optional<TypeConverter::SignatureConversion> conversion =
41 converter->convertBlockSignature(block);
42 if (!conversion)
43 return rewriter.notifyMatchFailure(branchOp,
44 "could not compute block signature");
45 if (expectedTypes != conversion->getConvertedTypes())
46 return rewriter.notifyMatchFailure(
47 branchOp,
48 "mismatch between adaptor operand types and computed block signature");
49 return rewriter.applySignatureConversion(block, *conversion, converter);
50}
51
52/// Flatten the given value ranges into a single vector of values.
55 for (const ValueRange &vals : values)
56 llvm::append_range(result, vals);
57 return result;
58}
59
60/// Convert the destination block signature (if necessary) and change the
61/// operands of the branch op.
62struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
63 using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
64
65 LogicalResult
66 matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
67 ConversionPatternRewriter &rewriter) const override {
68 SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
69 FailureOr<Block *> convertedBlock =
70 getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
71 TypeRange(ValueRange(flattenedAdaptor)));
72 if (failed(convertedBlock))
73 return failure();
74 rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
75 *convertedBlock);
76 return success();
77 }
78};
79
80/// Convert the destination block signatures (if necessary) and change the
81/// operands of the branch op.
82struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
83 using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
84
85 LogicalResult
86 matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter) const override {
88 SmallVector<Value> flattenedAdaptorTrue =
89 flattenValues(adaptor.getTrueDestOperands());
90 SmallVector<Value> flattenedAdaptorFalse =
91 flattenValues(adaptor.getFalseDestOperands());
92 if (!llvm::hasSingleElement(adaptor.getCondition()))
93 return rewriter.notifyMatchFailure(op,
94 "expected single element condition");
95 FailureOr<Block *> convertedTrueBlock =
96 getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
97 TypeRange(ValueRange(flattenedAdaptorTrue)));
98 if (failed(convertedTrueBlock))
99 return failure();
100 FailureOr<Block *> convertedFalseBlock =
101 getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
102 TypeRange(ValueRange(flattenedAdaptorFalse)));
103 if (failed(convertedFalseBlock))
104 return failure();
105 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
106 op, llvm::getSingleElement(adaptor.getCondition()),
107 flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
108 *convertedTrueBlock, *convertedFalseBlock);
109 return success();
110 }
111};
112
113/// Convert the destination block signatures (if necessary) and change the
114/// operands of the switch op.
115struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
116 using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
117
118 LogicalResult
119 matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override {
121 // Get or convert default block.
122 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
123 rewriter, getTypeConverter(), op, op.getDefaultDestination(),
124 TypeRange(adaptor.getDefaultOperands()));
125 if (failed(convertedDefaultBlock))
126 return failure();
127
128 // Get or convert all case blocks.
129 SmallVector<Block *> caseDestinations;
130 SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
131 for (auto it : llvm::enumerate(op.getCaseDestinations())) {
132 Block *b = it.value();
133 FailureOr<Block *> convertedBlock =
134 getConvertedBlock(rewriter, getTypeConverter(), op, b,
135 TypeRange(caseOperands[it.index()]));
136 if (failed(convertedBlock))
137 return failure();
138 caseDestinations.push_back(*convertedBlock);
139 }
140
141 rewriter.replaceOpWithNewOp<cf::SwitchOp>(
142 op, adaptor.getFlag(), *convertedDefaultBlock,
143 adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
144 caseDestinations, caseOperands);
145 return success();
146 }
147};
148
149} // namespace
150
152 const TypeConverter &typeConverter, RewritePatternSet &patterns,
153 PatternBenefit benefit) {
154 patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
155 typeConverter, patterns.getContext(), benefit);
156}
157
159 const TypeConverter &typeConverter, ConversionTarget &target) {
160 target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
161 [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
162}
163
return success()
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
void populateCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of CF operations based on the provided type conver...
void populateCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Similar to populateCFStructuralTypeConversionsAndLegality but does not populate the conversion target...
void populateCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for CF structural type conversions and sets up the provided ConversionTarget with ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns