MLIR  18.0.0git
Detensorize.cpp
Go to the documentation of this file.
1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/OpDefinition.h"
19 #include <iterator>
20 #include <memory>
21 #include <utility>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGDETENSORIZE
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
32  ValueRange inputs, Location loc) {
33  assert(inputs.size() == 1);
34  auto inputType = inputs[0].getType();
35  if (isa<TensorType>(inputType))
36  return nullptr;
37 
38  // A detensored value is converted back by creating a new tensor from its
39  // element(s).
40  return builder.create<tensor::FromElementsOp>(
41  loc, RankedTensorType::get({}, inputType), inputs[0]);
42 }
43 
44 namespace {
45 /// Defines the criteria a TensorType must follow in order to be considered
46 /// "detensorable".
47 ///
48 /// NOTE: For now, only 0-D tensors are supported.
49 ///
50 /// Returns true if tensorType can be detensored.
51 bool canBeDetensored(TensorType tensorType) {
52  return tensorType.hasRank() && tensorType.getRank() == 0;
53 }
54 
55 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
56  GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
57  return genericOp &&
58  llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
59  return !typeConverter.isLegal(opOperand.get().getType());
60  });
61 }
62 
63 /// A conversion pattern for detensoring `linalg.generic` ops.
64 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
65 public:
68  matchAndRewrite(GenericOp op, OpAdaptor adaptor,
69  ConversionPatternRewriter &rewriter) const override {
70  Block *originalBlock = op->getBlock();
71 
72  // Gather some information about the op before inlining its region.
73  Block *opEntryBlock = &*op.getRegion().begin();
74  YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
75 
76  // Split the op's region before the op. This way, we have a clear insertion
77  // point in which the op can be inlined.
78  Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
79  rewriter.inlineRegionBefore(op.getRegion(), newBlock);
80  // Now that op's region is inlined, the operands of its YieldOp are mapped
81  // to the materialized target values. Therefore, we can replace the op's
82  // uses with those of its YielOp's operands.
83  rewriter.replaceOp(op, yieldOp->getOperands());
84 
85  // No need for these intermediate blocks, merge them into 1.
86  rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
87  rewriter.mergeBlocks(newBlock, originalBlock, {});
88 
89  rewriter.eraseOp(&*Block::iterator(yieldOp));
90 
91  return success();
92  }
93 };
94 
95 /// A conversion pattern for detensoring internal (non-entry) blocks within a
96 /// function.
97 struct FunctionNonEntryBlockConversion
98  : public OpInterfaceConversionPattern<FunctionOpInterface> {
99  FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
100  DenseSet<BlockArgument> blockArgsToDetensor)
101  : OpInterfaceConversionPattern(converter, ctx),
102  blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
103 
105  matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
106  ConversionPatternRewriter &rewriter) const override {
107  rewriter.startRootUpdate(op);
108  Region &region = op.getFunctionBody();
110 
111  for (Block &block : llvm::drop_begin(region, 1)) {
112  conversions.emplace_back(block.getNumArguments());
113  TypeConverter::SignatureConversion &back = conversions.back();
114 
115  for (BlockArgument blockArgument : block.getArguments()) {
116  int idx = blockArgument.getArgNumber();
117 
118  if (blockArgsToDetensor.count(blockArgument))
119  back.addInputs(idx, {getTypeConverter()->convertType(
120  block.getArgumentTypes()[idx])});
121  else
122  back.addInputs(idx, {block.getArgumentTypes()[idx]});
123  }
124  }
125 
126  if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
127  conversions))) {
128  rewriter.cancelRootUpdate(op);
129  return failure();
130  }
131 
132  rewriter.finalizeRootUpdate(op);
133  return success();
134  }
135 
136 private:
137  const DenseSet<BlockArgument> blockArgsToDetensor;
138 };
139 
140 class DetensorizeTypeConverter : public TypeConverter {
141 public:
142  DetensorizeTypeConverter() {
143  addConversion([](Type type) { return type; });
144 
145  // A TensorType that can be detensored, is converted to the underlying
146  // element type.
147  addConversion([](TensorType tensorType) -> Type {
148  if (canBeDetensored(tensorType))
149  return tensorType.getElementType();
150 
151  return tensorType;
152  });
153 
154  // A tensor value is detensoried by extracting its element(s).
155  addTargetMaterialization([](OpBuilder &builder, Type type,
156  ValueRange inputs, Location loc) -> Value {
157  return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
158  });
159 
160  addSourceMaterialization(sourceMaterializationCallback);
161  addArgumentMaterialization(sourceMaterializationCallback);
162  }
163 };
164 
165 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
166 struct LinalgDetensorize
167  : public impl::LinalgDetensorizeBase<LinalgDetensorize> {
168  LinalgDetensorize() = default;
169 
170  class CostModel {
171  public:
172  virtual ~CostModel() = default;
173 
174  /// A cost model algorithm computes the following outputs:
175  ///
176  /// - opsToDetensor: the list of linalg ops that should be
177  /// detensored.
178  ///
179  /// - blockArgsToDetensor: since the operands and results of detensored
180  /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
181  /// from a BB argument and a linalg op's output can be passed to successor
182  /// BBs), we need to maintain the sub-set of arguments that should be
183  /// detensored (i.e. converted by typeConverter) for each affected BB.
184  ///
185  /// Example:
186  ///
187  /// For the following snippet:
188  /// ...
189  /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
190  /// %7 = tensor.empty() : tensor<i32>
191  /// %8 = linalg.generic #attrs
192  /// ins(%6, %6 : tensor<i32>, tensor<i32>)
193  /// outs(%7 : tensor<i32>) {
194  /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
195  /// %9 = arith.addi %arg0, %arg1 : i32
196  /// linalg.yield %9 : i32
197  /// } -> tensor<i32>
198  /// %10 = "some.op"(%9)
199  /// br ^bb2(%8 : tensor<i32>)
200  /// ...
201  ///
202  /// if the cost model decides that the linalg.generic op should be
203  /// detensored, then:
204  /// - opsToDetensor should be = {linalg.generic{add}}.
205  /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
206  virtual void compute(FunctionOpInterface func,
207  DetensorizeTypeConverter typeConverter,
208  DenseSet<Operation *> &opsToDetensor,
209  DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
210 
211  /// From the blockArgsToDetensor set computed by a CostModel
212  /// implementation, this method computes the corresponding branch op
213  /// detensoring. The result is a map from a branch op to a subset of indices
214  /// of its operands. The indices specify which of the branch op's operands
215  /// should be detensored.
216  ///
217  /// For the previous example, this method would compute: {bb2 -> {0}}.
218  static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
219  const DenseSet<BlockArgument> &blockArgsToDetensor) {
220  DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
221 
222  for (auto blockArgumentElem : blockArgsToDetensor) {
223  Block *block = blockArgumentElem.getOwner();
224 
225  for (PredecessorIterator pred = block->pred_begin();
226  pred != block->pred_end(); ++pred) {
227  BranchOpInterface terminator =
228  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
229  auto blockOperands =
230  terminator.getSuccessorOperands(pred.getSuccessorIndex());
231 
232  if (blockOperands.empty() ||
233  blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
234  continue;
235 
236  detensorableBranchOps[terminator].insert(
237  blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
238  }
239  }
240 
241  return detensorableBranchOps;
242  }
243  };
244 
245  /// Detensorize linalg ops involved in control-flow within a function.
246  ///
247  /// This model starts from BranchOps and CondBranchOps within a function. For
248  /// each such branch, the model then walks the use-def chain for the branch's
249  /// condition backwards in order to understand where the condition's value
250  /// comes from. If the condition value is (indirectly) computed by a linalg op
251  /// that can be detensored, the model then continues walking the use-def chain
252  /// in order to understand where the linalg op's operands come from. This
253  /// leads to discovering a "detensoring component". A detensoring component is
254  /// the set of operations + block arguments that are involved in control-flow
255  /// AND can be detensored.
256  class ControlFlowDetectionModel : public CostModel {
257  public:
258  void compute(FunctionOpInterface func,
259  DetensorizeTypeConverter typeConverter,
260  DenseSet<Operation *> &opsToDetensor,
261  DenseSet<BlockArgument> &blockArgsToDetensor) override {
262  SmallVector<Value> workList;
263 
264  func->walk([&](cf::CondBranchOp condBr) {
265  llvm::append_range(workList, condBr.getOperands());
266  });
267 
268  func->walk([&](cf::BranchOp br) {
269  llvm::append_range(workList, br.getOperands());
270  });
271 
272  DenseSet<Value> visitedValues;
273  DenseSet<Operation *> visitedOps;
274 
275  // For a (to-be-detesored) value, check if it "escapes" the block by being
276  // passed to terminator. If it does, then workList is updated with the
277  // corresponding argument to the successor block.
278  auto updateWorkListWithSuccessorArguments =
279  [&](Value value, BranchOpInterface terminator) {
280  if (!terminator)
281  return;
282 
283  for (auto operandIdx :
284  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
285  Value operand = terminator->getOperand(operandIdx);
286 
287  if (operand == value) {
288  auto succBlockArg =
289  terminator.getSuccessorBlockArgument(operandIdx);
290 
291  if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
292  workList.push_back(*succBlockArg);
293  }
294  }
295  };
296 
297  while (!workList.empty()) {
298  Value currentItem = workList.pop_back_val();
299 
300  if (!visitedValues.insert(currentItem).second)
301  continue;
302 
303  // 1 - Look forward:
304  // 1.1 - If currentItem escapes to one or more successors, add
305  // the corresponding successor arguments to workList.
306  updateWorkListWithSuccessorArguments(
307  currentItem, dyn_cast<BranchOpInterface>(
308  currentItem.getParentBlock()->getTerminator()));
309 
310  // 1.2 - For each user of currentItem, add the defined values to
311  // workList. This way, the user ops can be inspected later if they are
312  // detensorable and if so, their operands will be added to workList to
313  // potentially discover other parts of the detensorable component.
314  for (auto *user : currentItem.getUsers())
315  llvm::append_range(workList, user->getResults());
316 
317  // 2 - Look backward:
318  // 2.1 - The current item is defined by a block argument. If the owner
319  // block is a non-entry one, then:
320  // * Add the argument to blockArgsToDetensor.
321  // * Walk the use-def chain backwards to add each predecessor's
322  // terminator-operands corresponding to currentItem to workList.
323  if (dyn_cast<BlockArgument>(currentItem)) {
324  BlockArgument currentItemBlockArgument =
325  cast<BlockArgument>(currentItem);
326  Block *ownerBlock = currentItemBlockArgument.getOwner();
327 
328  // Function arguments are not detensored/converted.
329  if (&*ownerBlock->getParent()->begin() == ownerBlock)
330  continue;
331 
332  // This inner-block argument is involved in control-flow, it should be
333  // detensored.
334  blockArgsToDetensor.insert(currentItemBlockArgument);
335 
336  for (PredecessorIterator pred = ownerBlock->pred_begin();
337  pred != ownerBlock->pred_end(); ++pred) {
338  BranchOpInterface predTerminator =
339  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
340 
341  // TODO: For now, we give up if any of the control-flow components
342  // in a function is not detensorable. Fix that.
343  if (!predTerminator) {
344  opsToDetensor.clear();
345  blockArgsToDetensor.clear();
346  return;
347  }
348 
349  auto ownerBlockOperands =
350  predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
351 
352  if (ownerBlockOperands.empty() ||
353  ownerBlockOperands.isOperandProduced(
354  currentItemBlockArgument.getArgNumber()))
355  continue;
356 
357  // For each predecessor, add the value it passes to that argument to
358  // workList to find out how it's computed.
359  workList.push_back(
360  ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
361  }
362 
363  continue;
364  }
365 
366  Operation *currentItemDefiningOp = currentItem.getDefiningOp();
367 
368  if (!visitedOps.insert(currentItemDefiningOp).second)
369  continue;
370 
371  // 2.2 - The current item is computed by a GenericOp. If the op should
372  // be detensored, then:
373  // * Add it to opsToDetensor.
374  // * Add its operands to workList to discover other parts of the
375  // potentially detensorable component.
376  if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
377  // The op was encountered already, no need to inspect it again.
378  if (opsToDetensor.count(genericOp))
379  continue;
380 
381  // The op should not be detensored, give up on it but continue with
382  // discovering the rest of the control-flow component.
383  if (!shouldBeDetensored(genericOp, typeConverter)) {
384  continue;
385  }
386 
387  opsToDetensor.insert(genericOp);
388  llvm::append_range(workList, genericOp.getInputs());
389  continue;
390  }
391 
392  // 2.3 - The current item is the result of a FromElementsOp, it will be
393  // trivially detensored later as part of canonicalization patterns
394  // applied at the end of detensoring.
395  //
396  // Note: No need to check whether the result type of this op is
397  // detensorable since if it wasn't we wouldn't reach that point in the
398  // work list.
399  if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
400  continue;
401 
402  // 2.4 - The current item is the result of a scalar op, add all its
403  // operands to the work list.
404  if (llvm::all_of(
405  currentItemDefiningOp->getResultTypes(),
406  [&](Type resultType) { return resultType.isIntOrFloat(); }))
407  llvm::append_range(workList, currentItemDefiningOp->getOperands());
408  }
409 
410  // Since the cost model gives up on some ops (see the details of step 2.2
411  // above), block arguments that correspond to the values produced by those
412  // ops should not be detensored as well.
413 
414  DenseSet<BlockArgument> blockArgsToRemove;
415 
416  for (auto &blockArg : blockArgsToDetensor) {
417  Block *block = blockArg.getParentBlock();
418 
419  // For the potentially detensorable block argument, find the
420  // correpsonding operands in predecessor blocks.
421  for (PredecessorIterator pred = block->pred_begin();
422  pred != block->pred_end(); ++pred) {
423  BranchOpInterface terminator =
424  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
425  auto blockOperands =
426  terminator.getSuccessorOperands(pred.getSuccessorIndex());
427 
428  if (blockOperands.empty() ||
429  blockOperands.isOperandProduced(blockArg.getArgNumber()))
430  continue;
431 
432  Operation *definingOp =
433  blockOperands[blockArg.getArgNumber()].getDefiningOp();
434 
435  // If the operand is defined by a GenericOp that will not be
436  // detensored, then do not detensor the corresponding block argument.
437  if (isa_and_nonnull<GenericOp>(definingOp) &&
438  opsToDetensor.count(definingOp) == 0) {
439  blockArgsToRemove.insert(blockArg);
440  break;
441  }
442  }
443  }
444 
445  for (auto &blockArg : blockArgsToRemove) {
446  blockArgsToDetensor.erase(blockArg);
447  }
448  }
449  };
450 
451  /// Detensorize everything that can detensored.
452  class AggressiveDetensoringModel : public CostModel {
453  public:
454  void compute(FunctionOpInterface func,
455  DetensorizeTypeConverter typeConverter,
456  DenseSet<Operation *> &opsToDetensor,
457  DenseSet<BlockArgument> &blockArgsToDetensor) override {
458  func->walk([&](GenericOp genericOp) {
459  if (shouldBeDetensored(genericOp, typeConverter))
460  opsToDetensor.insert(genericOp);
461  });
462 
463  for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
464  for (BlockArgument blockArgument : block.getArguments())
465  blockArgsToDetensor.insert(blockArgument);
466  }
467  };
468 
469  void runOnOperation() override {
470  MLIRContext *context = &getContext();
471  DetensorizeTypeConverter typeConverter;
472  RewritePatternSet patterns(context);
473  ConversionTarget target(*context);
474  DenseSet<Operation *> opsToDetensor;
475  DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
476  DenseSet<BlockArgument> blockArgsToDetensor;
477  FunctionOpInterface funcOp = getOperation();
478 
479  if (funcOp.getFunctionBody().empty())
480  return;
481 
482  // Make sure the entry block of the function doesn't contain any Linalg ops.
483  // Otherwise, it may lead to the signature of the block being changed by the
484  // dialect conversion below, which would make the function op invalid
485  // because its type shouldn't change.
486  IRRewriter rewriter(funcOp->getContext());
487  Block *entryBlock = &funcOp.getFunctionBody().front();
488  Block *postEntryBlock =
489  rewriter.splitBlock(entryBlock, entryBlock->begin());
490  rewriter.setInsertionPointToStart(entryBlock);
491  auto branch =
492  rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
493 
494  if (aggressiveMode.getValue()) {
495  AggressiveDetensoringModel costModel;
496  costModel.compute(funcOp, typeConverter, opsToDetensor,
497  blockArgsToDetensor);
498  } else {
499  ControlFlowDetectionModel costModel;
500  costModel.compute(funcOp, typeConverter, opsToDetensor,
501  blockArgsToDetensor);
502  }
503 
504  detensorableBranchOps =
505  CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
506 
507  target.addDynamicallyLegalOp<GenericOp>(
508  [&](GenericOp op) { return !opsToDetensor.count(op); });
509 
510  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
511  // A function is legal if all of its non-entry blocks are legal. We
512  // don't legalize the entry block (i.e. the function's signature)
513  // since detensoring can't happen along external calling convention
514  // boundaries, which we conservatively approximate as all function
515  // signatures.
516  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
517  Region &body = funcOp.getFunctionBody();
518  return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
519  return !llvm::any_of(
520  blockArgsToDetensor, [&](BlockArgument blockArgument) {
521  return blockArgument.getOwner() == &block &&
522  !typeConverter.isLegal(blockArgument.getType());
523  });
524  });
525  }
526 
528  isLegalForReturnOpTypeConversionPattern(op, typeConverter,
529  /*returnOpAlwaysLegal*/ true))
530  return true;
531 
532  if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
533  if (!detensorableBranchOps.count(branchOp))
534  return true;
535 
536  for (auto operandIdx : detensorableBranchOps[branchOp])
537  if (!typeConverter.isLegal(
538  branchOp->getOperand(operandIdx).getType()))
539  return false;
540 
541  return true;
542  }
543 
544  return false;
545  });
546 
547  patterns.add<DetensorizeGenericOp>(typeConverter, context);
548  patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
549  blockArgsToDetensor);
550  // Since non-entry block arguments get detensorized, we also need to
551  // update the control flow inside the function to reflect the correct
552  // types.
553  auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
554  int operandIdx) -> bool {
555  return detensorableBranchOps.count(branchOp) &&
556  detensorableBranchOps[branchOp].count(operandIdx);
557  };
558 
559  populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
560  shouldConvertBranchOperand);
561 
562  if (failed(
563  applyFullConversion(getOperation(), target, std::move(patterns))))
564  signalPassFailure();
565 
566  RewritePatternSet canonPatterns(context);
567  tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
568  if (failed(applyPatternsAndFoldGreedily(getOperation(),
569  std::move(canonPatterns))))
570  signalPassFailure();
571 
572  // Get rid of the dummy entry block we created in the beginning to work
573  // around dialect conversion signature rewriting.
574  rewriter.eraseOp(branch);
575  rewriter.mergeBlocks(postEntryBlock, entryBlock);
576  }
577 };
578 } // namespace
579 
580 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
581  return std::make_unique<LinalgDetensorize>();
582 }
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
Definition: Detensorize.cpp:31
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
pred_iterator pred_begin()
Definition: Block.h:226
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator begin()
Definition: Block.h:136
pred_iterator pred_end()
Definition: Block.h:229
Location getUnknownLoc()
Definition: Builders.cpp:27
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
This class describes a specific conversion target.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Implement a predecessor iterator for blocks.
Definition: BlockSupport.h:51
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
iterator begin()
Definition: Region.h:55
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:224
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
std::unique_ptr< Pass > createLinalgDetensorizePass()
Create a pass to convert Linalg operations to equivalent operations that work on primitive types,...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isLegalForReturnOpTypeConversionPattern(Operation *op, TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26