MLIR 22.0.0git
ControlFlowToSPIRV.cpp
Go to the documentation of this file.
1//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert standard dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/AffineMap.h"
21#include "llvm/Support/FormatVariadic.h"
22
23#define DEBUG_TYPE "cf-to-spirv-pattern"
24
25using namespace mlir;
26
27/// Legailze target block arguments.
28static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
29 PatternRewriter &rewriter,
30 const TypeConverter &converter) {
31 auto builder = OpBuilder::atBlockBegin(&block);
32 for (unsigned i = 0; i < block.getNumArguments(); ++i) {
33 BlockArgument arg = block.getArgument(i);
34 if (converter.isLegal(arg.getType()))
35 continue;
36 Type ty = arg.getType();
37 Type newTy = converter.convertType(ty);
38 if (!newTy) {
39 return rewriter.notifyMatchFailure(
40 op, llvm::formatv("failed to legalize type for argument {0})", arg));
41 }
42 unsigned argNum = arg.getArgNumber();
43 Location loc = arg.getLoc();
44 Value newArg = block.insertArgument(argNum, newTy, loc);
45 Value convertedValue = converter.materializeSourceConversion(
46 builder, op->getLoc(), ty, newArg);
47 if (!convertedValue) {
48 return rewriter.notifyMatchFailure(
49 op, llvm::formatv("failed to cast new argument {0} to type {1})",
50 newArg, ty));
51 }
52 arg.replaceAllUsesWith(convertedValue);
53 block.eraseArgument(argNum + 1);
54 }
55 return success();
56}
57
58//===----------------------------------------------------------------------===//
59// Operation conversion
60//===----------------------------------------------------------------------===//
61
62namespace {
63/// Converts cf.br to spirv.Branch.
64struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
65 using Base::Base;
66
67 LogicalResult
68 matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
69 ConversionPatternRewriter &rewriter) const override {
70 if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
71 *getTypeConverter())))
72 return failure();
73
74 rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
75 adaptor.getDestOperands());
76 return success();
77 }
78};
79
80/// Converts cf.cond_br to spirv.BranchConditional.
81struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
82 using Base::Base;
83
84 LogicalResult
85 matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
86 ConversionPatternRewriter &rewriter) const override {
87 if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
88 *getTypeConverter())))
89 return failure();
90
91 if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
92 *getTypeConverter())))
93 return failure();
94
95 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
96 op, adaptor.getCondition(), op.getTrueDest(),
97 adaptor.getTrueDestOperands(), op.getFalseDest(),
98 adaptor.getFalseDestOperands());
99 return success();
100 }
101};
102} // namespace
103
104//===----------------------------------------------------------------------===//
105// Pattern population
106//===----------------------------------------------------------------------===//
107
109 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
110 MLIRContext *context = patterns.getContext();
111
112 patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
113}
return success()
static LogicalResult legalizeBlockArguments(Block &block, Operation *op, PatternRewriter &rewriter, const TypeConverter &converter)
Legailze target block arguments.
This class represents an argument of a Block.
Definition Value.h:309
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition Block.cpp:187
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:193
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:240
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
void populateControlFlowToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating ControlFLow ops to SPIR-V ops.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns