MLIR  22.0.0git
ControlFlowToSPIRV.cpp
Go to the documentation of this file.
1 //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/Support/FormatVariadic.h"
22 
23 #define DEBUG_TYPE "cf-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 /// Legailze target block arguments.
28 static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
29  PatternRewriter &rewriter,
30  const TypeConverter &converter) {
31  auto builder = OpBuilder::atBlockBegin(&block);
32  for (unsigned i = 0; i < block.getNumArguments(); ++i) {
33  BlockArgument arg = block.getArgument(i);
34  if (converter.isLegal(arg.getType()))
35  continue;
36  Type ty = arg.getType();
37  Type newTy = converter.convertType(ty);
38  if (!newTy) {
39  return rewriter.notifyMatchFailure(
40  op, llvm::formatv("failed to legalize type for argument {0})", arg));
41  }
42  unsigned argNum = arg.getArgNumber();
43  Location loc = arg.getLoc();
44  Value newArg = block.insertArgument(argNum, newTy, loc);
45  Value convertedValue = converter.materializeSourceConversion(
46  builder, op->getLoc(), ty, newArg);
47  if (!convertedValue) {
48  return rewriter.notifyMatchFailure(
49  op, llvm::formatv("failed to cast new argument {0} to type {1})",
50  newArg, ty));
51  }
52  arg.replaceAllUsesWith(convertedValue);
53  block.eraseArgument(argNum + 1);
54  }
55  return success();
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Operation conversion
60 //===----------------------------------------------------------------------===//
61 
62 namespace {
63 /// Converts cf.br to spirv.Branch.
64 struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
66 
67  LogicalResult
68  matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
69  ConversionPatternRewriter &rewriter) const override {
70  if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
71  *getTypeConverter())))
72  return failure();
73 
74  rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
75  adaptor.getDestOperands());
76  return success();
77  }
78 };
79 
80 /// Converts cf.cond_br to spirv.BranchConditional.
81 struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
83 
84  LogicalResult
85  matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
86  ConversionPatternRewriter &rewriter) const override {
87  if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
88  *getTypeConverter())))
89  return failure();
90 
91  if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
92  *getTypeConverter())))
93  return failure();
94 
95  rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
96  op, adaptor.getCondition(), op.getTrueDest(),
97  adaptor.getTrueDestOperands(), op.getFalseDest(),
98  adaptor.getFalseDestOperands());
99  return success();
100  }
101 };
102 } // namespace
103 
104 //===----------------------------------------------------------------------===//
105 // Pattern population
106 //===----------------------------------------------------------------------===//
107 
109  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
110  MLIRContext *context = patterns.getContext();
111 
112  patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
113 }
static LogicalResult legalizeBlockArguments(Block &block, Operation *op, PatternRewriter &rewriter, const TypeConverter &converter)
Legailze target block arguments.
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:187
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:193
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:238
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:149
void populateControlFlowToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating ControlFLow ops to SPIR-V ops.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns