MLIR 23.0.0git
SCFToSPIRV.cpp
Go to the documentation of this file.
1//===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SCF dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
19#include "llvm/Support/FormatVariadic.h"
20
21using namespace mlir;
22
23//===----------------------------------------------------------------------===//
24// Context
25//===----------------------------------------------------------------------===//
26
27namespace mlir {
29 // Map between the spirv region control flow operation (spirv.mlir.loop or
30 // spirv.mlir.selection) to the VariableOp created to store the region
31 // results. The order of the VariableOp matches the order of the results.
33};
34} // namespace mlir
35
36/// We use ScfToSPIRVContext to store information about the lowering of the scf
37/// region that need to be used later on. When we lower scf.for/scf.if we create
38/// VariableOp to store the results. We need to keep track of the VariableOp
39/// created as we need to insert stores into them when lowering Yield. Those
40/// StoreOp cannot be created earlier as they may use a different type than
41/// yield operands.
43 impl = std::make_unique<::ScfToSPIRVContextImpl>();
44}
45
47
48namespace {
49
50//===----------------------------------------------------------------------===//
51// Helper Functions
52//===----------------------------------------------------------------------===//
53
54/// Replaces SCF op outputs with SPIR-V variable loads.
55/// We create VariableOp to handle the results value of the control flow region.
56/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
57/// after the loop we load the value from the allocation and use it as the SCF
58/// op result.
59template <typename ScfOp, typename OpTy>
60void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
61 ConversionPatternRewriter &rewriter,
62 ScfToSPIRVContextImpl *scfToSPIRVContext,
63 ArrayRef<Type> returnTypes) {
64
65 Location loc = scfOp.getLoc();
66 auto &allocas = scfToSPIRVContext->outputVars[newOp];
67 // Clearing the allocas is necessary in case a dialect conversion path failed
68 // previously, and this is the second attempt of this conversion.
69 allocas.clear();
70 SmallVector<Value, 8> resultValue;
71 for (Type convertedType : returnTypes) {
72 auto pointerType =
73 spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
74 rewriter.setInsertionPoint(newOp);
75 auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
76 spirv::StorageClass::Function,
77 /*initializer=*/nullptr);
78 allocas.push_back(alloc);
79 rewriter.setInsertionPointAfter(newOp);
80 Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
81 resultValue.push_back(loadResult);
82 }
83 rewriter.replaceOp(scfOp, resultValue);
84}
85
86Region::iterator getBlockIt(Region &region, unsigned index) {
87 return std::next(region.begin(), index);
88}
89
90//===----------------------------------------------------------------------===//
91// Conversion Patterns
92//===----------------------------------------------------------------------===//
93
94/// Common class for all vector to GPU patterns.
95template <typename OpTy>
96class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
97public:
98 SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
99 ScfToSPIRVContextImpl *scfToSPIRVContext)
100 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
101 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
102
103protected:
104 ScfToSPIRVContextImpl *scfToSPIRVContext;
105 // FIXME: We explicitly keep a reference of the type converter here instead of
106 // passing it to OpConversionPattern during construction. This effectively
107 // bypasses the conversion framework's automation on type conversion. This is
108 // needed right now because the conversion framework will unconditionally
109 // legalize all types used by SCF ops upon discovering them, for example, the
110 // types of loop carried values. We use SPIR-V variables for those loop
111 // carried values. Depending on the available capabilities, the SPIR-V
112 // variable can be different, for example, cooperative matrix or normal
113 // variable. We'd like to detach the conversion of the loop carried values
114 // from the SCF ops (which is mainly a region). So we need to "mark" types
115 // used by SCF ops as legal, if to use the conversion framework for type
116 // conversion. There isn't a straightforward way to do that yet, as when
117 // converting types, ops aren't taken into consideration. Therefore, we just
118 // bypass the framework's type conversion for now.
119 const SPIRVTypeConverter &typeConverter;
120};
121
122//===----------------------------------------------------------------------===//
123// scf::ForOp
124//===----------------------------------------------------------------------===//
125
126/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
127struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
128 using SCFToSPIRVPattern::SCFToSPIRVPattern;
129
130 LogicalResult
131 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter) const override {
133 // scf::ForOp can be lowered to the structured control flow represented by
134 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
135 // latch and the merge block the exit block. The resulting spirv::LoopOp has
136 // a single back edge from the continue to header block, and a single exit
137 // from header to merge.
138 auto loc = forOp.getLoc();
139 auto loopControl = spirv::LoopControl::None;
140 if (auto attr = forOp->getAttrOfType<spirv::LoopControlAttr>(
142 loopControl = attr.getValue();
143 auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
144 loopOp.addEntryAndMergeBlock(rewriter);
145
146 OpBuilder::InsertionGuard guard(rewriter);
147 // Create the block for the header.
148 Block *header = rewriter.createBlock(&loopOp.getBody(),
149 getBlockIt(loopOp.getBody(), 1));
150 rewriter.setInsertionPointAfter(loopOp);
151
152 // Create the new induction variable to use.
153 Value adapLowerBound = adaptor.getLowerBound();
154 BlockArgument newIndVar =
155 header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
156 for (Value arg : adaptor.getInitArgs())
157 header->addArgument(arg.getType(), arg.getLoc());
158 Block *body = forOp.getBody();
159
160 // Apply signature conversion to the body of the forOp. It has a single
161 // block, with argument which is the induction variable. That has to be
162 // replaced with the new induction variable.
163 TypeConverter::SignatureConversion signatureConverter(
164 body->getNumArguments());
165 signatureConverter.remapInput(0, newIndVar);
166 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
167 signatureConverter.remapInput(i, header->getArgument(i));
168 body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
169 signatureConverter);
170
171 // Move the blocks from the forOp into the loopOp. This is the body of the
172 // loopOp.
173 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
174 getBlockIt(loopOp.getBody(), 2));
175
176 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
177 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
178 // Branch into it from the entry.
179 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
180 spirv::BranchOp::create(rewriter, loc, header, args);
181
182 // Generate the rest of the loop header.
183 rewriter.setInsertionPointToEnd(header);
184 auto *mergeBlock = loopOp.getMergeBlock();
185 Value cmpOp;
186 if (forOp.getUnsignedCmp()) {
187 cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
188 newIndVar, adaptor.getUpperBound());
189 } else {
190 cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
191 newIndVar, adaptor.getUpperBound());
192 }
193
194 spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
195 ArrayRef<Value>(), mergeBlock,
196 ArrayRef<Value>());
197
198 // Generate instructions to increment the step of the induction variable and
199 // branch to the header.
200 Block *continueBlock = loopOp.getContinueBlock();
201 rewriter.setInsertionPointToEnd(continueBlock);
202
203 // Add the step to the induction variable and branch to the header.
204 Value updatedIndVar = spirv::IAddOp::create(
205 rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep());
206 spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
207
208 // Infer the return types from the init operands. Vector type may get
209 // converted to CooperativeMatrix or to Vector type, to avoid having complex
210 // extra logic to figure out the right type we just infer it from the Init
211 // operands.
212 SmallVector<Type, 8> initTypes;
213 for (auto arg : adaptor.getInitArgs())
214 initTypes.push_back(arg.getType());
215 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
216 initTypes);
217 return success();
218 }
219};
220
221//===----------------------------------------------------------------------===//
222// scf::IfOp
223//===----------------------------------------------------------------------===//
224
225/// Pattern to convert a scf::IfOp within kernel functions into
226/// spirv::SelectionOp.
227struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
228 using SCFToSPIRVPattern::SCFToSPIRVPattern;
229
230 LogicalResult
231 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter) const override {
233 // When lowering `scf::IfOp` we explicitly create a selection header block
234 // before the control flow diverges and a merge block where control flow
235 // subsequently converges.
236 auto loc = ifOp.getLoc();
237
238 // Compute return types.
239 SmallVector<Type, 8> returnTypes;
240 for (auto result : ifOp.getResults()) {
241 auto convertedType = typeConverter.convertType(result.getType());
242 if (!convertedType)
243 return rewriter.notifyMatchFailure(
244 loc,
245 llvm::formatv("failed to convert type '{0}'", result.getType()));
246
247 returnTypes.push_back(convertedType);
248 }
249
250 // Create `spirv.selection` operation, selection header block and merge
251 // block.
252 auto selectionControl = spirv::SelectionControl::None;
253 if (auto attr = ifOp->getAttrOfType<spirv::SelectionControlAttr>(
255 selectionControl = attr.getValue();
256 auto selectionOp =
257 spirv::SelectionOp::create(rewriter, loc, selectionControl);
258 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
259 selectionOp.getBody().end());
260 spirv::MergeOp::create(rewriter, loc);
261
262 OpBuilder::InsertionGuard guard(rewriter);
263 auto *selectionHeaderBlock =
264 rewriter.createBlock(&selectionOp.getBody().front());
265
266 // Inline `then` region before the merge block and branch to it.
267 auto &thenRegion = ifOp.getThenRegion();
268 auto *thenBlock = &thenRegion.front();
269 rewriter.setInsertionPointToEnd(&thenRegion.back());
270 spirv::BranchOp::create(rewriter, loc, mergeBlock);
271 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
272
273 auto *elseBlock = mergeBlock;
274 // If `else` region is not empty, inline that region before the merge block
275 // and branch to it.
276 if (!ifOp.getElseRegion().empty()) {
277 auto &elseRegion = ifOp.getElseRegion();
278 elseBlock = &elseRegion.front();
279 rewriter.setInsertionPointToEnd(&elseRegion.back());
280 spirv::BranchOp::create(rewriter, loc, mergeBlock);
281 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
282 }
283
284 // Create a `spirv.BranchConditional` operation for selection header block.
285 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
286 spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
287 thenBlock, ArrayRef<Value>(), elseBlock,
288 ArrayRef<Value>());
289
290 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
291 returnTypes);
292 return success();
293 }
294};
295
296//===----------------------------------------------------------------------===//
297// scf::YieldOp
298//===----------------------------------------------------------------------===//
299
300struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
301public:
302 using SCFToSPIRVPattern::SCFToSPIRVPattern;
303
304 LogicalResult
305 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter) const override {
307 ValueRange operands = adaptor.getOperands();
308
309 Operation *parent = terminatorOp->getParentOp();
310
311 // TODO: Implement conversion for the remaining `scf` ops.
312 if (parent->getDialect()->getNamespace() ==
313 scf::SCFDialect::getDialectNamespace() &&
314 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
315 return rewriter.notifyMatchFailure(
316 terminatorOp,
317 llvm::formatv("conversion not supported for parent op: '{0}'",
318 parent->getName()));
319
320 // If the region return values, store each value into the associated
321 // VariableOp created during lowering of the parent region.
322 if (!operands.empty()) {
323 auto &allocas = scfToSPIRVContext->outputVars[parent];
324 if (allocas.size() != operands.size())
325 return failure();
326
327 auto loc = terminatorOp.getLoc();
328 for (unsigned i = 0, e = operands.size(); i < e; i++)
329 spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
330 if (isa<spirv::LoopOp>(parent)) {
331 // For loops we also need to update the branch jumping back to the
332 // header.
333 auto br = cast<spirv::BranchOp>(
334 rewriter.getInsertionBlock()->getTerminator());
335 SmallVector<Value, 8> args(br.getBlockArguments());
336 args.append(operands.begin(), operands.end());
337 rewriter.setInsertionPoint(br);
338 spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
339 args);
340 rewriter.eraseOp(br);
341 }
342 }
343 rewriter.eraseOp(terminatorOp);
344 return success();
345 }
346};
347
348//===----------------------------------------------------------------------===//
349// scf::WhileOp
350//===----------------------------------------------------------------------===//
351
352struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
353 using SCFToSPIRVPattern::SCFToSPIRVPattern;
354
355 LogicalResult
356 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
357 ConversionPatternRewriter &rewriter) const override {
358 auto loc = whileOp.getLoc();
359 auto loopControl = spirv::LoopControl::None;
360 if (auto attr = whileOp->getAttrOfType<spirv::LoopControlAttr>(
362 loopControl = attr.getValue();
363 auto loopOp = spirv::LoopOp::create(rewriter, loc, loopControl);
364 loopOp.addEntryAndMergeBlock(rewriter);
365
366 Region &beforeRegion = whileOp.getBefore();
367 Region &afterRegion = whileOp.getAfter();
368
369 if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
370 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
371 return rewriter.notifyMatchFailure(whileOp,
372 "Failed to convert region types");
373
374 OpBuilder::InsertionGuard guard(rewriter);
375
376 Block &entryBlock = *loopOp.getEntryBlock();
377 Block &beforeBlock = beforeRegion.front();
378 Block &afterBlock = afterRegion.front();
379 Block &mergeBlock = *loopOp.getMergeBlock();
380
381 auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
382 SmallVector<Value> condArgs;
383 if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
384 return failure();
385
386 Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
387 if (!conditionVal)
388 return failure();
389
390 auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
391 SmallVector<Value> yieldArgs;
392 if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
393 return failure();
394
395 // Move the while before block as the initial loop header block.
396 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
397 getBlockIt(loopOp.getBody(), 1));
398
399 // Move the while after block as the initial loop body block.
400 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
401 getBlockIt(loopOp.getBody(), 2));
402
403 // Jump from the loop entry block to the loop header block.
404 rewriter.setInsertionPointToEnd(&entryBlock);
405 spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
406
407 auto condLoc = cond.getLoc();
408
409 SmallVector<Value> resultValues(condArgs.size());
410
411 // For other SCF ops, the scf.yield op yields the value for the whole SCF
412 // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
413 // local variables. But for the scf.while op, the scf.yield op yields a
414 // value for the before region, which may not matching the whole op's
415 // result. Instead, the scf.condition op returns values matching the whole
416 // op's results. So we need to create/load/store variables according to
417 // that.
418 for (const auto &it : llvm::enumerate(condArgs)) {
419 auto res = it.value();
420 auto i = it.index();
421 auto pointerType =
422 spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
423
424 // Create local variables before the scf.while op.
425 rewriter.setInsertionPoint(loopOp);
426 auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
427 spirv::StorageClass::Function,
428 /*initializer=*/nullptr);
429
430 // Load the final result values after the scf.while op.
431 rewriter.setInsertionPointAfter(loopOp);
432 auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
433 resultValues[i] = loadResult;
434
435 // Store the current iteration's result value.
436 rewriter.setInsertionPointToEnd(&beforeBlock);
437 spirv::StoreOp::create(rewriter, condLoc, alloc, res);
438 }
439
440 rewriter.setInsertionPointToEnd(&beforeBlock);
441 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
442 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, ValueRange());
443
444 // Convert the scf.yield op to a branch back to the header block.
445 rewriter.setInsertionPointToEnd(&afterBlock);
446 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
447 yieldArgs);
448
449 rewriter.replaceOp(whileOp, resultValues);
450 return success();
451 }
452};
453} // namespace
454
455//===----------------------------------------------------------------------===//
456// Public API
457//===----------------------------------------------------------------------===//
458
460 ScfToSPIRVContext &scfToSPIRVContext,
461 RewritePatternSet &patterns) {
462 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
463 WhileOpConversion>(patterns.getContext(), typeConverter,
464 scfToSPIRVContext.getImpl());
465}
return success()
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
StringRef getNamespace() const
Definition Dialect.h:54
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
BlockListType::iterator iterator
Definition Region.h:52
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
static PointerType get(Type pointeeType, StorageClass storageClass)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
StringRef getLoopControlAttrName()
Returns the attribute name for specifying loop control.
StringRef getSelectionControlAttrName()
Returns the attribute name for specifying selection control.
Include the generated interface declarations.
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...
ScfToSPIRVContextImpl * getImpl()
Definition SCFToSPIRV.h:29