19#include "llvm/Support/FormatVariadic.h"
43 impl = std::make_unique<::ScfToSPIRVContextImpl>();
59template <
typename ScfOp,
typename OpTy>
60void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
61 ConversionPatternRewriter &rewriter,
66 auto &allocas = scfToSPIRVContext->
outputVars[newOp];
71 for (
Type convertedType : returnTypes) {
74 rewriter.setInsertionPoint(newOp);
75 auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
76 spirv::StorageClass::Function,
78 allocas.push_back(alloc);
79 rewriter.setInsertionPointAfter(newOp);
80 Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
81 resultValue.push_back(loadResult);
83 rewriter.replaceOp(scfOp, resultValue);
95template <
typename OpTy>
96class SCFToSPIRVPattern :
public OpConversionPattern<OpTy> {
98 SCFToSPIRVPattern(MLIRContext *context,
const SPIRVTypeConverter &converter,
99 ScfToSPIRVContextImpl *scfToSPIRVContext)
100 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
101 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
104 ScfToSPIRVContextImpl *scfToSPIRVContext;
119 const SPIRVTypeConverter &typeConverter;
127struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
128 using SCFToSPIRVPattern::SCFToSPIRVPattern;
131 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const override {
138 auto loc = forOp.getLoc();
139 auto loopControl = spirv::LoopControl::None;
140 if (
auto attr = forOp->getAttrOfType<spirv::LoopControlAttr>(
142 loopControl = attr.getValue();
143 auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
144 loopOp.addEntryAndMergeBlock(rewriter);
146 OpBuilder::InsertionGuard guard(rewriter);
148 Block *header = rewriter.createBlock(&loopOp.getBody(),
149 getBlockIt(loopOp.getBody(), 1));
150 rewriter.setInsertionPointAfter(loopOp);
153 Value adapLowerBound = adaptor.getLowerBound();
154 BlockArgument newIndVar =
156 for (Value arg : adaptor.getInitArgs())
158 Block *body = forOp.getBody();
163 TypeConverter::SignatureConversion signatureConverter(
165 signatureConverter.remapInput(0, newIndVar);
167 signatureConverter.remapInput(i, header->
getArgument(i));
168 body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
173 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
174 getBlockIt(loopOp.getBody(), 2));
176 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
177 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
179 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
180 spirv::BranchOp::create(rewriter, loc, header, args);
183 rewriter.setInsertionPointToEnd(header);
184 auto *mergeBlock = loopOp.getMergeBlock();
186 if (forOp.getUnsignedCmp()) {
187 cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
188 newIndVar, adaptor.getUpperBound());
190 cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
191 newIndVar, adaptor.getUpperBound());
194 spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
195 ArrayRef<Value>(), mergeBlock,
200 Block *continueBlock = loopOp.getContinueBlock();
201 rewriter.setInsertionPointToEnd(continueBlock);
204 Value updatedIndVar = spirv::IAddOp::create(
205 rewriter, loc, newIndVar.
getType(), newIndVar, adaptor.getStep());
206 spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
212 SmallVector<Type, 8> initTypes;
213 for (
auto arg : adaptor.getInitArgs())
214 initTypes.push_back(arg.getType());
215 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
227struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
228 using SCFToSPIRVPattern::SCFToSPIRVPattern;
231 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter)
const override {
236 auto loc = ifOp.getLoc();
239 SmallVector<Type, 8> returnTypes;
240 for (
auto result : ifOp.getResults()) {
241 auto convertedType = typeConverter.convertType(
result.getType());
243 return rewriter.notifyMatchFailure(
245 llvm::formatv(
"failed to convert type '{0}'",
result.getType()));
247 returnTypes.push_back(convertedType);
252 auto selectionControl = spirv::SelectionControl::None;
253 if (
auto attr = ifOp->getAttrOfType<spirv::SelectionControlAttr>(
255 selectionControl = attr.getValue();
257 spirv::SelectionOp::create(rewriter, loc, selectionControl);
258 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
259 selectionOp.getBody().end());
260 spirv::MergeOp::create(rewriter, loc);
262 OpBuilder::InsertionGuard guard(rewriter);
263 auto *selectionHeaderBlock =
264 rewriter.createBlock(&selectionOp.getBody().front());
267 auto &thenRegion = ifOp.getThenRegion();
268 auto *thenBlock = &thenRegion.front();
269 rewriter.setInsertionPointToEnd(&thenRegion.back());
270 spirv::BranchOp::create(rewriter, loc, mergeBlock);
271 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
273 auto *elseBlock = mergeBlock;
276 if (!ifOp.getElseRegion().empty()) {
277 auto &elseRegion = ifOp.getElseRegion();
278 elseBlock = &elseRegion.front();
279 rewriter.setInsertionPointToEnd(&elseRegion.back());
280 spirv::BranchOp::create(rewriter, loc, mergeBlock);
281 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
285 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
286 spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
287 thenBlock, ArrayRef<Value>(), elseBlock,
290 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
300struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
302 using SCFToSPIRVPattern::SCFToSPIRVPattern;
305 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
309 Operation *parent = terminatorOp->getParentOp();
313 scf::SCFDialect::getDialectNamespace() &&
314 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
315 return rewriter.notifyMatchFailure(
317 llvm::formatv(
"conversion not supported for parent op: '{0}'",
322 if (!operands.empty()) {
323 auto &allocas = scfToSPIRVContext->
outputVars[parent];
324 if (allocas.size() != operands.size())
327 auto loc = terminatorOp.getLoc();
328 for (
unsigned i = 0, e = operands.size(); i < e; i++)
329 spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
330 if (isa<spirv::LoopOp>(parent)) {
333 auto br = cast<spirv::BranchOp>(
334 rewriter.getInsertionBlock()->getTerminator());
335 SmallVector<Value, 8> args(br.getBlockArguments());
336 args.append(operands.begin(), operands.end());
337 rewriter.setInsertionPoint(br);
338 spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
340 rewriter.eraseOp(br);
343 rewriter.eraseOp(terminatorOp);
352struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
353 using SCFToSPIRVPattern::SCFToSPIRVPattern;
356 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
357 ConversionPatternRewriter &rewriter)
const override {
358 auto loc = whileOp.getLoc();
359 auto loopControl = spirv::LoopControl::None;
360 if (
auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
362 loopControl = attr.getValue();
363 auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
364 loopOp.addEntryAndMergeBlock(rewriter);
366 Region &beforeRegion = whileOp.getBefore();
367 Region &afterRegion = whileOp.getAfter();
369 if (
failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
370 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
371 return rewriter.notifyMatchFailure(whileOp,
372 "Failed to convert region types");
374 OpBuilder::InsertionGuard guard(rewriter);
376 Block &entryBlock = *loopOp.getEntryBlock();
379 Block &mergeBlock = *loopOp.getMergeBlock();
381 auto cond = cast<scf::ConditionOp>(beforeBlock.
getTerminator());
382 SmallVector<Value> condArgs;
383 if (
failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
386 Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
391 SmallVector<Value> yieldArgs;
392 if (
failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
396 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
397 getBlockIt(loopOp.getBody(), 1));
400 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
401 getBlockIt(loopOp.getBody(), 2));
404 rewriter.setInsertionPointToEnd(&entryBlock);
405 spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
407 auto condLoc = cond.getLoc();
409 SmallVector<Value> resultValues(condArgs.size());
418 for (
const auto &it : llvm::enumerate(condArgs)) {
419 auto res = it.value();
425 rewriter.setInsertionPoint(loopOp);
426 auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
427 spirv::StorageClass::Function,
431 rewriter.setInsertionPointAfter(loopOp);
432 auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
433 resultValues[i] = loadResult;
436 rewriter.setInsertionPointToEnd(&beforeBlock);
437 spirv::StoreOp::create(rewriter, condLoc, alloc, res);
440 rewriter.setInsertionPointToEnd(&beforeBlock);
441 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
442 cond, conditionVal, &afterBlock, condArgs, &mergeBlock,
ValueRange());
445 rewriter.setInsertionPointToEnd(&afterBlock);
446 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
449 rewriter.replaceOp(whileOp, resultValues);
462 patterns.
add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
463 WhileOpConversion>(patterns.
getContext(), typeConverter,
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OperationName getName()
The name of an operation is the key identifier for it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
StringRef getLoopControlAttrName()
Returns the attribute name for specifying loop control.
StringRef getSelectionControlAttrName()
Returns the attribute name for specifying selection control.
Include the generated interface declarations.
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...
ScfToSPIRVContextImpl * getImpl()