MLIR  18.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 /// Matches a block containing a "simple" reduction. The expected shape of the
36 /// block is as follows.
37 ///
38 /// ^bb(%arg0, %arg1):
39 /// %0 = OpTy(%arg0, %arg1)
40 /// scf.reduce.return %0
41 template <typename... OpTy>
42 static bool matchSimpleReduction(Block &block) {
43  if (block.empty() || llvm::hasSingleElement(block) ||
44  std::next(block.begin(), 2) != block.end())
45  return false;
46 
47  if (block.getNumArguments() != 2)
48  return false;
49 
50  SmallVector<Operation *, 4> combinerOps;
51  Value reducedVal = matchReduction({block.getArguments()[1]},
52  /*redPos=*/0, combinerOps);
53 
54  if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
55  return false;
56 
57  return isa<OpTy...>(combinerOps[0]) &&
58  isa<scf::ReduceReturnOp>(block.back()) &&
59  block.front().getOperands() == block.getArguments();
60 }
61 
62 /// Matches a block containing a select-based min/max reduction. The types of
63 /// select and compare operations are provided as template arguments. The
64 /// comparison predicates suitable for min and max are provided as function
65 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
66 /// compute the minimum and unset if it computes the maximum, otherwise it
67 /// remains unmodified. The expected shape of the block is as follows.
68 ///
69 /// ^bb(%arg0, %arg1):
70 /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
71 /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
72 /// scf.reduce.return %1
73 template <
74  typename CompareOpTy, typename SelectOpTy,
75  typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
76 static bool
77 matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
78  ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
79  static_assert(
80  llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
81  "only arithmetic and llvm select ops are supported");
82 
83  // Expect exactly three operations in the block.
84  if (block.empty() || llvm::hasSingleElement(block) ||
85  std::next(block.begin(), 2) == block.end() ||
86  std::next(block.begin(), 3) != block.end())
87  return false;
88 
89  // Check op kinds.
90  auto compare = dyn_cast<CompareOpTy>(block.front());
91  auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
92  auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
93  if (!compare || !select || !terminator)
94  return false;
95 
96  // Block arguments must be compared.
97  if (compare->getOperands() != block.getArguments())
98  return false;
99 
100  // Detect whether the comparison is less-than or greater-than, otherwise bail.
101  bool isLess;
102  if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
103  isLess = true;
104  } else if (llvm::is_contained(greaterThanPredicates,
105  compare.getPredicate())) {
106  isLess = false;
107  } else {
108  return false;
109  }
110 
111  if (select.getCondition() != compare.getResult())
112  return false;
113 
114  // Detect if the operands are swapped between cmpf and select. Match the
115  // comparison type with the requested type or with the opposite of the
116  // requested type if the operands are swapped. Use generic accessors because
117  // std and LLVM versions of select have different operand names but identical
118  // positions.
119  constexpr unsigned kTrueValue = 1;
120  constexpr unsigned kFalseValue = 2;
121  bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
122  select.getOperand(kFalseValue) == compare.getRhs();
123  bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
124  select.getOperand(kFalseValue) == compare.getLhs();
125  if (!sameOperands && !swappedOperands)
126  return false;
127 
128  if (select.getResult() != terminator.getResult())
129  return false;
130 
131  // The reduction is a min if it uses less-than predicates with same operands
132  // or greather-than predicates with swapped operands. Similarly for max.
133  isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
134  return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
135 }
136 
137 /// Returns the float semantics for the given float type.
138 static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
139  if (type.isF16())
140  return llvm::APFloat::IEEEhalf();
141  if (type.isF32())
142  return llvm::APFloat::IEEEsingle();
143  if (type.isF64())
144  return llvm::APFloat::IEEEdouble();
145  if (type.isF128())
146  return llvm::APFloat::IEEEquad();
147  if (type.isBF16())
148  return llvm::APFloat::BFloat();
149  if (type.isF80())
150  return llvm::APFloat::x87DoubleExtended();
151  llvm_unreachable("unknown float type");
152 }
153 
154 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
155 /// (otherwise) for the given float type.
156 static Attribute minMaxValueForFloat(Type type, bool min) {
157  auto fltType = cast<FloatType>(type);
158  return FloatAttr::get(
159  type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
160 }
161 
162 /// Returns an attribute with the signed integer minimum (if `min` is set) or
163 /// the maximum value (otherwise) for the given integer type, regardless of its
164 /// signedness semantics (only the width is considered).
166  auto intType = cast<IntegerType>(type);
167  unsigned bitwidth = intType.getWidth();
168  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
169  : llvm::APInt::getSignedMaxValue(bitwidth));
170 }
171 
172 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
173 /// the maximum value (otherwise) for the given integer type, regardless of its
174 /// signedness semantics (only the width is considered).
176  auto intType = cast<IntegerType>(type);
177  unsigned bitwidth = intType.getWidth();
178  return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
179  : llvm::APInt::getAllOnes(bitwidth));
180 }
181 
182 /// Creates an OpenMP reduction declaration and inserts it into the provided
183 /// symbol table. The declaration has a constant initializer with the neutral
184 /// value `initValue`, and the reduction combiner carried over from `reduce`.
185 static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
186  SymbolTable &symbolTable,
187  scf::ReduceOp reduce,
188  Attribute initValue) {
189  OpBuilder::InsertionGuard guard(builder);
190  auto decl = builder.create<omp::ReductionDeclareOp>(
191  reduce.getLoc(), "__scf_reduction", reduce.getOperand().getType());
192  symbolTable.insert(decl);
193 
194  Type type = reduce.getOperand().getType();
195  builder.createBlock(&decl.getInitializerRegion(),
196  decl.getInitializerRegion().end(), {type},
197  {reduce.getOperand().getLoc()});
198  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
199  Value init =
200  builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
201  builder.create<omp::YieldOp>(reduce.getLoc(), init);
202 
203  Operation *terminator = &reduce.getRegion().front().back();
204  assert(isa<scf::ReduceReturnOp>(terminator) &&
205  "expected reduce op to be terminated by redure return");
206  builder.setInsertionPoint(terminator);
207  builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
208  terminator->getOperands());
209  builder.inlineRegionBefore(reduce.getRegion(), decl.getReductionRegion(),
210  decl.getReductionRegion().end());
211  return decl;
212 }
213 
214 /// Returns an LLVM pointer type with the given element type, or an opaque
215 /// pointer if 'useOpaquePointers' is true.
216 static LLVM::LLVMPointerType getPointerType(Type elementType,
217  bool useOpaquePointers) {
218  if (useOpaquePointers)
219  return LLVM::LLVMPointerType::get(elementType.getContext());
220  return LLVM::LLVMPointerType::get(elementType);
221 }
222 
223 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
224 /// using llvm.atomicrmw of the given kind.
225 static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
226  LLVM::AtomicBinOp atomicKind,
227  omp::ReductionDeclareOp decl,
228  scf::ReduceOp reduce,
229  bool useOpaquePointers) {
230  OpBuilder::InsertionGuard guard(builder);
231  Type type = reduce.getOperand().getType();
232  Type ptrType = getPointerType(type, useOpaquePointers);
233  Location reduceOperandLoc = reduce.getOperand().getLoc();
234  builder.createBlock(&decl.getAtomicReductionRegion(),
235  decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
236  {reduceOperandLoc, reduceOperandLoc});
237  Block *atomicBlock = &decl.getAtomicReductionRegion().back();
238  builder.setInsertionPointToEnd(atomicBlock);
239  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
240  atomicBlock->getArgument(1));
241  builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
242  atomicBlock->getArgument(0), loaded,
243  LLVM::AtomicOrdering::monotonic);
244  builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
245  return decl;
246 }
247 
248 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
249 /// reduction and returns it. Recognizes common reductions in order to identify
250 /// the neutral value, necessary for the OpenMP declaration. If the reduction
251 /// cannot be recognized, returns null.
252 static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
253  scf::ReduceOp reduce,
254  bool useOpaquePointers) {
256  SymbolTable symbolTable(container);
257 
258  // Insert reduction declarations in the symbol-table ancestor before the
259  // ancestor of the current insertion point.
260  Operation *insertionPoint = reduce;
261  while (insertionPoint->getParentOp() != container)
262  insertionPoint = insertionPoint->getParentOp();
263  OpBuilder::InsertionGuard guard(builder);
264  builder.setInsertionPoint(insertionPoint);
265 
266  assert(llvm::hasSingleElement(reduce.getRegion()) &&
267  "expected reduction region to have a single element");
268 
269  // Match simple binary reductions that can be expressed with atomicrmw.
270  Type type = reduce.getOperand().getType();
271  Block &reduction = reduce.getRegion().front();
272  if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
273  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
274  builder.getFloatAttr(type, 0.0));
275  return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
276  useOpaquePointers);
277  }
278  if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
279  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
280  builder.getIntegerAttr(type, 0));
281  return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
282  useOpaquePointers);
283  }
284  if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
285  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
286  builder.getIntegerAttr(type, 0));
287  return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
288  useOpaquePointers);
289  }
290  if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
291  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
292  builder.getIntegerAttr(type, 0));
293  return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
294  useOpaquePointers);
295  }
296  if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
297  omp::ReductionDeclareOp decl = createDecl(
298  builder, symbolTable, reduce,
299  builder.getIntegerAttr(
300  type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
301  return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
302  useOpaquePointers);
303  }
304 
305  // Match simple binary reductions that cannot be expressed with atomicrmw.
306  // TODO: add atomic region using cmpxchg (which needs atomic load to be
307  // available as an op).
308  if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
309  return createDecl(builder, symbolTable, reduce,
310  builder.getFloatAttr(type, 1.0));
311  }
312  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
313  return createDecl(builder, symbolTable, reduce,
314  builder.getIntegerAttr(type, 1));
315  }
316 
317  // Match select-based min/max reductions.
318  bool isMin;
319  if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
320  reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
321  {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
322  matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
323  reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
324  {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
325  return createDecl(builder, symbolTable, reduce,
326  minMaxValueForFloat(type, !isMin));
327  }
328  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
329  reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
330  {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
331  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
332  reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
333  {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
334  omp::ReductionDeclareOp decl = createDecl(
335  builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
336  return addAtomicRMW(builder,
338  decl, reduce, useOpaquePointers);
339  }
340  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
341  reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
342  {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
343  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
344  reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
345  {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
346  omp::ReductionDeclareOp decl = createDecl(
347  builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
348  return addAtomicRMW(
349  builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
350  decl, reduce, useOpaquePointers);
351  }
352 
353  return nullptr;
354 }
355 
356 namespace {
357 
358 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
359 
360  bool useOpaquePointers;
361 
362  ParallelOpLowering(MLIRContext *context, bool useOpaquePointers)
363  : OpRewritePattern<scf::ParallelOp>(context),
364  useOpaquePointers(useOpaquePointers) {}
365 
366  LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
367  PatternRewriter &rewriter) const override {
368  // Declare reductions.
369  // TODO: consider checking it here is already a compatible reduction
370  // declaration and use it instead of redeclaring.
371  SmallVector<Attribute> reductionDeclSymbols;
372  for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
373  omp::ReductionDeclareOp decl =
374  declareReduction(rewriter, reduce, useOpaquePointers);
375  if (!decl)
376  return failure();
377  reductionDeclSymbols.push_back(
378  SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
379  }
380 
381  // Allocate reduction variables. Make sure the we don't overflow the stack
382  // with local `alloca`s by saving and restoring the stack pointer.
383  Location loc = parallelOp.getLoc();
384  Value one = rewriter.create<LLVM::ConstantOp>(
385  loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
386  SmallVector<Value> reductionVariables;
387  reductionVariables.reserve(parallelOp.getNumReductions());
388  for (Value init : parallelOp.getInitVals()) {
389  assert((LLVM::isCompatibleType(init.getType()) ||
390  isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
391  "cannot create a reduction variable if the type is not an LLVM "
392  "pointer element");
393  Value storage = rewriter.create<LLVM::AllocaOp>(
394  loc, getPointerType(init.getType(), useOpaquePointers),
395  init.getType(), one, 0);
396  rewriter.create<LLVM::StoreOp>(loc, init, storage);
397  reductionVariables.push_back(storage);
398  }
399 
400  // Replace the reduction operations contained in this loop. Must be done
401  // here rather than in a separate pattern to have access to the list of
402  // reduction variables.
403  for (auto pair :
404  llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
405  OpBuilder::InsertionGuard guard(rewriter);
406  scf::ReduceOp reduceOp = std::get<0>(pair);
407  rewriter.setInsertionPoint(reduceOp);
408  rewriter.replaceOpWithNewOp<omp::ReductionOp>(
409  reduceOp, reduceOp.getOperand(), std::get<1>(pair));
410  }
411 
412  // Create the parallel wrapper.
413  auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
414  {
415 
416  OpBuilder::InsertionGuard guard(rewriter);
417  rewriter.createBlock(&ompParallel.getRegion());
418 
419  // Replace the loop.
420  {
421  OpBuilder::InsertionGuard allocaGuard(rewriter);
422  auto loop = rewriter.create<omp::WsLoopOp>(
423  parallelOp.getLoc(), parallelOp.getLowerBound(),
424  parallelOp.getUpperBound(), parallelOp.getStep());
425  rewriter.create<omp::TerminatorOp>(loc);
426 
427  rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
428  loop.getRegion().begin());
429 
430  Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
431  loop.getRegion().begin()->begin());
432 
433  rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
434 
435  auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
436  TypeRange());
437  rewriter.create<omp::YieldOp>(loc, ValueRange());
438  Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
439  rewriter.mergeBlocks(ops, scopeBlock);
440  auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
441  rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
442  rewriter.replaceOpWithNewOp<memref::AllocaScopeReturnOp>(
443  oldYield, oldYield->getOperands());
444  if (!reductionVariables.empty()) {
445  loop.setReductionsAttr(
446  ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
447  loop.getReductionVarsMutable().append(reductionVariables);
448  }
449  }
450  }
451 
452  // Load loop results.
453  SmallVector<Value> results;
454  results.reserve(reductionVariables.size());
455  for (auto [variable, type] :
456  llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
457  Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
458  results.push_back(res);
459  }
460  rewriter.replaceOp(parallelOp, results);
461 
462  return success();
463  }
464 };
465 
466 /// Applies the conversion patterns in the given function.
467 static LogicalResult applyPatterns(ModuleOp module, bool useOpaquePointers) {
468  ConversionTarget target(*module.getContext());
469  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
470  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
471  memref::MemRefDialect>();
472 
473  RewritePatternSet patterns(module.getContext());
474  patterns.add<ParallelOpLowering>(module.getContext(), useOpaquePointers);
475  FrozenRewritePatternSet frozen(std::move(patterns));
476  return applyPartialConversion(module, target, frozen);
477 }
478 
479 /// A pass converting SCF operations to OpenMP operations.
480 struct SCFToOpenMPPass
481  : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
482 
483  using Base::Base;
484 
485  /// Pass entry point.
486  void runOnOperation() override {
487  if (failed(applyPatterns(getOperation(), useOpaquePointers)))
488  signalPassFailure();
489  }
490 };
491 
492 } // namespace
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2444
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static LLVM::LLVMPointerType getPointerType(Type elementType, bool useOpaquePointers)
Returns an LLVM pointer type with the given element type, or an opaque pointer if 'useOpaquePointers'...
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
Definition: SCFToOpenMP.cpp:42
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::ReductionDeclareOp decl, scf::ReduceOp reduce, bool useOpaquePointers)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
Definition: SCFToOpenMP.cpp:77
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, bool useOpaquePointers)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:141
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation & back()
Definition: Block.h:145
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator end()
Definition: Block.h:137
iterator begin()
Definition: Block.h:136
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:51
bool isF128() const
Definition: Types.cpp:54
bool isF16() const
Definition: Types.cpp:49
bool isF80() const
Definition: Types.cpp:53
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:913
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:63
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357