19#include "llvm/Support/FormatVariadic.h"
43 impl = std::make_unique<::ScfToSPIRVContextImpl>();
59template <
typename ScfOp,
typename OpTy>
60void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
61 ConversionPatternRewriter &rewriter,
66 auto &allocas = scfToSPIRVContext->
outputVars[newOp];
71 for (
Type convertedType : returnTypes) {
74 rewriter.setInsertionPoint(newOp);
75 auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
76 spirv::StorageClass::Function,
78 allocas.push_back(alloc);
79 rewriter.setInsertionPointAfter(newOp);
80 Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
81 resultValue.push_back(loadResult);
83 rewriter.replaceOp(scfOp, resultValue);
95template <
typename OpTy>
96class SCFToSPIRVPattern :
public OpConversionPattern<OpTy> {
98 SCFToSPIRVPattern(MLIRContext *context,
const SPIRVTypeConverter &converter,
99 ScfToSPIRVContextImpl *scfToSPIRVContext)
100 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
101 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
104 ScfToSPIRVContextImpl *scfToSPIRVContext;
119 const SPIRVTypeConverter &typeConverter;
127struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
128 using SCFToSPIRVPattern::SCFToSPIRVPattern;
131 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const override {
138 auto loc = forOp.getLoc();
139 auto loopControl = spirv::LoopControl::None;
140 if (
auto attr = forOp->getAttrOfType<spirv::LoopControlAttr>(
142 loopControl = attr.getValue();
143 auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
144 loopOp.addEntryAndMergeBlock(rewriter);
146 OpBuilder::InsertionGuard guard(rewriter);
148 Block *header = rewriter.createBlock(&loopOp.getBody(),
149 getBlockIt(loopOp.getBody(), 1));
150 rewriter.setInsertionPointAfter(loopOp);
153 Value adapLowerBound = adaptor.getLowerBound();
154 BlockArgument newIndVar =
156 for (Value arg : adaptor.getInitArgs())
158 Block *body = forOp.getBody();
163 TypeConverter::SignatureConversion signatureConverter(
165 signatureConverter.remapInput(0, newIndVar);
167 signatureConverter.remapInput(i, header->
getArgument(i));
168 body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
173 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
174 getBlockIt(loopOp.getBody(), 2));
176 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
177 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
179 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
180 spirv::BranchOp::create(rewriter, loc, header, args);
183 rewriter.setInsertionPointToEnd(header);
184 auto *mergeBlock = loopOp.getMergeBlock();
186 if (forOp.getUnsignedCmp()) {
187 cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
188 newIndVar, adaptor.getUpperBound());
190 cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
191 newIndVar, adaptor.getUpperBound());
194 spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
195 ArrayRef<Value>(), mergeBlock,
200 Block *continueBlock = loopOp.getContinueBlock();
201 rewriter.setInsertionPointToEnd(continueBlock);
204 Value updatedIndVar = spirv::IAddOp::create(
205 rewriter, loc, newIndVar.
getType(), newIndVar, adaptor.getStep());
206 spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
212 SmallVector<Type, 8> initTypes;
213 for (
auto arg : adaptor.getInitArgs())
214 initTypes.push_back(arg.getType());
215 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
227struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
228 using SCFToSPIRVPattern::SCFToSPIRVPattern;
231 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter)
const override {
236 auto loc = ifOp.getLoc();
239 SmallVector<Type, 8> returnTypes;
240 for (
auto result : ifOp.getResults()) {
241 auto convertedType = typeConverter.convertType(
result.getType());
243 return rewriter.notifyMatchFailure(
245 llvm::formatv(
"failed to convert type '{0}'",
result.getType()));
247 returnTypes.push_back(convertedType);
252 auto selectionOp = spirv::SelectionOp::create(
253 rewriter, loc, spirv::SelectionControl::None);
254 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
255 selectionOp.getBody().end());
256 spirv::MergeOp::create(rewriter, loc);
258 OpBuilder::InsertionGuard guard(rewriter);
259 auto *selectionHeaderBlock =
260 rewriter.createBlock(&selectionOp.getBody().front());
263 auto &thenRegion = ifOp.getThenRegion();
264 auto *thenBlock = &thenRegion.front();
265 rewriter.setInsertionPointToEnd(&thenRegion.back());
266 spirv::BranchOp::create(rewriter, loc, mergeBlock);
267 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
269 auto *elseBlock = mergeBlock;
272 if (!ifOp.getElseRegion().empty()) {
273 auto &elseRegion = ifOp.getElseRegion();
274 elseBlock = &elseRegion.front();
275 rewriter.setInsertionPointToEnd(&elseRegion.back());
276 spirv::BranchOp::create(rewriter, loc, mergeBlock);
277 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
281 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
282 spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
283 thenBlock, ArrayRef<Value>(), elseBlock,
286 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
296struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
298 using SCFToSPIRVPattern::SCFToSPIRVPattern;
301 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
302 ConversionPatternRewriter &rewriter)
const override {
305 Operation *parent = terminatorOp->getParentOp();
309 scf::SCFDialect::getDialectNamespace() &&
310 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
311 return rewriter.notifyMatchFailure(
313 llvm::formatv(
"conversion not supported for parent op: '{0}'",
318 if (!operands.empty()) {
319 auto &allocas = scfToSPIRVContext->
outputVars[parent];
320 if (allocas.size() != operands.size())
323 auto loc = terminatorOp.getLoc();
324 for (
unsigned i = 0, e = operands.size(); i < e; i++)
325 spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
326 if (isa<spirv::LoopOp>(parent)) {
329 auto br = cast<spirv::BranchOp>(
330 rewriter.getInsertionBlock()->getTerminator());
331 SmallVector<Value, 8> args(br.getBlockArguments());
332 args.append(operands.begin(), operands.end());
333 rewriter.setInsertionPoint(br);
334 spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
336 rewriter.eraseOp(br);
339 rewriter.eraseOp(terminatorOp);
348struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
349 using SCFToSPIRVPattern::SCFToSPIRVPattern;
352 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const override {
354 auto loc = whileOp.getLoc();
355 auto loopControl = spirv::LoopControl::None;
356 if (
auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
358 loopControl = attr.getValue();
359 auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
360 loopOp.addEntryAndMergeBlock(rewriter);
362 Region &beforeRegion = whileOp.getBefore();
363 Region &afterRegion = whileOp.getAfter();
365 if (
failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
366 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
367 return rewriter.notifyMatchFailure(whileOp,
368 "Failed to convert region types");
370 OpBuilder::InsertionGuard guard(rewriter);
372 Block &entryBlock = *loopOp.getEntryBlock();
375 Block &mergeBlock = *loopOp.getMergeBlock();
377 auto cond = cast<scf::ConditionOp>(beforeBlock.
getTerminator());
378 SmallVector<Value> condArgs;
379 if (
failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
382 Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
387 SmallVector<Value> yieldArgs;
388 if (
failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
392 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
393 getBlockIt(loopOp.getBody(), 1));
396 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
397 getBlockIt(loopOp.getBody(), 2));
400 rewriter.setInsertionPointToEnd(&entryBlock);
401 spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
403 auto condLoc = cond.getLoc();
405 SmallVector<Value> resultValues(condArgs.size());
414 for (
const auto &it : llvm::enumerate(condArgs)) {
415 auto res = it.value();
421 rewriter.setInsertionPoint(loopOp);
422 auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
423 spirv::StorageClass::Function,
427 rewriter.setInsertionPointAfter(loopOp);
428 auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
429 resultValues[i] = loadResult;
432 rewriter.setInsertionPointToEnd(&beforeBlock);
433 spirv::StoreOp::create(rewriter, condLoc, alloc, res);
436 rewriter.setInsertionPointToEnd(&beforeBlock);
437 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
438 cond, conditionVal, &afterBlock, condArgs, &mergeBlock,
ValueRange());
441 rewriter.setInsertionPointToEnd(&afterBlock);
442 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
445 rewriter.replaceOp(whileOp, resultValues);
458 patterns.
add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
459 WhileOpConversion>(patterns.
getContext(), typeConverter,
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OperationName getName()
The name of an operation is the key identifier for it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
StringRef getLoopControlAttrName()
Returns the attribute name for specifying loop control.
Include the generated interface declarations.
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...
ScfToSPIRVContextImpl * getImpl()