MLIR  16.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMP
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 /// Matches a block containing a "simple" reduction. The expected shape of the
36 /// block is as follows.
37 ///
38 /// ^bb(%arg0, %arg1):
39 /// %0 = OpTy(%arg0, %arg1)
40 /// scf.reduce.return %0
41 template <typename... OpTy>
42 static bool matchSimpleReduction(Block &block) {
43  if (block.empty() || llvm::hasSingleElement(block) ||
44  std::next(block.begin(), 2) != block.end())
45  return false;
46 
47  if (block.getNumArguments() != 2)
48  return false;
49 
50  SmallVector<Operation *, 4> combinerOps;
51  Value reducedVal = matchReduction({block.getArguments()[1]},
52  /*redPos=*/0, combinerOps);
53 
54  if (!reducedVal || !reducedVal.isa<BlockArgument>() ||
55  combinerOps.size() != 1)
56  return false;
57 
58  return isa<OpTy...>(combinerOps[0]) &&
59  isa<scf::ReduceReturnOp>(block.back()) &&
60  block.front().getOperands() == block.getArguments();
61 }
62 
63 /// Matches a block containing a select-based min/max reduction. The types of
64 /// select and compare operations are provided as template arguments. The
65 /// comparison predicates suitable for min and max are provided as function
66 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
67 /// compute the minimum and unset if it computes the maximum, otherwise it
68 /// remains unmodified. The expected shape of the block is as follows.
69 ///
70 /// ^bb(%arg0, %arg1):
71 /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
72 /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
73 /// scf.reduce.return %1
74 template <
75  typename CompareOpTy, typename SelectOpTy,
76  typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
77 static bool
78 matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
79  ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
80  static_assert(
82  "only arithmetic and llvm select ops are supported");
83 
84  // Expect exactly three operations in the block.
85  if (block.empty() || llvm::hasSingleElement(block) ||
86  std::next(block.begin(), 2) == block.end() ||
87  std::next(block.begin(), 3) != block.end())
88  return false;
89 
90  // Check op kinds.
91  auto compare = dyn_cast<CompareOpTy>(block.front());
92  auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
93  auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
94  if (!compare || !select || !terminator)
95  return false;
96 
97  // Block arguments must be compared.
98  if (compare->getOperands() != block.getArguments())
99  return false;
100 
101  // Detect whether the comparison is less-than or greater-than, otherwise bail.
102  bool isLess;
103  if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
104  isLess = true;
105  } else if (llvm::is_contained(greaterThanPredicates,
106  compare.getPredicate())) {
107  isLess = false;
108  } else {
109  return false;
110  }
111 
112  if (select.getCondition() != compare.getResult())
113  return false;
114 
115  // Detect if the operands are swapped between cmpf and select. Match the
116  // comparison type with the requested type or with the opposite of the
117  // requested type if the operands are swapped. Use generic accessors because
118  // std and LLVM versions of select have different operand names but identical
119  // positions.
120  constexpr unsigned kTrueValue = 1;
121  constexpr unsigned kFalseValue = 2;
122  bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
123  select.getOperand(kFalseValue) == compare.getRhs();
124  bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
125  select.getOperand(kFalseValue) == compare.getLhs();
126  if (!sameOperands && !swappedOperands)
127  return false;
128 
129  if (select.getResult() != terminator.getResult())
130  return false;
131 
132  // The reduction is a min if it uses less-than predicates with same operands
133  // or greather-than predicates with swapped operands. Similarly for max.
134  isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
135  return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
136 }
137 
138 /// Returns the float semantics for the given float type.
139 static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
140  if (type.isF16())
141  return llvm::APFloat::IEEEhalf();
142  if (type.isF32())
143  return llvm::APFloat::IEEEsingle();
144  if (type.isF64())
145  return llvm::APFloat::IEEEdouble();
146  if (type.isF128())
147  return llvm::APFloat::IEEEquad();
148  if (type.isBF16())
149  return llvm::APFloat::BFloat();
150  if (type.isF80())
151  return llvm::APFloat::x87DoubleExtended();
152  llvm_unreachable("unknown float type");
153 }
154 
155 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
156 /// (otherwise) for the given float type.
157 static Attribute minMaxValueForFloat(Type type, bool min) {
158  auto fltType = type.cast<FloatType>();
159  return FloatAttr::get(
160  type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
161 }
162 
163 /// Returns an attribute with the signed integer minimum (if `min` is set) or
164 /// the maximum value (otherwise) for the given integer type, regardless of its
165 /// signedness semantics (only the width is considered).
167  auto intType = type.cast<IntegerType>();
168  unsigned bitwidth = intType.getWidth();
169  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
170  : llvm::APInt::getSignedMaxValue(bitwidth));
171 }
172 
173 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
174 /// the maximum value (otherwise) for the given integer type, regardless of its
175 /// signedness semantics (only the width is considered).
177  auto intType = type.cast<IntegerType>();
178  unsigned bitwidth = intType.getWidth();
179  return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth)
180  : llvm::APInt::getAllOnesValue(bitwidth));
181 }
182 
183 /// Creates an OpenMP reduction declaration and inserts it into the provided
184 /// symbol table. The declaration has a constant initializer with the neutral
185 /// value `initValue`, and the reduction combiner carried over from `reduce`.
186 static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
187  SymbolTable &symbolTable,
188  scf::ReduceOp reduce,
189  Attribute initValue) {
190  OpBuilder::InsertionGuard guard(builder);
191  auto decl = builder.create<omp::ReductionDeclareOp>(
192  reduce.getLoc(), "__scf_reduction", reduce.getOperand().getType());
193  symbolTable.insert(decl);
194 
195  Type type = reduce.getOperand().getType();
196  builder.createBlock(&decl.getInitializerRegion(),
197  decl.getInitializerRegion().end(), {type},
198  {reduce.getOperand().getLoc()});
199  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
200  Value init =
201  builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
202  builder.create<omp::YieldOp>(reduce.getLoc(), init);
203 
204  Operation *terminator = &reduce.getRegion().front().back();
205  assert(isa<scf::ReduceReturnOp>(terminator) &&
206  "expected reduce op to be terminated by redure return");
207  builder.setInsertionPoint(terminator);
208  builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
209  terminator->getOperands());
210  builder.inlineRegionBefore(reduce.getRegion(), decl.getReductionRegion(),
211  decl.getReductionRegion().end());
212  return decl;
213 }
214 
215 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
216 /// using llvm.atomicrmw of the given kind.
217 static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
218  LLVM::AtomicBinOp atomicKind,
219  omp::ReductionDeclareOp decl,
220  scf::ReduceOp reduce) {
221  OpBuilder::InsertionGuard guard(builder);
222  Type type = reduce.getOperand().getType();
223  Type ptrType = LLVM::LLVMPointerType::get(type);
224  Location reduceOperandLoc = reduce.getOperand().getLoc();
225  builder.createBlock(&decl.getAtomicReductionRegion(),
226  decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
227  {reduceOperandLoc, reduceOperandLoc});
228  Block *atomicBlock = &decl.getAtomicReductionRegion().back();
229  builder.setInsertionPointToEnd(atomicBlock);
230  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(),
231  atomicBlock->getArgument(1));
232  builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind,
233  atomicBlock->getArgument(0), loaded,
234  LLVM::AtomicOrdering::monotonic);
235  builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
236  return decl;
237 }
238 
239 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
240 /// reduction and returns it. Recognizes common reductions in order to identify
241 /// the neutral value, necessary for the OpenMP declaration. If the reduction
242 /// cannot be recognized, returns null.
243 static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
244  scf::ReduceOp reduce) {
245  Operation *container = SymbolTable::getNearestSymbolTable(reduce);
246  SymbolTable symbolTable(container);
247 
248  // Insert reduction declarations in the symbol-table ancestor before the
249  // ancestor of the current insertion point.
250  Operation *insertionPoint = reduce;
251  while (insertionPoint->getParentOp() != container)
252  insertionPoint = insertionPoint->getParentOp();
253  OpBuilder::InsertionGuard guard(builder);
254  builder.setInsertionPoint(insertionPoint);
255 
256  assert(llvm::hasSingleElement(reduce.getRegion()) &&
257  "expected reduction region to have a single element");
258 
259  // Match simple binary reductions that can be expressed with atomicrmw.
260  Type type = reduce.getOperand().getType();
261  Block &reduction = reduce.getRegion().front();
262  if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
263  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
264  builder.getFloatAttr(type, 0.0));
265  return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
266  }
267  if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
268  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
269  builder.getIntegerAttr(type, 0));
270  return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
271  }
272  if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
273  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
274  builder.getIntegerAttr(type, 0));
275  return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
276  }
277  if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
278  omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
279  builder.getIntegerAttr(type, 0));
280  return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
281  }
282  if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
283  omp::ReductionDeclareOp decl = createDecl(
284  builder, symbolTable, reduce,
285  builder.getIntegerAttr(
286  type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth())));
287  return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
288  }
289 
290  // Match simple binary reductions that cannot be expressed with atomicrmw.
291  // TODO: add atomic region using cmpxchg (which needs atomic load to be
292  // available as an op).
293  if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
294  return createDecl(builder, symbolTable, reduce,
295  builder.getFloatAttr(type, 1.0));
296  }
297 
298  // Match select-based min/max reductions.
299  bool isMin;
300  if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
301  reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
302  {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
303  matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
304  reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
305  {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
306  return createDecl(builder, symbolTable, reduce,
307  minMaxValueForFloat(type, !isMin));
308  }
309  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
310  reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
311  {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
312  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
313  reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
314  {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
315  omp::ReductionDeclareOp decl = createDecl(
316  builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
317  return addAtomicRMW(builder,
319  decl, reduce);
320  }
321  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
322  reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
323  {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
324  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
325  reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
326  {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
327  omp::ReductionDeclareOp decl = createDecl(
328  builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
329  return addAtomicRMW(
330  builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
331  decl, reduce);
332  }
333 
334  return nullptr;
335 }
336 
337 namespace {
338 
339 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
341 
342  LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
343  PatternRewriter &rewriter) const override {
344  // Declare reductions.
345  // TODO: consider checking it here is already a compatible reduction
346  // declaration and use it instead of redeclaring.
347  SmallVector<Attribute> reductionDeclSymbols;
348  for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
349  omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
350  if (!decl)
351  return failure();
352  reductionDeclSymbols.push_back(
353  SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
354  }
355 
356  // Allocate reduction variables. Make sure the we don't overflow the stack
357  // with local `alloca`s by saving and restoring the stack pointer.
358  Location loc = parallelOp.getLoc();
359  Value one = rewriter.create<LLVM::ConstantOp>(
360  loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
361  SmallVector<Value> reductionVariables;
362  reductionVariables.reserve(parallelOp.getNumReductions());
363  for (Value init : parallelOp.getInitVals()) {
364  assert((LLVM::isCompatibleType(init.getType()) ||
365  init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
366  "cannot create a reduction variable if the type is not an LLVM "
367  "pointer element");
368  Value storage = rewriter.create<LLVM::AllocaOp>(
369  loc, LLVM::LLVMPointerType::get(init.getType()), one, 0);
370  rewriter.create<LLVM::StoreOp>(loc, init, storage);
371  reductionVariables.push_back(storage);
372  }
373 
374  // Replace the reduction operations contained in this loop. Must be done
375  // here rather than in a separate pattern to have access to the list of
376  // reduction variables.
377  for (auto pair :
378  llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
379  OpBuilder::InsertionGuard guard(rewriter);
380  scf::ReduceOp reduceOp = std::get<0>(pair);
381  rewriter.setInsertionPoint(reduceOp);
382  rewriter.replaceOpWithNewOp<omp::ReductionOp>(
383  reduceOp, reduceOp.getOperand(), std::get<1>(pair));
384  }
385 
386  // Create the parallel wrapper.
387  auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
388  {
389 
390  OpBuilder::InsertionGuard guard(rewriter);
391  rewriter.createBlock(&ompParallel.getRegion());
392 
393  // Replace the loop.
394  {
395  OpBuilder::InsertionGuard allocaGuard(rewriter);
396  auto loop = rewriter.create<omp::WsLoopOp>(
397  parallelOp.getLoc(), parallelOp.getLowerBound(),
398  parallelOp.getUpperBound(), parallelOp.getStep());
399  rewriter.create<omp::TerminatorOp>(loc);
400 
401  rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
402  loop.getRegion().begin());
403 
404  Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
405  loop.getRegion().begin()->begin());
406 
407  rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
408 
409  auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
410  TypeRange());
411  rewriter.create<omp::YieldOp>(loc, ValueRange());
412  Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
413  rewriter.mergeBlocks(ops, scopeBlock);
414  auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
415  rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
416  rewriter.replaceOpWithNewOp<memref::AllocaScopeReturnOp>(
417  oldYield, oldYield->getOperands());
418  if (!reductionVariables.empty()) {
419  loop.setReductionsAttr(
420  ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
421  loop.getReductionVarsMutable().append(reductionVariables);
422  }
423  }
424  }
425 
426  // Load loop results.
427  SmallVector<Value> results;
428  results.reserve(reductionVariables.size());
429  for (Value variable : reductionVariables) {
430  Value res = rewriter.create<LLVM::LoadOp>(loc, variable);
431  results.push_back(res);
432  }
433  rewriter.replaceOp(parallelOp, results);
434 
435  return success();
436  }
437 };
438 
439 /// Applies the conversion patterns in the given function.
440 static LogicalResult applyPatterns(ModuleOp module) {
441  ConversionTarget target(*module.getContext());
442  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
443  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
444  memref::MemRefDialect>();
445 
446  RewritePatternSet patterns(module.getContext());
447  patterns.add<ParallelOpLowering>(module.getContext());
448  FrozenRewritePatternSet frozen(std::move(patterns));
449  return applyPartialConversion(module, target, frozen);
450 }
451 
452 /// A pass converting SCF operations to OpenMP operations.
453 struct SCFToOpenMPPass : public impl::ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
454  /// Pass entry point.
455  void runOnOperation() override {
456  if (failed(applyPatterns(getOperation())))
457  signalPassFailure();
458  }
459 };
460 
461 } // namespace
462 
463 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() {
464  return std::make_unique<SCFToOpenMPPass>();
465 }
static constexpr const bool value
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
Definition: SCFToOpenMP.cpp:42
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::ReductionDeclareOp decl, scf::ReduceOp reduce)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
Definition: SCFToOpenMP.cpp:78
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:137
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
Operation & back()
Definition: Block.h:141
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator end()
Definition: Block.h:133
iterator begin()
Definition: Block.h:132
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:235
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:113
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
MLIRContext * getContext() const
Definition: Builders.h:54
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
Block & front()
Definition: Region.h:65
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:23
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:26
U cast() const
Definition: Types.h:280
bool isF32() const
Definition: Types.cpp:25
bool isF128() const
Definition: Types.cpp:28
bool isF16() const
Definition: Types.cpp:24
bool isF80() const
Definition: Types.cpp:27
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
bool isBF16() const
Definition: Types.cpp:23
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool isa() const
Definition: Value.h:90
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:839
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:59
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
std::unique_ptr< OperationPass< ModuleOp > > createConvertSCFToOpenMPPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356