18 #include "llvm/Support/FormatVariadic.h"
42 impl = std::make_unique<::ScfToSPIRVContextImpl>();
58 template <
typename ScfOp,
typename OpTy>
59 void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
65 auto &allocas = scfToSPIRVContext->
outputVars[newOp];
70 for (
Type convertedType : returnTypes) {
74 auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
75 spirv::StorageClass::Function,
77 allocas.push_back(alloc);
79 Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
80 resultValue.push_back(loadResult);
86 return std::next(region.
begin(), index);
94 template <
typename OpTy>
100 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
126 struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
127 using SCFToSPIRVPattern::SCFToSPIRVPattern;
130 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
137 auto loc = forOp.getLoc();
140 loopOp.addEntryAndMergeBlock(rewriter);
145 getBlockIt(loopOp.getBody(), 1));
149 Value adapLowerBound = adaptor.getLowerBound();
152 for (
Value arg : adaptor.getInitArgs())
154 Block *body = forOp.getBody();
161 signatureConverter.remapInput(0, newIndVar);
163 signatureConverter.remapInput(i, header->
getArgument(i));
170 getBlockIt(loopOp.getBody(), 2));
173 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
176 spirv::BranchOp::create(rewriter, loc, header, args);
180 auto *mergeBlock = loopOp.getMergeBlock();
182 if (forOp.getUnsignedCmp()) {
183 cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.
getI1Type(),
184 newIndVar, adaptor.getUpperBound());
186 cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.
getI1Type(),
187 newIndVar, adaptor.getUpperBound());
190 spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
196 Block *continueBlock = loopOp.getContinueBlock();
200 Value updatedIndVar = spirv::IAddOp::create(
201 rewriter, loc, newIndVar.
getType(), newIndVar, adaptor.getStep());
202 spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
209 for (
auto arg : adaptor.getInitArgs())
210 initTypes.push_back(arg.getType());
211 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
223 struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
224 using SCFToSPIRVPattern::SCFToSPIRVPattern;
227 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
232 auto loc = ifOp.getLoc();
236 for (
auto result : ifOp.getResults()) {
237 auto convertedType = typeConverter.convertType(result.getType());
241 llvm::formatv(
"failed to convert type '{0}'", result.getType()));
243 returnTypes.push_back(convertedType);
248 auto selectionOp = spirv::SelectionOp::create(
250 auto *mergeBlock = rewriter.
createBlock(&selectionOp.getBody(),
251 selectionOp.getBody().end());
252 spirv::MergeOp::create(rewriter, loc);
255 auto *selectionHeaderBlock =
256 rewriter.
createBlock(&selectionOp.getBody().front());
259 auto &thenRegion = ifOp.getThenRegion();
260 auto *thenBlock = &thenRegion.
front();
262 spirv::BranchOp::create(rewriter, loc, mergeBlock);
265 auto *elseBlock = mergeBlock;
268 if (!ifOp.getElseRegion().empty()) {
269 auto &elseRegion = ifOp.getElseRegion();
270 elseBlock = &elseRegion.front();
272 spirv::BranchOp::create(rewriter, loc, mergeBlock);
278 spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
282 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
292 struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
294 using SCFToSPIRVPattern::SCFToSPIRVPattern;
297 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
305 scf::SCFDialect::getDialectNamespace() &&
306 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
309 llvm::formatv(
"conversion not supported for parent op: '{0}'",
314 if (!operands.empty()) {
315 auto &allocas = scfToSPIRVContext->
outputVars[parent];
316 if (allocas.size() != operands.size())
319 auto loc = terminatorOp.getLoc();
320 for (
unsigned i = 0, e = operands.size(); i < e; i++)
321 spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
322 if (isa<spirv::LoopOp>(parent)) {
325 auto br = cast<spirv::BranchOp>(
328 args.append(operands.begin(), operands.end());
330 spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
335 rewriter.
eraseOp(terminatorOp);
344 struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
345 using SCFToSPIRVPattern::SCFToSPIRVPattern;
348 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
350 auto loc = whileOp.getLoc();
353 loopOp.addEntryAndMergeBlock(rewriter);
355 Region &beforeRegion = whileOp.getBefore();
356 Region &afterRegion = whileOp.getAfter();
361 "Failed to convert region types");
365 Block &entryBlock = *loopOp.getEntryBlock();
368 Block &mergeBlock = *loopOp.getMergeBlock();
370 auto cond = cast<scf::ConditionOp>(beforeBlock.
getTerminator());
386 getBlockIt(loopOp.getBody(), 1));
390 getBlockIt(loopOp.getBody(), 2));
394 spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
396 auto condLoc = cond.getLoc();
408 auto res = it.value();
415 auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
416 spirv::StorageClass::Function,
421 auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
422 resultValues[i] = loadResult;
426 spirv::StoreOp::create(rewriter, condLoc, alloc, res);
431 cond, conditionVal, &afterBlock, condArgs, &mergeBlock,
ValueRange());
438 rewriter.
replaceOp(whileOp, resultValues);
451 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
452 WhileOpConversion>(
patterns.getContext(), typeConverter,
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContextImpl * getImpl()
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...