18#include "llvm/Support/FormatVariadic.h"
42 impl = std::make_unique<::ScfToSPIRVContextImpl>();
58template <
typename ScfOp,
typename OpTy>
59void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
60 ConversionPatternRewriter &rewriter,
65 auto &allocas = scfToSPIRVContext->
outputVars[newOp];
70 for (
Type convertedType : returnTypes) {
73 rewriter.setInsertionPoint(newOp);
74 auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
75 spirv::StorageClass::Function,
77 allocas.push_back(alloc);
78 rewriter.setInsertionPointAfter(newOp);
79 Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
80 resultValue.push_back(loadResult);
82 rewriter.replaceOp(scfOp, resultValue);
94template <
typename OpTy>
95class SCFToSPIRVPattern :
public OpConversionPattern<OpTy> {
97 SCFToSPIRVPattern(MLIRContext *context,
const SPIRVTypeConverter &converter,
98 ScfToSPIRVContextImpl *scfToSPIRVContext)
99 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
100 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
103 ScfToSPIRVContextImpl *scfToSPIRVContext;
118 const SPIRVTypeConverter &typeConverter;
126struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
127 using SCFToSPIRVPattern::SCFToSPIRVPattern;
130 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter)
const override {
137 auto loc = forOp.getLoc();
139 spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
140 loopOp.addEntryAndMergeBlock(rewriter);
142 OpBuilder::InsertionGuard guard(rewriter);
144 Block *header = rewriter.createBlock(&loopOp.getBody(),
145 getBlockIt(loopOp.getBody(), 1));
146 rewriter.setInsertionPointAfter(loopOp);
149 Value adapLowerBound = adaptor.getLowerBound();
150 BlockArgument newIndVar =
152 for (Value arg : adaptor.getInitArgs())
154 Block *body = forOp.getBody();
159 TypeConverter::SignatureConversion signatureConverter(
161 signatureConverter.remapInput(0, newIndVar);
163 signatureConverter.remapInput(i, header->
getArgument(i));
164 body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
169 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
170 getBlockIt(loopOp.getBody(), 2));
172 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
173 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
175 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
176 spirv::BranchOp::create(rewriter, loc, header, args);
179 rewriter.setInsertionPointToEnd(header);
180 auto *mergeBlock = loopOp.getMergeBlock();
182 if (forOp.getUnsignedCmp()) {
183 cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
184 newIndVar, adaptor.getUpperBound());
186 cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
187 newIndVar, adaptor.getUpperBound());
190 spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
191 ArrayRef<Value>(), mergeBlock,
196 Block *continueBlock = loopOp.getContinueBlock();
197 rewriter.setInsertionPointToEnd(continueBlock);
200 Value updatedIndVar = spirv::IAddOp::create(
201 rewriter, loc, newIndVar.
getType(), newIndVar, adaptor.getStep());
202 spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
208 SmallVector<Type, 8> initTypes;
209 for (
auto arg : adaptor.getInitArgs())
210 initTypes.push_back(arg.getType());
211 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
223struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
224 using SCFToSPIRVPattern::SCFToSPIRVPattern;
227 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter)
const override {
232 auto loc = ifOp.getLoc();
235 SmallVector<Type, 8> returnTypes;
236 for (
auto result : ifOp.getResults()) {
237 auto convertedType = typeConverter.convertType(
result.getType());
239 return rewriter.notifyMatchFailure(
241 llvm::formatv(
"failed to convert type '{0}'",
result.getType()));
243 returnTypes.push_back(convertedType);
248 auto selectionOp = spirv::SelectionOp::create(
249 rewriter, loc, spirv::SelectionControl::None);
250 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
251 selectionOp.getBody().end());
252 spirv::MergeOp::create(rewriter, loc);
254 OpBuilder::InsertionGuard guard(rewriter);
255 auto *selectionHeaderBlock =
256 rewriter.createBlock(&selectionOp.getBody().front());
259 auto &thenRegion = ifOp.getThenRegion();
260 auto *thenBlock = &thenRegion.front();
261 rewriter.setInsertionPointToEnd(&thenRegion.back());
262 spirv::BranchOp::create(rewriter, loc, mergeBlock);
263 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
265 auto *elseBlock = mergeBlock;
268 if (!ifOp.getElseRegion().empty()) {
269 auto &elseRegion = ifOp.getElseRegion();
270 elseBlock = &elseRegion.front();
271 rewriter.setInsertionPointToEnd(&elseRegion.back());
272 spirv::BranchOp::create(rewriter, loc, mergeBlock);
273 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
277 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
278 spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
279 thenBlock, ArrayRef<Value>(), elseBlock,
282 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
292struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
294 using SCFToSPIRVPattern::SCFToSPIRVPattern;
297 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter)
const override {
301 Operation *parent = terminatorOp->getParentOp();
305 scf::SCFDialect::getDialectNamespace() &&
306 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
307 return rewriter.notifyMatchFailure(
309 llvm::formatv(
"conversion not supported for parent op: '{0}'",
314 if (!operands.empty()) {
315 auto &allocas = scfToSPIRVContext->
outputVars[parent];
316 if (allocas.size() != operands.size())
319 auto loc = terminatorOp.getLoc();
320 for (
unsigned i = 0, e = operands.size(); i < e; i++)
321 spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
322 if (isa<spirv::LoopOp>(parent)) {
325 auto br = cast<spirv::BranchOp>(
326 rewriter.getInsertionBlock()->getTerminator());
327 SmallVector<Value, 8> args(br.getBlockArguments());
328 args.append(operands.begin(), operands.end());
329 rewriter.setInsertionPoint(br);
330 spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
332 rewriter.eraseOp(br);
335 rewriter.eraseOp(terminatorOp);
344struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
345 using SCFToSPIRVPattern::SCFToSPIRVPattern;
348 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override {
350 auto loc = whileOp.getLoc();
352 spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
353 loopOp.addEntryAndMergeBlock(rewriter);
355 Region &beforeRegion = whileOp.getBefore();
356 Region &afterRegion = whileOp.getAfter();
358 if (
failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
359 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
360 return rewriter.notifyMatchFailure(whileOp,
361 "Failed to convert region types");
363 OpBuilder::InsertionGuard guard(rewriter);
365 Block &entryBlock = *loopOp.getEntryBlock();
368 Block &mergeBlock = *loopOp.getMergeBlock();
370 auto cond = cast<scf::ConditionOp>(beforeBlock.
getTerminator());
371 SmallVector<Value> condArgs;
372 if (
failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
375 Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
380 SmallVector<Value> yieldArgs;
381 if (
failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
385 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
386 getBlockIt(loopOp.getBody(), 1));
389 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
390 getBlockIt(loopOp.getBody(), 2));
393 rewriter.setInsertionPointToEnd(&entryBlock);
394 spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
396 auto condLoc = cond.getLoc();
398 SmallVector<Value> resultValues(condArgs.size());
407 for (
const auto &it : llvm::enumerate(condArgs)) {
408 auto res = it.value();
414 rewriter.setInsertionPoint(loopOp);
415 auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
416 spirv::StorageClass::Function,
420 rewriter.setInsertionPointAfter(loopOp);
421 auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
422 resultValues[i] = loadResult;
425 rewriter.setInsertionPointToEnd(&beforeBlock);
426 spirv::StoreOp::create(rewriter, condLoc, alloc, res);
429 rewriter.setInsertionPointToEnd(&beforeBlock);
430 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
431 cond, conditionVal, &afterBlock, condArgs, &mergeBlock,
ValueRange());
434 rewriter.setInsertionPointToEnd(&afterBlock);
435 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
438 rewriter.replaceOp(whileOp, resultValues);
451 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
452 WhileOpConversion>(
patterns.getContext(), typeConverter,
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OperationName getName()
The name of an operation is the key identifier for it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...
ScfToSPIRVContextImpl * getImpl()