MLIR  19.0.0git
ControlFlowToSPIRV.cpp
Go to the documentation of this file.
1 //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "../SPIRVCommon/Pattern.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define DEBUG_TYPE "cf-to-spirv-pattern"
28 
29 using namespace mlir;
30 
31 /// Legailze target block arguments.
33  PatternRewriter &rewriter,
34  const TypeConverter &converter) {
35  auto builder = OpBuilder::atBlockBegin(&block);
36  for (unsigned i = 0; i < block.getNumArguments(); ++i) {
37  BlockArgument arg = block.getArgument(i);
38  if (converter.isLegal(arg.getType()))
39  continue;
40  Type ty = arg.getType();
41  Type newTy = converter.convertType(ty);
42  if (!newTy) {
43  return rewriter.notifyMatchFailure(
44  op, llvm::formatv("failed to legalize type for argument {0})", arg));
45  }
46  unsigned argNum = arg.getArgNumber();
47  Location loc = arg.getLoc();
48  Value newArg = block.insertArgument(argNum, newTy, loc);
49  Value convertedValue = converter.materializeSourceConversion(
50  builder, op->getLoc(), ty, newArg);
51  if (!convertedValue) {
52  return rewriter.notifyMatchFailure(
53  op, llvm::formatv("failed to cast new argument {0} to type {1})",
54  newArg, ty));
55  }
56  arg.replaceAllUsesWith(convertedValue);
57  block.eraseArgument(argNum + 1);
58  }
59  return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Operation conversion
64 //===----------------------------------------------------------------------===//
65 
66 namespace {
67 /// Converts cf.br to spirv.Branch.
68 struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
70 
72  matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
73  ConversionPatternRewriter &rewriter) const override {
74  if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
75  *getTypeConverter())))
76  return failure();
77 
78  rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
79  adaptor.getDestOperands());
80  return success();
81  }
82 };
83 
84 /// Converts cf.cond_br to spirv.BranchConditional.
85 struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
87 
89  matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
90  ConversionPatternRewriter &rewriter) const override {
91  if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
92  *getTypeConverter())))
93  return failure();
94 
95  if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
96  *getTypeConverter())))
97  return failure();
98 
99  rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
100  op, adaptor.getCondition(), op.getTrueDest(),
101  adaptor.getTrueDestOperands(), op.getFalseDest(),
102  adaptor.getFalseDestOperands());
103  return success();
104  }
105 };
106 } // namespace
107 
108 //===----------------------------------------------------------------------===//
109 // Pattern population
110 //===----------------------------------------------------------------------===//
111 
113  SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
114  MLIRContext *context = patterns.getContext();
115 
116  patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
117 }
static LogicalResult legalizeBlockArguments(Block &block, Operation *op, PatternRewriter &rewriter, const TypeConverter &converter)
Legailze target block arguments.
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:186
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:192
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:242
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion from builtin types to SPIR-V types for shader interface.
Type conversion class.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:173
void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating ControlFLow ops to SPIR-V ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26