MLIR  20.0.0git
ControlFlowToSPIRV.cpp
Go to the documentation of this file.
1 //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "../SPIRVCommon/Pattern.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "cf-to-spirv-pattern"
27 
28 using namespace mlir;
29 
30 /// Legailze target block arguments.
31 static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
32  PatternRewriter &rewriter,
33  const TypeConverter &converter) {
34  auto builder = OpBuilder::atBlockBegin(&block);
35  for (unsigned i = 0; i < block.getNumArguments(); ++i) {
36  BlockArgument arg = block.getArgument(i);
37  if (converter.isLegal(arg.getType()))
38  continue;
39  Type ty = arg.getType();
40  Type newTy = converter.convertType(ty);
41  if (!newTy) {
42  return rewriter.notifyMatchFailure(
43  op, llvm::formatv("failed to legalize type for argument {0})", arg));
44  }
45  unsigned argNum = arg.getArgNumber();
46  Location loc = arg.getLoc();
47  Value newArg = block.insertArgument(argNum, newTy, loc);
48  Value convertedValue = converter.materializeSourceConversion(
49  builder, op->getLoc(), ty, newArg);
50  if (!convertedValue) {
51  return rewriter.notifyMatchFailure(
52  op, llvm::formatv("failed to cast new argument {0} to type {1})",
53  newArg, ty));
54  }
55  arg.replaceAllUsesWith(convertedValue);
56  block.eraseArgument(argNum + 1);
57  }
58  return success();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Operation conversion
63 //===----------------------------------------------------------------------===//
64 
65 namespace {
66 /// Converts cf.br to spirv.Branch.
67 struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
69 
70  LogicalResult
71  matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
72  ConversionPatternRewriter &rewriter) const override {
73  if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
74  *getTypeConverter())))
75  return failure();
76 
77  rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
78  adaptor.getDestOperands());
79  return success();
80  }
81 };
82 
83 /// Converts cf.cond_br to spirv.BranchConditional.
84 struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
86 
87  LogicalResult
88  matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const override {
90  if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
91  *getTypeConverter())))
92  return failure();
93 
94  if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
95  *getTypeConverter())))
96  return failure();
97 
98  rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
99  op, adaptor.getCondition(), op.getTrueDest(),
100  adaptor.getTrueDestOperands(), op.getFalseDest(),
101  adaptor.getFalseDestOperands());
102  return success();
103  }
104 };
105 } // namespace
106 
107 //===----------------------------------------------------------------------===//
108 // Pattern population
109 //===----------------------------------------------------------------------===//
110 
112  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
113  MLIRContext *context = patterns.getContext();
114 
115  patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
116 }
static LogicalResult legalizeBlockArguments(Block &block, Operation *op, PatternRewriter &rewriter, const TypeConverter &converter)
Legailze target block arguments.
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:189
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:195
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:249
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:173
void populateControlFlowToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating ControlFLow ops to SPIR-V ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns