20 #include "llvm/Support/FormatVariadic.h"
44 impl = std::make_unique<::ScfToSPIRVContextImpl>();
60 template <
typename ScfOp,
typename OpTy>
61 void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
67 auto &allocas = scfToSPIRVContext->
outputVars[newOp];
72 for (
Type convertedType : returnTypes) {
76 auto alloc = rewriter.
create<spirv::VariableOp>(
77 loc, pointerType, spirv::StorageClass::Function,
79 allocas.push_back(alloc);
81 Value loadResult = rewriter.
create<spirv::LoadOp>(loc, alloc);
82 resultValue.push_back(loadResult);
88 return std::next(region.
begin(), index);
96 template <
typename OpTy>
102 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
128 struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
129 using SCFToSPIRVPattern::SCFToSPIRVPattern;
132 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
139 auto loc = forOp.getLoc();
141 loopOp.addEntryAndMergeBlock(rewriter);
146 getBlockIt(loopOp.getBody(), 1));
150 Value adapLowerBound = adaptor.getLowerBound();
153 for (
Value arg : adaptor.getInitArgs())
155 Block *body = forOp.getBody();
162 signatureConverter.remapInput(0, newIndVar);
164 signatureConverter.remapInput(i, header->
getArgument(i));
171 getBlockIt(loopOp.getBody(), 2));
174 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
177 rewriter.
create<spirv::BranchOp>(loc, header, args);
181 auto *mergeBlock = loopOp.getMergeBlock();
182 auto cmpOp = rewriter.
create<spirv::SLessThanOp>(
183 loc, rewriter.
getI1Type(), newIndVar, adaptor.getUpperBound());
185 rewriter.
create<spirv::BranchConditionalOp>(
190 Block *continueBlock = loopOp.getContinueBlock();
194 Value updatedIndVar = rewriter.
create<spirv::IAddOp>(
195 loc, newIndVar.
getType(), newIndVar, adaptor.getStep());
196 rewriter.
create<spirv::BranchOp>(loc, header, updatedIndVar);
203 for (
auto arg : adaptor.getInitArgs())
204 initTypes.push_back(arg.getType());
205 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
217 struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
218 using SCFToSPIRVPattern::SCFToSPIRVPattern;
221 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
226 auto loc = ifOp.getLoc();
230 for (
auto result : ifOp.getResults()) {
231 auto convertedType = typeConverter.convertType(result.getType());
235 llvm::formatv(
"failed to convert type '{0}'", result.getType()));
237 returnTypes.push_back(convertedType);
244 auto *mergeBlock = rewriter.
createBlock(&selectionOp.getBody(),
245 selectionOp.getBody().end());
246 rewriter.
create<spirv::MergeOp>(loc);
249 auto *selectionHeaderBlock =
250 rewriter.
createBlock(&selectionOp.getBody().front());
253 auto &thenRegion = ifOp.getThenRegion();
254 auto *thenBlock = &thenRegion.
front();
256 rewriter.
create<spirv::BranchOp>(loc, mergeBlock);
259 auto *elseBlock = mergeBlock;
262 if (!ifOp.getElseRegion().empty()) {
263 auto &elseRegion = ifOp.getElseRegion();
264 elseBlock = &elseRegion.front();
266 rewriter.
create<spirv::BranchOp>(loc, mergeBlock);
272 rewriter.
create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
276 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
286 struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
288 using SCFToSPIRVPattern::SCFToSPIRVPattern;
291 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
299 scf::SCFDialect::getDialectNamespace() &&
300 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
303 llvm::formatv(
"conversion not supported for parent op: '{0}'",
308 if (!operands.empty()) {
309 auto &allocas = scfToSPIRVContext->
outputVars[parent];
310 if (allocas.size() != operands.size())
313 auto loc = terminatorOp.getLoc();
314 for (
unsigned i = 0, e = operands.size(); i < e; i++)
315 rewriter.
create<spirv::StoreOp>(loc, allocas[i], operands[i]);
316 if (isa<spirv::LoopOp>(parent)) {
319 auto br = cast<spirv::BranchOp>(
322 args.append(operands.begin(), operands.end());
324 rewriter.
create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
329 rewriter.
eraseOp(terminatorOp);
338 struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
339 using SCFToSPIRVPattern::SCFToSPIRVPattern;
342 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
344 auto loc = whileOp.getLoc();
346 loopOp.addEntryAndMergeBlock(rewriter);
348 Region &beforeRegion = whileOp.getBefore();
349 Region &afterRegion = whileOp.getAfter();
354 "Failed to convert region types");
358 Block &entryBlock = *loopOp.getEntryBlock();
361 Block &mergeBlock = *loopOp.getMergeBlock();
363 auto cond = cast<scf::ConditionOp>(beforeBlock.
getTerminator());
379 getBlockIt(loopOp.getBody(), 1));
383 getBlockIt(loopOp.getBody(), 2));
387 rewriter.
create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
389 auto condLoc = cond.
getLoc();
401 auto res = it.value();
408 auto alloc = rewriter.
create<spirv::VariableOp>(
409 condLoc, pointerType, spirv::StorageClass::Function,
414 auto loadResult = rewriter.
create<spirv::LoadOp>(condLoc, alloc);
415 resultValues[i] = loadResult;
419 rewriter.
create<spirv::StoreOp>(condLoc, alloc, res);
424 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
431 rewriter.
replaceOp(whileOp, resultValues);
444 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
445 WhileOpConversion>(
patterns.getContext(), typeConverter,
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContextImpl * getImpl()
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...