MLIR 22.0.0git
SCFToSPIRV.cpp
Go to the documentation of this file.
1//===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SCF dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/Support/FormatVariadic.h"
19
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23// Context
24//===----------------------------------------------------------------------===//
25
26namespace mlir {
28 // Map between the spirv region control flow operation (spirv.mlir.loop or
29 // spirv.mlir.selection) to the VariableOp created to store the region
30 // results. The order of the VariableOp matches the order of the results.
32};
33} // namespace mlir
34
35/// We use ScfToSPIRVContext to store information about the lowering of the scf
36/// region that need to be used later on. When we lower scf.for/scf.if we create
37/// VariableOp to store the results. We need to keep track of the VariableOp
38/// created as we need to insert stores into them when lowering Yield. Those
39/// StoreOp cannot be created earlier as they may use a different type than
40/// yield operands.
42 impl = std::make_unique<::ScfToSPIRVContextImpl>();
43}
44
46
47namespace {
48
49//===----------------------------------------------------------------------===//
50// Helper Functions
51//===----------------------------------------------------------------------===//
52
53/// Replaces SCF op outputs with SPIR-V variable loads.
54/// We create VariableOp to handle the results value of the control flow region.
55/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
56/// after the loop we load the value from the allocation and use it as the SCF
57/// op result.
58template <typename ScfOp, typename OpTy>
59void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
60 ConversionPatternRewriter &rewriter,
61 ScfToSPIRVContextImpl *scfToSPIRVContext,
62 ArrayRef<Type> returnTypes) {
63
64 Location loc = scfOp.getLoc();
65 auto &allocas = scfToSPIRVContext->outputVars[newOp];
66 // Clearing the allocas is necessary in case a dialect conversion path failed
67 // previously, and this is the second attempt of this conversion.
68 allocas.clear();
69 SmallVector<Value, 8> resultValue;
70 for (Type convertedType : returnTypes) {
71 auto pointerType =
72 spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
73 rewriter.setInsertionPoint(newOp);
74 auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
75 spirv::StorageClass::Function,
76 /*initializer=*/nullptr);
77 allocas.push_back(alloc);
78 rewriter.setInsertionPointAfter(newOp);
79 Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
80 resultValue.push_back(loadResult);
81 }
82 rewriter.replaceOp(scfOp, resultValue);
83}
84
85Region::iterator getBlockIt(Region &region, unsigned index) {
86 return std::next(region.begin(), index);
87}
88
89//===----------------------------------------------------------------------===//
90// Conversion Patterns
91//===----------------------------------------------------------------------===//
92
93/// Common class for all vector to GPU patterns.
94template <typename OpTy>
95class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
96public:
97 SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
98 ScfToSPIRVContextImpl *scfToSPIRVContext)
99 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
100 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
101
102protected:
103 ScfToSPIRVContextImpl *scfToSPIRVContext;
104 // FIXME: We explicitly keep a reference of the type converter here instead of
105 // passing it to OpConversionPattern during construction. This effectively
106 // bypasses the conversion framework's automation on type conversion. This is
107 // needed right now because the conversion framework will unconditionally
108 // legalize all types used by SCF ops upon discovering them, for example, the
109 // types of loop carried values. We use SPIR-V variables for those loop
110 // carried values. Depending on the available capabilities, the SPIR-V
111 // variable can be different, for example, cooperative matrix or normal
112 // variable. We'd like to detach the conversion of the loop carried values
113 // from the SCF ops (which is mainly a region). So we need to "mark" types
114 // used by SCF ops as legal, if to use the conversion framework for type
115 // conversion. There isn't a straightforward way to do that yet, as when
116 // converting types, ops aren't taken into consideration. Therefore, we just
117 // bypass the framework's type conversion for now.
118 const SPIRVTypeConverter &typeConverter;
119};
120
121//===----------------------------------------------------------------------===//
122// scf::ForOp
123//===----------------------------------------------------------------------===//
124
125/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
126struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
127 using SCFToSPIRVPattern::SCFToSPIRVPattern;
128
129 LogicalResult
130 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override {
132 // scf::ForOp can be lowered to the structured control flow represented by
133 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
134 // latch and the merge block the exit block. The resulting spirv::LoopOp has
135 // a single back edge from the continue to header block, and a single exit
136 // from header to merge.
137 auto loc = forOp.getLoc();
138 auto loopOp =
139 spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
140 loopOp.addEntryAndMergeBlock(rewriter);
141
142 OpBuilder::InsertionGuard guard(rewriter);
143 // Create the block for the header.
144 Block *header = rewriter.createBlock(&loopOp.getBody(),
145 getBlockIt(loopOp.getBody(), 1));
146 rewriter.setInsertionPointAfter(loopOp);
147
148 // Create the new induction variable to use.
149 Value adapLowerBound = adaptor.getLowerBound();
150 BlockArgument newIndVar =
151 header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
152 for (Value arg : adaptor.getInitArgs())
153 header->addArgument(arg.getType(), arg.getLoc());
154 Block *body = forOp.getBody();
155
156 // Apply signature conversion to the body of the forOp. It has a single
157 // block, with argument which is the induction variable. That has to be
158 // replaced with the new induction variable.
159 TypeConverter::SignatureConversion signatureConverter(
160 body->getNumArguments());
161 signatureConverter.remapInput(0, newIndVar);
162 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
163 signatureConverter.remapInput(i, header->getArgument(i));
164 body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
165 signatureConverter);
166
167 // Move the blocks from the forOp into the loopOp. This is the body of the
168 // loopOp.
169 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
170 getBlockIt(loopOp.getBody(), 2));
171
172 SmallVector<Value, 8> args(1, adaptor.getLowerBound());
173 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
174 // Branch into it from the entry.
175 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
176 spirv::BranchOp::create(rewriter, loc, header, args);
177
178 // Generate the rest of the loop header.
179 rewriter.setInsertionPointToEnd(header);
180 auto *mergeBlock = loopOp.getMergeBlock();
181 Value cmpOp;
182 if (forOp.getUnsignedCmp()) {
183 cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
184 newIndVar, adaptor.getUpperBound());
185 } else {
186 cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
187 newIndVar, adaptor.getUpperBound());
188 }
189
190 spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
191 ArrayRef<Value>(), mergeBlock,
192 ArrayRef<Value>());
193
194 // Generate instructions to increment the step of the induction variable and
195 // branch to the header.
196 Block *continueBlock = loopOp.getContinueBlock();
197 rewriter.setInsertionPointToEnd(continueBlock);
198
199 // Add the step to the induction variable and branch to the header.
200 Value updatedIndVar = spirv::IAddOp::create(
201 rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep());
202 spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);
203
204 // Infer the return types from the init operands. Vector type may get
205 // converted to CooperativeMatrix or to Vector type, to avoid having complex
206 // extra logic to figure out the right type we just infer it from the Init
207 // operands.
208 SmallVector<Type, 8> initTypes;
209 for (auto arg : adaptor.getInitArgs())
210 initTypes.push_back(arg.getType());
211 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
212 initTypes);
213 return success();
214 }
215};
216
217//===----------------------------------------------------------------------===//
218// scf::IfOp
219//===----------------------------------------------------------------------===//
220
221/// Pattern to convert a scf::IfOp within kernel functions into
222/// spirv::SelectionOp.
223struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
224 using SCFToSPIRVPattern::SCFToSPIRVPattern;
225
226 LogicalResult
227 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override {
229 // When lowering `scf::IfOp` we explicitly create a selection header block
230 // before the control flow diverges and a merge block where control flow
231 // subsequently converges.
232 auto loc = ifOp.getLoc();
233
234 // Compute return types.
235 SmallVector<Type, 8> returnTypes;
236 for (auto result : ifOp.getResults()) {
237 auto convertedType = typeConverter.convertType(result.getType());
238 if (!convertedType)
239 return rewriter.notifyMatchFailure(
240 loc,
241 llvm::formatv("failed to convert type '{0}'", result.getType()));
242
243 returnTypes.push_back(convertedType);
244 }
245
246 // Create `spirv.selection` operation, selection header block and merge
247 // block.
248 auto selectionOp = spirv::SelectionOp::create(
249 rewriter, loc, spirv::SelectionControl::None);
250 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
251 selectionOp.getBody().end());
252 spirv::MergeOp::create(rewriter, loc);
253
254 OpBuilder::InsertionGuard guard(rewriter);
255 auto *selectionHeaderBlock =
256 rewriter.createBlock(&selectionOp.getBody().front());
257
258 // Inline `then` region before the merge block and branch to it.
259 auto &thenRegion = ifOp.getThenRegion();
260 auto *thenBlock = &thenRegion.front();
261 rewriter.setInsertionPointToEnd(&thenRegion.back());
262 spirv::BranchOp::create(rewriter, loc, mergeBlock);
263 rewriter.inlineRegionBefore(thenRegion, mergeBlock);
264
265 auto *elseBlock = mergeBlock;
266 // If `else` region is not empty, inline that region before the merge block
267 // and branch to it.
268 if (!ifOp.getElseRegion().empty()) {
269 auto &elseRegion = ifOp.getElseRegion();
270 elseBlock = &elseRegion.front();
271 rewriter.setInsertionPointToEnd(&elseRegion.back());
272 spirv::BranchOp::create(rewriter, loc, mergeBlock);
273 rewriter.inlineRegionBefore(elseRegion, mergeBlock);
274 }
275
276 // Create a `spirv.BranchConditional` operation for selection header block.
277 rewriter.setInsertionPointToEnd(selectionHeaderBlock);
278 spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
279 thenBlock, ArrayRef<Value>(), elseBlock,
280 ArrayRef<Value>());
281
282 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
283 returnTypes);
284 return success();
285 }
286};
287
288//===----------------------------------------------------------------------===//
289// scf::YieldOp
290//===----------------------------------------------------------------------===//
291
292struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
293public:
294 using SCFToSPIRVPattern::SCFToSPIRVPattern;
295
296 LogicalResult
297 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const override {
299 ValueRange operands = adaptor.getOperands();
300
301 Operation *parent = terminatorOp->getParentOp();
302
303 // TODO: Implement conversion for the remaining `scf` ops.
304 if (parent->getDialect()->getNamespace() ==
305 scf::SCFDialect::getDialectNamespace() &&
306 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
307 return rewriter.notifyMatchFailure(
308 terminatorOp,
309 llvm::formatv("conversion not supported for parent op: '{0}'",
310 parent->getName()));
311
312 // If the region return values, store each value into the associated
313 // VariableOp created during lowering of the parent region.
314 if (!operands.empty()) {
315 auto &allocas = scfToSPIRVContext->outputVars[parent];
316 if (allocas.size() != operands.size())
317 return failure();
318
319 auto loc = terminatorOp.getLoc();
320 for (unsigned i = 0, e = operands.size(); i < e; i++)
321 spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
322 if (isa<spirv::LoopOp>(parent)) {
323 // For loops we also need to update the branch jumping back to the
324 // header.
325 auto br = cast<spirv::BranchOp>(
326 rewriter.getInsertionBlock()->getTerminator());
327 SmallVector<Value, 8> args(br.getBlockArguments());
328 args.append(operands.begin(), operands.end());
329 rewriter.setInsertionPoint(br);
330 spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
331 args);
332 rewriter.eraseOp(br);
333 }
334 }
335 rewriter.eraseOp(terminatorOp);
336 return success();
337 }
338};
339
340//===----------------------------------------------------------------------===//
341// scf::WhileOp
342//===----------------------------------------------------------------------===//
343
344struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
345 using SCFToSPIRVPattern::SCFToSPIRVPattern;
346
347 LogicalResult
348 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override {
350 auto loc = whileOp.getLoc();
351 auto loopOp =
352 spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
353 loopOp.addEntryAndMergeBlock(rewriter);
354
355 Region &beforeRegion = whileOp.getBefore();
356 Region &afterRegion = whileOp.getAfter();
357
358 if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
359 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
360 return rewriter.notifyMatchFailure(whileOp,
361 "Failed to convert region types");
362
363 OpBuilder::InsertionGuard guard(rewriter);
364
365 Block &entryBlock = *loopOp.getEntryBlock();
366 Block &beforeBlock = beforeRegion.front();
367 Block &afterBlock = afterRegion.front();
368 Block &mergeBlock = *loopOp.getMergeBlock();
369
370 auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
371 SmallVector<Value> condArgs;
372 if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
373 return failure();
374
375 Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
376 if (!conditionVal)
377 return failure();
378
379 auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
380 SmallVector<Value> yieldArgs;
381 if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
382 return failure();
383
384 // Move the while before block as the initial loop header block.
385 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
386 getBlockIt(loopOp.getBody(), 1));
387
388 // Move the while after block as the initial loop body block.
389 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
390 getBlockIt(loopOp.getBody(), 2));
391
392 // Jump from the loop entry block to the loop header block.
393 rewriter.setInsertionPointToEnd(&entryBlock);
394 spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());
395
396 auto condLoc = cond.getLoc();
397
398 SmallVector<Value> resultValues(condArgs.size());
399
400 // For other SCF ops, the scf.yield op yields the value for the whole SCF
401 // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
402 // local variables. But for the scf.while op, the scf.yield op yields a
403 // value for the before region, which may not matching the whole op's
404 // result. Instead, the scf.condition op returns values matching the whole
405 // op's results. So we need to create/load/store variables according to
406 // that.
407 for (const auto &it : llvm::enumerate(condArgs)) {
408 auto res = it.value();
409 auto i = it.index();
410 auto pointerType =
411 spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
412
413 // Create local variables before the scf.while op.
414 rewriter.setInsertionPoint(loopOp);
415 auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
416 spirv::StorageClass::Function,
417 /*initializer=*/nullptr);
418
419 // Load the final result values after the scf.while op.
420 rewriter.setInsertionPointAfter(loopOp);
421 auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
422 resultValues[i] = loadResult;
423
424 // Store the current iteration's result value.
425 rewriter.setInsertionPointToEnd(&beforeBlock);
426 spirv::StoreOp::create(rewriter, condLoc, alloc, res);
427 }
428
429 rewriter.setInsertionPointToEnd(&beforeBlock);
430 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
431 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, ValueRange());
432
433 // Convert the scf.yield op to a branch back to the header block.
434 rewriter.setInsertionPointToEnd(&afterBlock);
435 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
436 yieldArgs);
437
438 rewriter.replaceOp(whileOp, resultValues);
439 return success();
440 }
441};
442} // namespace
443
444//===----------------------------------------------------------------------===//
445// Public API
446//===----------------------------------------------------------------------===//
447
449 ScfToSPIRVContext &scfToSPIRVContext,
451 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
452 WhileOpConversion>(patterns.getContext(), typeConverter,
453 scfToSPIRVContext.getImpl());
454}
return success()
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
StringRef getNamespace() const
Definition Dialect.h:54
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
BlockListType::iterator iterator
Definition Region.h:52
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
static PointerType get(Type pointeeType, StorageClass storageClass)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns)
Collects a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations withi...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
DenseMap< Operation *, SmallVector< spirv::VariableOp, 8 > > outputVars
ScfToSPIRVContext()
We use ScfToSPIRVContext to store information about the lowering of the scf region that need to be us...
ScfToSPIRVContextImpl * getImpl()
Definition SCFToSPIRV.h:29