MLIR  14.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "../PassDetail.h"
21 #include "mlir/Dialect/SCF/SCF.h"
24 #include "mlir/IR/SymbolTable.h"
26 
27 using namespace mlir;
28 
29 /// Matches a block containing a "simple" reduction. The expected shape of the
30 /// block is as follows.
31 ///
32 /// ^bb(%arg0, %arg1):
33 /// %0 = OpTy(%arg0, %arg1)
34 /// scf.reduce.return %0
35 template <typename... OpTy>
36 static bool matchSimpleReduction(Block &block) {
37  if (block.empty() || llvm::hasSingleElement(block) ||
38  std::next(block.begin(), 2) != block.end())
39  return false;
40 
41  if (block.getNumArguments() != 2)
42  return false;
43 
44  SmallVector<Operation *, 4> combinerOps;
45  Value reducedVal = matchReduction({block.getArguments()[1]},
46  /*redPos=*/0, combinerOps);
47 
48  if (!reducedVal || !reducedVal.isa<BlockArgument>() ||
49  combinerOps.size() != 1)
50  return false;
51 
52  return isa<OpTy...>(combinerOps[0]) &&
53  isa<scf::ReduceReturnOp>(block.back()) &&
54  block.front().getOperands() == block.getArguments();
55 }
56 
57 /// Matches a block containing a select-based min/max reduction. The types of
58 /// select and compare operations are provided as template arguments. The
59 /// comparison predicates suitable for min and max are provided as function
60 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
61 /// compute the minimum and unset if it computes the maximum, otherwise it
62 /// remains unmodified. The expected shape of the block is as follows.
63 ///
64 /// ^bb(%arg0, %arg1):
65 /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
66 /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
67 /// scf.reduce.return %1
68 template <
69  typename CompareOpTy, typename SelectOpTy,
70  typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
71 static bool
72 matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
73  ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
75  "only std and llvm select ops are supported");
76 
77  // Expect exactly three operations in the block.
78  if (block.empty() || llvm::hasSingleElement(block) ||
79  std::next(block.begin(), 2) == block.end() ||
80  std::next(block.begin(), 3) != block.end())
81  return false;
82 
83  // Check op kinds.
84  auto compare = dyn_cast<CompareOpTy>(block.front());
85  auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
86  auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
87  if (!compare || !select || !terminator)
88  return false;
89 
90  // Block arguments must be compared.
91  if (compare->getOperands() != block.getArguments())
92  return false;
93 
94  // Detect whether the comparison is less-than or greater-than, otherwise bail.
95  bool isLess;
96  if (llvm::find(lessThanPredicates, compare.getPredicate()) !=
97  lessThanPredicates.end()) {
98  isLess = true;
99  } else if (llvm::find(greaterThanPredicates, compare.getPredicate()) !=
100  greaterThanPredicates.end()) {
101  isLess = false;
102  } else {
103  return false;
104  }
105 
106  if (select.getCondition() != compare.getResult())
107  return false;
108 
109  // Detect if the operands are swapped between cmpf and select. Match the
110  // comparison type with the requested type or with the opposite of the
111  // requested type if the operands are swapped. Use generic accessors because
112  // std and LLVM versions of select have different operand names but identical
113  // positions.
114  constexpr unsigned kTrueValue = 1;
115  constexpr unsigned kFalseValue = 2;
116  bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
117  select.getOperand(kFalseValue) == compare.getRhs();
118  bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
119  select.getOperand(kFalseValue) == compare.getLhs();
120  if (!sameOperands && !swappedOperands)
121  return false;
122 
123  if (select.getResult() != terminator.getResult())
124  return false;
125 
126  // The reduction is a min if it uses less-than predicates with same operands
127  // or greather-than predicates with swapped operands. Similarly for max.
128  isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
129  return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
130 }
131 
132 /// Returns the float semantics for the given float type.
133 static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
134  if (type.isF16())
135  return llvm::APFloat::IEEEhalf();
136  if (type.isF32())
137  return llvm::APFloat::IEEEsingle();
138  if (type.isF64())
139  return llvm::APFloat::IEEEdouble();
140  if (type.isF128())
141  return llvm::APFloat::IEEEquad();
142  if (type.isBF16())
143  return llvm::APFloat::BFloat();
144  if (type.isF80())
145  return llvm::APFloat::x87DoubleExtended();
146  llvm_unreachable("unknown float type");
147 }
148 
149 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
150 /// (otherwise) for the given float type.
151 static Attribute minMaxValueForFloat(Type type, bool min) {
152  auto fltType = type.cast<FloatType>();
153  return FloatAttr::get(
154  type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
155 }
156 
157 /// Returns an attribute with the signed integer minimum (if `min` is set) or
158 /// the maximum value (otherwise) for the given integer type, regardless of its
159 /// signedness semantics (only the width is considered).
161  auto intType = type.cast<IntegerType>();
162  unsigned bitwidth = intType.getWidth();
163  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
164  : llvm::APInt::getSignedMaxValue(bitwidth));
165 }
166 
167 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
168 /// the maximum value (otherwise) for the given integer type, regardless of its
169 /// signedness semantics (only the width is considered).
171  auto intType = type.cast<IntegerType>();
172  unsigned bitwidth = intType.getWidth();
173  return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth)
174  : llvm::APInt::getAllOnesValue(bitwidth));
175 }
176 
177 /// Creates an OpenMP reduction declaration and inserts it into the provided
178 /// symbol table. The declaration has a constant initializer with the neutral
179 /// value `initValue`, and the reduction combiner carried over from `reduce`.
180 static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
181  SymbolTable &symbolTable,
182  scf::ReduceOp reduce,
183  Attribute initValue) {
184  OpBuilder::InsertionGuard guard(builder);
185  auto decl = builder.create<omp::ReductionDeclareOp>(
186  reduce.getLoc(), "__scf_reduction", reduce.getOperand().getType());
187  symbolTable.insert(decl);
188 
189  Type type = reduce.getOperand().getType();
190  builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(),
191  {type}, {reduce.getOperand().getLoc()});
192  builder.setInsertionPointToEnd(&decl.initializerRegion().back());
193  Value init =
194  builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
195  builder.create<omp::YieldOp>(reduce.getLoc(), init);
196 
197  Operation *terminator = &reduce.getRegion().front().back();
198  assert(isa<scf::ReduceReturnOp>(terminator) &&
199  "expected reduce op to be terminated by redure return");
200  builder.setInsertionPoint(terminator);
201  builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
202  terminator->getOperands());
203  builder.inlineRegionBefore(reduce.getRegion(), decl.reductionRegion(),
204  decl.reductionRegion().end());
205  return decl;
206 }
207 
208 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
209 /// using llvm.atomicrmw of the given kind.
210 static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
211  LLVM::AtomicBinOp atomicKind,
212  omp::ReductionDeclareOp decl,
213  scf::ReduceOp reduce) {
214  OpBuilder::InsertionGuard guard(builder);
215  Type type = reduce.getOperand().getType();
216  Type ptrType = LLVM::LLVMPointerType::get(type);
217  Location reduceOperandLoc = reduce.getOperand().getLoc();
218  builder.createBlock(&decl.atomicReductionRegion(),
219  decl.atomicReductionRegion().end(), {ptrType, ptrType},
220  {reduceOperandLoc, reduceOperandLoc});
221  Block *atomicBlock = &decl.atomicReductionRegion().back();
222  builder.setInsertionPointToEnd(atomicBlock);
223  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(),
224  atomicBlock->getArgument(1));
225  builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind,
226  atomicBlock->getArgument(0), loaded,
227  LLVM::AtomicOrdering::monotonic);
228  builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
229  return decl;
230 }
231 
232 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
233 /// reduction and returns it. Recognizes common reductions in order to identify
234 /// the neutral value, necessary for the OpenMP declaration. If the reduction
235 /// cannot be recognized, returns null.
236 static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
237  scf::ReduceOp reduce) {
238  Operation *container = SymbolTable::getNearestSymbolTable(reduce);
239  SymbolTable symbolTable(container);
240 
241  // Insert reduction declarations in the symbol-table ancestor before the
242  // ancestor of the current insertion point.
243  Operation *insertionPoint = reduce;
244  while (insertionPoint->getParentOp() != container)
245  insertionPoint = insertionPoint->getParentOp();
246  OpBuilder::InsertionGuard guard(builder);
247  builder.setInsertionPoint(insertionPoint);
248 
249  assert(llvm::hasSingleElement(reduce.getRegion()) &&
250  "expected reduction region to have a single element");
251 
252  // Match simple binary reductions that can be expressed with atomicrmw.
253  Type type = reduce.getOperand().getType();
254  Block &reduction = reduce.getRegion().front();
255  if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
256  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
257  builder.getFloatAttr(type, 0.0));
258  return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
259  }
260  if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
261  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
262  builder.getIntegerAttr(type, 0));
263  return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
264  }
265  if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
266  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
267  builder.getIntegerAttr(type, 0));
268  return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
269  }
270  if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
271  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
272  builder.getIntegerAttr(type, 0));
273  return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
274  }
275  if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
276  omp::ReductionDeclareOp decl = createDecl(
277  builder, symbolTable, reduce,
278  builder.getIntegerAttr(
279  type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth())));
280  return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
281  }
282 
283  // Match simple binary reductions that cannot be expressed with atomicrmw.
284  // TODO: add atomic region using cmpxchg (which needs atomic load to be
285  // available as an op).
286  if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
287  return createDecl(builder, symbolTable, reduce,
288  builder.getFloatAttr(type, 1.0));
289  }
290 
291  // Match select-based min/max reductions.
292  bool isMin;
293  if (matchSelectReduction<arith::CmpFOp, SelectOp>(
294  reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
295  {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
296  matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
297  reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
298  {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
299  return createDecl(builder, symbolTable, reduce,
300  minMaxValueForFloat(type, !isMin));
301  }
302  if (matchSelectReduction<arith::CmpIOp, SelectOp>(
303  reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
304  {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
305  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
306  reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
307  {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
308  omp::ReductionDeclareOp decl = createDecl(
309  builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
310  return addAtomicRMW(builder,
312  decl, reduce);
313  }
314  if (matchSelectReduction<arith::CmpIOp, SelectOp>(
315  reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
316  {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
317  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
318  reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
319  {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
320  omp::ReductionDeclareOp decl = createDecl(
321  builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
322  return addAtomicRMW(
323  builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
324  decl, reduce);
325  }
326 
327  return nullptr;
328 }
329 
330 namespace {
331 
332 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
334 
335  LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
336  PatternRewriter &rewriter) const override {
337  // Replace SCF yield with OpenMP yield.
338  {
339  OpBuilder::InsertionGuard guard(rewriter);
340  rewriter.setInsertionPointToEnd(parallelOp.getBody());
341  assert(llvm::hasSingleElement(parallelOp.getRegion()) &&
342  "expected scf.parallel to have one block");
343  rewriter.replaceOpWithNewOp<omp::YieldOp>(
344  parallelOp.getBody()->getTerminator(), ValueRange());
345  }
346 
347  // Declare reductions.
348  // TODO: consider checking it here is already a compatible reduction
349  // declaration and use it instead of redeclaring.
350  SmallVector<Attribute> reductionDeclSymbols;
351  for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
352  omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
353  if (!decl)
354  return failure();
355  reductionDeclSymbols.push_back(
356  SymbolRefAttr::get(rewriter.getContext(), decl.sym_name()));
357  }
358 
359  // Allocate reduction variables. Make sure the we don't overflow the stack
360  // with local `alloca`s by saving and restoring the stack pointer.
361  Location loc = parallelOp.getLoc();
362  Value one = rewriter.create<LLVM::ConstantOp>(
363  loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
364  SmallVector<Value> reductionVariables;
365  reductionVariables.reserve(parallelOp.getNumReductions());
366  Value token = rewriter.create<LLVM::StackSaveOp>(
367  loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)));
368  for (Value init : parallelOp.getInitVals()) {
369  assert((LLVM::isCompatibleType(init.getType()) ||
370  init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
371  "cannot create a reduction variable if the type is not an LLVM "
372  "pointer element");
373  Value storage = rewriter.create<LLVM::AllocaOp>(
374  loc, LLVM::LLVMPointerType::get(init.getType()), one, 0);
375  rewriter.create<LLVM::StoreOp>(loc, init, storage);
376  reductionVariables.push_back(storage);
377  }
378 
379  // Replace the reduction operations contained in this loop. Must be done
380  // here rather than in a separate pattern to have access to the list of
381  // reduction variables.
382  for (auto pair :
383  llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
384  OpBuilder::InsertionGuard guard(rewriter);
385  scf::ReduceOp reduceOp = std::get<0>(pair);
386  rewriter.setInsertionPoint(reduceOp);
387  rewriter.replaceOpWithNewOp<omp::ReductionOp>(
388  reduceOp, reduceOp.getOperand(), std::get<1>(pair));
389  }
390 
391  // Create the parallel wrapper.
392  auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
393  {
394  OpBuilder::InsertionGuard guard(rewriter);
395  rewriter.createBlock(&ompParallel.region());
396 
397  // Replace SCF yield with OpenMP yield.
398  {
399  OpBuilder::InsertionGuard innerGuard(rewriter);
400  rewriter.setInsertionPointToEnd(parallelOp.getBody());
401  assert(llvm::hasSingleElement(parallelOp.getRegion()) &&
402  "expected scf.parallel to have one block");
403  rewriter.replaceOpWithNewOp<omp::YieldOp>(
404  parallelOp.getBody()->getTerminator(), ValueRange());
405  }
406 
407  // Replace the loop.
408  auto loop = rewriter.create<omp::WsLoopOp>(
409  parallelOp.getLoc(), parallelOp.getLowerBound(),
410  parallelOp.getUpperBound(), parallelOp.getStep());
411  rewriter.create<omp::TerminatorOp>(loc);
412 
413  rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(),
414  loop.region().begin());
415  if (!reductionVariables.empty()) {
416  loop.reductionsAttr(
417  ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
418  loop.reduction_varsMutable().append(reductionVariables);
419  }
420  }
421 
422  // Load loop results.
423  SmallVector<Value> results;
424  results.reserve(reductionVariables.size());
425  for (Value variable : reductionVariables) {
426  Value res = rewriter.create<LLVM::LoadOp>(loc, variable);
427  results.push_back(res);
428  }
429  rewriter.replaceOp(parallelOp, results);
430 
431  rewriter.create<LLVM::StackRestoreOp>(loc, token);
432  return success();
433  }
434 };
435 
436 /// Applies the conversion patterns in the given function.
437 static LogicalResult applyPatterns(ModuleOp module) {
438  ConversionTarget target(*module.getContext());
439  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
440  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect>();
441 
442  RewritePatternSet patterns(module.getContext());
443  patterns.add<ParallelOpLowering>(module.getContext());
444  FrozenRewritePatternSet frozen(std::move(patterns));
445  return applyPartialConversion(module, target, frozen);
446 }
447 
448 /// A pass converting SCF operations to OpenMP operations.
449 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
450  /// Pass entry point.
451  void runOnOperation() override {
452  if (failed(applyPatterns(getOperation())))
453  signalPassFailure();
454  }
455 };
456 
457 } // namespace
458 
459 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() {
460  return std::make_unique<SCFToOpenMPPass>();
461 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
bool isF32() const
Definition: Types.cpp:23
iterator begin()
Definition: Block.h:134
MLIRContext * getContext() const
Definition: Builders.h:54
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Operation & back()
Definition: Block.h:143
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
This class represents a frozen set of patterns that can be processed by a pattern applicator...
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
std::unique_ptr< OperationPass< ModuleOp > > createConvertSCFToOpenMPPass()
Operation & front()
Definition: Block.h:144
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
Definition: SCFToOpenMP.cpp:72
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static LLVMPointerType get(Type pointee, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:165
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
Definition: SCFToOpenMP.cpp:36
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:752
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool isF80() const
Definition: Types.cpp:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
iterator end()
Definition: Block.h:135
unsigned getNumArguments()
Definition: Block.h:119
bool isF16() const
Definition: Types.cpp:22
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation *> &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
bool isF128() const
Definition: Types.cpp:26
BlockArgListType getArguments()
Definition: Block.h:76
This class represents an argument of a Block.
Definition: Value.h:298
static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
void addIllegalOp()
Register the given operation as illegal, i.e.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
bool isF64() const
Definition: Types.cpp:24
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
bool isa() const
Definition: Value.h:89
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it...
int compare(Fraction x, Fraction y)
Three-way comparison between two fractions.
Definition: Fraction.h:46
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation&#39;s number falls into rangeToKeep...
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::ReductionDeclareOp decl, scf::ReduceOp reduce)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm...
This class describes a specific conversion target.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:429
bool isBF16() const
Definition: Types.cpp:21
U cast() const
Definition: Types.h:250
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)