20 #include "llvm/Support/FormatVariadic.h"
44 impl = std::make_unique<::ScfToSPIRVContextImpl>();
60 template <
typename ScfOp,
typename OpTy>
61 void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
67 auto &allocas = scfToSPIRVContext->
outputVars[newOp];
72 for (
Type convertedType : returnTypes) {
76 auto alloc = rewriter.
create<spirv::VariableOp>(
77 loc, pointerType, spirv::StorageClass::Function,
79 allocas.push_back(alloc);
81 Value loadResult = rewriter.
create<spirv::LoadOp>(loc, alloc);
82 resultValue.push_back(loadResult);
88 return std::next(region.
begin(), index);
96 template <
typename OpTy>
102 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
128 struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
129 using SCFToSPIRVPattern::SCFToSPIRVPattern;
132 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
139 auto loc = forOp.getLoc();
141 loopOp.addEntryAndMergeBlock(rewriter);
146 getBlockIt(loopOp.getBody(), 1));
150 Value adapLowerBound = adaptor.getLowerBound();
153 for (
Value arg : adaptor.getInitArgs())
155 Block *body = forOp.getBody();
162 signatureConverter.remapInput(0, newIndVar);
164 signatureConverter.remapInput(i, header->
getArgument(i));
171 getBlockIt(loopOp.getBody(), 2));
174 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
177 rewriter.
create<spirv::BranchOp>(loc, header, args);
181 auto *mergeBlock = loopOp.getMergeBlock();
182 auto cmpOp = rewriter.
create<spirv::SLessThanOp>(
183 loc, rewriter.
getI1Type(), newIndVar, adaptor.getUpperBound());
185 rewriter.
create<spirv::BranchConditionalOp>(
190 Block *continueBlock = loopOp.getContinueBlock();
194 Value updatedIndVar = rewriter.
create<spirv::IAddOp>(
195 loc, newIndVar.
getType(), newIndVar, adaptor.getStep());
196 rewriter.
create<spirv::BranchOp>(loc, header, updatedIndVar);
203 for (
auto arg : adaptor.getInitArgs())
204 initTypes.push_back(arg.getType());
205 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
217 struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
218 using SCFToSPIRVPattern::SCFToSPIRVPattern;
221 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
226 auto loc = ifOp.getLoc();
232 auto *mergeBlock = rewriter.
createBlock(&selectionOp.getBody(),
233 selectionOp.getBody().end());
234 rewriter.
create<spirv::MergeOp>(loc);
237 auto *selectionHeaderBlock =
238 rewriter.
createBlock(&selectionOp.getBody().front());
241 auto &thenRegion = ifOp.getThenRegion();
242 auto *thenBlock = &thenRegion.
front();
244 rewriter.
create<spirv::BranchOp>(loc, mergeBlock);
247 auto *elseBlock = mergeBlock;
250 if (!ifOp.getElseRegion().empty()) {
251 auto &elseRegion = ifOp.getElseRegion();
252 elseBlock = &elseRegion.front();
254 rewriter.
create<spirv::BranchOp>(loc, mergeBlock);
260 rewriter.
create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
265 for (
auto result : ifOp.getResults()) {
266 auto convertedType = typeConverter.convertType(result.getType());
270 llvm::formatv(
"failed to convert type '{0}'", result.getType()));
272 returnTypes.push_back(convertedType);
274 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
284 struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
286 using SCFToSPIRVPattern::SCFToSPIRVPattern;
289 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
297 scf::SCFDialect::getDialectNamespace() &&
298 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
301 llvm::formatv(
"conversion not supported for parent op: '{0}'",
306 if (!operands.empty()) {
307 auto &allocas = scfToSPIRVContext->
outputVars[parent];
308 if (allocas.size() != operands.size())
311 auto loc = terminatorOp.getLoc();
312 for (
unsigned i = 0, e = operands.size(); i < e; i++)
313 rewriter.
create<spirv::StoreOp>(loc, allocas[i], operands[i]);
314 if (isa<spirv::LoopOp>(parent)) {
317 auto br = cast<spirv::BranchOp>(
320 args.append(operands.begin(), operands.end());
322 rewriter.
create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
327 rewriter.
eraseOp(terminatorOp);
336 struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
337 using SCFToSPIRVPattern::SCFToSPIRVPattern;
340 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
342 auto loc = whileOp.getLoc();
344 loopOp.addEntryAndMergeBlock(rewriter);
346 Region &beforeRegion = whileOp.getBefore();
347 Region &afterRegion = whileOp.getAfter();
352 "Failed to convert region types");
356 Block &entryBlock = *loopOp.getEntryBlock();
359 Block &mergeBlock = *loopOp.getMergeBlock();
361 auto cond = cast<scf::ConditionOp>(beforeBlock.
getTerminator());
377 getBlockIt(loopOp.getBody(), 1));
381 getBlockIt(loopOp.getBody(), 2));
385 rewriter.
create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
387 auto condLoc = cond.
getLoc();
399 auto res = it.value();
406 auto alloc = rewriter.
create<spirv::VariableOp>(
407 condLoc, pointerType, spirv::StorageClass::Function,
412 auto loadResult = rewriter.
create<spirv::LoadOp>(condLoc, alloc);
413 resultValues[i] = loadResult;
417 rewriter.
create<spirv::StoreOp>(condLoc, alloc, res);
422 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
429 rewriter.
replaceOp(whileOp, resultValues);
442 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
443 WhileOpConversion>(
patterns.getContext(), typeConverter,
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContextImpl * getImpl()
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...