MLIR  22.0.0git
StructuralTypeConversions.cpp
Go to the documentation of this file.
1 //===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Pass/Pass.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 /// Helper function for converting branch ops. This function converts the
26 /// signature of the given block. If the new block signature is different from
27 /// `expectedTypes`, returns "failure".
28 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
29  const TypeConverter *converter,
30  Operation *branchOp, Block *block,
31  TypeRange expectedTypes) {
32  assert(converter && "expected non-null type converter");
33  assert(!block->isEntryBlock() && "entry blocks have no predecessors");
34 
35  // There is nothing to do if the types already match.
36  if (block->getArgumentTypes() == expectedTypes)
37  return block;
38 
39  // Compute the new block argument types and convert the block.
40  std::optional<TypeConverter::SignatureConversion> conversion =
41  converter->convertBlockSignature(block);
42  if (!conversion)
43  return rewriter.notifyMatchFailure(branchOp,
44  "could not compute block signature");
45  if (expectedTypes != conversion->getConvertedTypes())
46  return rewriter.notifyMatchFailure(
47  branchOp,
48  "mismatch between adaptor operand types and computed block signature");
49  return rewriter.applySignatureConversion(block, *conversion, converter);
50 }
51 
52 /// Flatten the given value ranges into a single vector of values.
54  SmallVector<Value> result;
55  for (const ValueRange &vals : values)
56  llvm::append_range(result, vals);
57  return result;
58 }
59 
60 /// Convert the destination block signature (if necessary) and change the
61 /// operands of the branch op.
62 struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
64 
65  LogicalResult
66  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
67  ConversionPatternRewriter &rewriter) const override {
68  SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
69  FailureOr<Block *> convertedBlock =
70  getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
71  TypeRange(ValueRange(flattenedAdaptor)));
72  if (failed(convertedBlock))
73  return failure();
74  rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
75  *convertedBlock);
76  return success();
77  }
78 };
79 
80 /// Convert the destination block signatures (if necessary) and change the
81 /// operands of the branch op.
82 struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
84 
85  LogicalResult
86  matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
87  ConversionPatternRewriter &rewriter) const override {
88  SmallVector<Value> flattenedAdaptorTrue =
89  flattenValues(adaptor.getTrueDestOperands());
90  SmallVector<Value> flattenedAdaptorFalse =
91  flattenValues(adaptor.getFalseDestOperands());
92  if (!llvm::hasSingleElement(adaptor.getCondition()))
93  return rewriter.notifyMatchFailure(op,
94  "expected single element condition");
95  FailureOr<Block *> convertedTrueBlock =
96  getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
97  TypeRange(ValueRange(flattenedAdaptorTrue)));
98  if (failed(convertedTrueBlock))
99  return failure();
100  FailureOr<Block *> convertedFalseBlock =
101  getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
102  TypeRange(ValueRange(flattenedAdaptorFalse)));
103  if (failed(convertedFalseBlock))
104  return failure();
105  rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
106  op, llvm::getSingleElement(adaptor.getCondition()),
107  flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
108  *convertedTrueBlock, *convertedFalseBlock);
109  return success();
110  }
111 };
112 
113 /// Convert the destination block signatures (if necessary) and change the
114 /// operands of the switch op.
115 struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
117 
118  LogicalResult
119  matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
120  ConversionPatternRewriter &rewriter) const override {
121  // Get or convert default block.
122  FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
123  rewriter, getTypeConverter(), op, op.getDefaultDestination(),
124  TypeRange(adaptor.getDefaultOperands()));
125  if (failed(convertedDefaultBlock))
126  return failure();
127 
128  // Get or convert all case blocks.
129  SmallVector<Block *> caseDestinations;
130  SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
131  for (auto it : llvm::enumerate(op.getCaseDestinations())) {
132  Block *b = it.value();
133  FailureOr<Block *> convertedBlock =
134  getConvertedBlock(rewriter, getTypeConverter(), op, b,
135  TypeRange(caseOperands[it.index()]));
136  if (failed(convertedBlock))
137  return failure();
138  caseDestinations.push_back(*convertedBlock);
139  }
140 
141  rewriter.replaceOpWithNewOp<cf::SwitchOp>(
142  op, adaptor.getFlag(), *convertedDefaultBlock,
143  adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
144  caseDestinations, caseOperands);
145  return success();
146  }
147 };
148 
149 } // namespace
150 
152  const TypeConverter &typeConverter, RewritePatternSet &patterns,
153  PatternBenefit benefit) {
154  patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
155  typeConverter, patterns.getContext(), benefit);
156 }
157 
159  const TypeConverter &typeConverter, ConversionTarget &target) {
160  target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
161  [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
162 }
163 
165  const TypeConverter &typeConverter, RewritePatternSet &patterns,
166  ConversionTarget &target, PatternBenefit benefit) {
167  populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
168  populateCFStructuralTypeConversionTarget(typeConverter, target);
169 }
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion class.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
void populateCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of CF operations based on the provided type conver...
void populateCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Similar to populateCFStructuralTypeConversionsAndLegality but does not populate the conversion target...
void populateCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for CF structural type conversions and sets up the provided ConversionTarget with ...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
Definition: XeGPUUtils.cpp:32
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns