MLIR  16.0.0git
Detensorize.cpp
Go to the documentation of this file.
1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
15 #include "mlir/IR/OpDefinition.h"
18 #include <iterator>
19 #include <memory>
20 #include <utility>
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
26  ValueRange inputs, Location loc) {
27  assert(inputs.size() == 1);
28  auto inputType = inputs[0].getType();
29  if (inputType.isa<TensorType>())
30  return nullptr;
31 
32  // A detensored value is converted back by creating a new tensor from its
33  // element(s).
34  return builder.create<tensor::FromElementsOp>(
35  loc, RankedTensorType::get({}, inputType), inputs[0]);
36 }
37 
38 namespace {
39 /// Defines the criteria a TensorType must follow in order to be considered
40 /// "detensorable".
41 ///
42 /// NOTE: For now, only 0-D tensors are supported.
43 ///
44 /// Returns true if tensorType can be detensored.
45 bool canBeDetensored(TensorType tensorType) {
46  return tensorType.hasRank() && tensorType.getRank() == 0;
47 }
48 
49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
50  GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
51  return genericOp &&
52  llvm::all_of(
53  genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
54  return !typeConverter.isLegal(opOperand->get().getType());
55  });
56 }
57 
58 /// A conversion patttern for detensoring `linalg.generic` ops.
59 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
60 public:
63  matchAndRewrite(GenericOp op, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const override {
65  Block *originalBlock = op->getBlock();
66 
67  // Gather some information about the op before inling its region.
68  Block *opEntryBlock = &*op.getRegion().begin();
69  YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
70 
71  // Split the op's region before the op. This way, we have a clear insertion
72  // point in which the op can be inlined.
73  Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
74  rewriter.inlineRegionBefore(op.getRegion(), newBlock);
75  // Now that op's region is inlined, the operands of its YieldOp are mapped
76  // to the materialized target values. Therefore, we can replace the op's
77  // uses with those of its YielOp's operands.
78  rewriter.replaceOp(op, yieldOp->getOperands());
79 
80  // No need for these intermediate blocks, merge them into 1.
81  rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
82  rewriter.mergeBlocks(newBlock, originalBlock, {});
83 
84  rewriter.eraseOp(&*Block::iterator(yieldOp));
85 
86  return success();
87  }
88 };
89 
90 /// A conversion pattern for detensoring internal (non-entry) blocks within a
91 /// function.
92 struct FunctionNonEntryBlockConversion
93  : public OpInterfaceConversionPattern<FunctionOpInterface> {
94  FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
95  DenseSet<BlockArgument> blockArgsToDetensor)
96  : OpInterfaceConversionPattern(converter, ctx),
97  blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
98 
100  matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
101  ConversionPatternRewriter &rewriter) const override {
102  rewriter.startRootUpdate(op);
103  Region &region = op.getBody();
105 
106  for (Block &block : llvm::drop_begin(region, 1)) {
107  conversions.emplace_back(block.getNumArguments());
108  TypeConverter::SignatureConversion &back = conversions.back();
109 
110  for (BlockArgument blockArgument : block.getArguments()) {
111  int idx = blockArgument.getArgNumber();
112 
113  if (blockArgsToDetensor.count(blockArgument))
114  back.addInputs(idx, {getTypeConverter()->convertType(
115  block.getArgumentTypes()[idx])});
116  else
117  back.addInputs(idx, {block.getArgumentTypes()[idx]});
118  }
119  }
120 
121  if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
122  conversions))) {
123  rewriter.cancelRootUpdate(op);
124  return failure();
125  }
126 
127  rewriter.finalizeRootUpdate(op);
128  return success();
129  }
130 
131 private:
132  const DenseSet<BlockArgument> blockArgsToDetensor;
133 };
134 
135 class DetensorizeTypeConverter : public TypeConverter {
136 public:
137  DetensorizeTypeConverter() {
138  addConversion([](Type type) { return type; });
139 
140  // A TensorType that can be detensored, is converted to the underlying
141  // element type.
142  addConversion([](TensorType tensorType) -> Type {
143  if (canBeDetensored(tensorType))
144  return tensorType.getElementType();
145 
146  return tensorType;
147  });
148 
149  // A tensor value is detensoried by extracting its element(s).
150  addTargetMaterialization([](OpBuilder &builder, Type type,
151  ValueRange inputs, Location loc) -> Value {
152  return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
153  });
154 
155  addSourceMaterialization(sourceMaterializationCallback);
156  addArgumentMaterialization(sourceMaterializationCallback);
157  }
158 };
159 
160 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
161 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
162  LinalgDetensorize() = default;
163 
164  class CostModel {
165  public:
166  virtual ~CostModel() = default;
167 
168  /// A cost model algorithm computes the following outputs:
169  ///
170  /// - opsToDetensor: the list of linalg ops that should be
171  /// detensored.
172  ///
173  /// - blockArgsToDetensor: since the operands and results of detensored
174  /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
175  /// from a BB argument and a linalg op's output can be passed to successor
176  /// BBs), we need to maintain the sub-set of arguments that should be
177  /// detensored (i.e. converted by typeConverter) for each affected BB.
178  ///
179  /// Example:
180  ///
181  /// For the following snippet:
182  /// ...
183  /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
184  /// %7 = linalg.init_tensor [] : tensor<i32>
185  /// %8 = linalg.generic #attrs
186  /// ins(%6, %6 : tensor<i32>, tensor<i32>)
187  /// outs(%7 : tensor<i32>) {
188  /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
189  /// %9 = arith.addi %arg0, %arg1 : i32
190  /// linalg.yield %9 : i32
191  /// } -> tensor<i32>
192  /// %10 = "some.op"(%9)
193  /// br ^bb2(%8 : tensor<i32>)
194  /// ...
195  ///
196  /// if the cost model decides that the linalg.generic op should be
197  /// detensored, then:
198  /// - opsToDetensor should be = {linalg.generic{add}}.
199  /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
200  virtual void compute(FunctionOpInterface func,
201  DetensorizeTypeConverter typeConverter,
202  DenseSet<Operation *> &opsToDetensor,
203  DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
204 
205  /// From the blockArgsToDetensor set computed by a CostModel
206  /// implementation, this method computes the corresponding branch op
207  /// detensoring. The result is a map from a branch op to a subset of indices
208  /// of its operands. The indices specify which of the branch op's operands
209  /// should be detensored.
210  ///
211  /// For the previous example, this method would compute: {bb2 -> {0}}.
212  static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
213  const DenseSet<BlockArgument> &blockArgsToDetensor) {
214  DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
215 
216  for (auto blockArgumentElem : blockArgsToDetensor) {
217  Block *block = blockArgumentElem.getOwner();
218 
219  for (PredecessorIterator pred = block->pred_begin();
220  pred != block->pred_end(); ++pred) {
221  BranchOpInterface terminator =
222  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
223  auto blockOperands =
224  terminator.getSuccessorOperands(pred.getSuccessorIndex());
225 
226  if (blockOperands.empty() ||
227  blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
228  continue;
229 
230  detensorableBranchOps[terminator].insert(
231  blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
232  }
233  }
234 
235  return detensorableBranchOps;
236  }
237  };
238 
239  /// Detensorize linalg ops involved in control-flow within a function.
240  ///
241  /// This model starts from BranchOps and CondBranchOps within a function. For
242  /// each such branch, the model then walks the use-def chain for the branch's
243  /// condition backwards in order to understand where the condition's value
244  /// comes from. If the condition value is (indirectly) computed by a linalg op
245  /// that can be detensored, the model then continues walking the use-def chain
246  /// in order to understand where the linalg op's operands come from. This
247  /// leads to discovering a "detensoring component". A detensoring component is
248  /// the set of operations + block arguments that are involved in control-flow
249  /// AND can be detensored.
250  class ControlFlowDetectionModel : public CostModel {
251  public:
252  void compute(FunctionOpInterface func,
253  DetensorizeTypeConverter typeConverter,
254  DenseSet<Operation *> &opsToDetensor,
255  DenseSet<BlockArgument> &blockArgsToDetensor) override {
256  SmallVector<Value> workList;
257 
258  func->walk([&](cf::CondBranchOp condBr) {
259  llvm::append_range(workList, condBr.getOperands());
260  });
261 
262  func->walk([&](cf::BranchOp br) {
263  llvm::append_range(workList, br.getOperands());
264  });
265 
266  DenseSet<Value> visitedValues;
267  DenseSet<Operation *> visitedOps;
268 
269  // For a (to-be-detesored) value, check if it "escapes" the block by being
270  // passed to terminator. If it does, then workList is updated with the
271  // corresponding argument to the successor block.
272  auto updateWorkListWithSuccessorArguments =
273  [&](Value value, BranchOpInterface terminator) {
274  if (!terminator)
275  return;
276 
277  for (auto operandIdx :
278  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
279  Value operand = terminator->getOperand(operandIdx);
280 
281  if (operand == value) {
282  auto succBlockArg =
283  terminator.getSuccessorBlockArgument(operandIdx);
284 
285  if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
286  workList.push_back(*succBlockArg);
287  }
288  }
289  };
290 
291  while (!workList.empty()) {
292  Value currentItem = workList.pop_back_val();
293 
294  if (!visitedValues.insert(currentItem).second)
295  continue;
296 
297  // 1 - Look forward:
298  // 1.1 - If currentItem escapes to one or more successors, add
299  // the corresponding successor arguments to workList.
300  updateWorkListWithSuccessorArguments(
301  currentItem, dyn_cast<BranchOpInterface>(
302  currentItem.getParentBlock()->getTerminator()));
303 
304  // 1.2 - For each user of currentItem, add the defined values to
305  // workList. This way, the user ops can be inspected later if they are
306  // detensorable and if so, their operands will be added to workList to
307  // potentially discover other parts of the detensorable component.
308  for (auto *user : currentItem.getUsers())
309  llvm::append_range(workList, user->getResults());
310 
311  // 2 - Look backward:
312  // 2.1 - The current item is defined by a block argument. If the owner
313  // block is a non-entry one, then:
314  // * Add the argument to blockArgsToDetensor.
315  // * Walk the use-def chain backwards to add each predecessor's
316  // terminator-operands corresponding to currentItem to workList.
317  if (currentItem.dyn_cast<BlockArgument>()) {
318  BlockArgument currentItemBlockArgument =
319  currentItem.cast<BlockArgument>();
320  Block *ownerBlock = currentItemBlockArgument.getOwner();
321 
322  // Function arguments are not detensored/converted.
323  if (&*ownerBlock->getParent()->begin() == ownerBlock)
324  continue;
325 
326  // This inner-block argument is involved in control-flow, it should be
327  // detensored.
328  blockArgsToDetensor.insert(currentItemBlockArgument);
329 
330  for (PredecessorIterator pred = ownerBlock->pred_begin();
331  pred != ownerBlock->pred_end(); ++pred) {
332  BranchOpInterface predTerminator =
333  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
334 
335  // TODO: For now, we give up if any of the control-flow components
336  // in a function is not detensorable. Fix that.
337  if (!predTerminator) {
338  opsToDetensor.clear();
339  blockArgsToDetensor.clear();
340  return;
341  }
342 
343  auto ownerBlockOperands =
344  predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
345 
346  if (ownerBlockOperands.empty() ||
347  ownerBlockOperands.isOperandProduced(
348  currentItemBlockArgument.getArgNumber()))
349  continue;
350 
351  // For each predecessor, add the value it passes to that argument to
352  // workList to find out how it's computed.
353  workList.push_back(
354  ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
355  }
356 
357  continue;
358  }
359 
360  Operation *currentItemDefiningOp = currentItem.getDefiningOp();
361 
362  if (!visitedOps.insert(currentItemDefiningOp).second)
363  continue;
364 
365  // 2.2 - The current item is computed by a GenericOp. If the op should
366  // be detensored, then:
367  // * Add it to opsToDetensor.
368  // * Add its operands to workList to discover other parts of the
369  // potentially detensorable component.
370  if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
371  // The op was encountered already, no need to inspect it again.
372  if (opsToDetensor.count(genericOp))
373  continue;
374 
375  // The op should not be detensored, give up on it but continue with
376  // discovering the rest of the control-flow component.
377  if (!shouldBeDetensored(genericOp, typeConverter)) {
378  continue;
379  }
380 
381  opsToDetensor.insert(genericOp);
382  llvm::append_range(workList, genericOp.getInputs());
383  continue;
384  }
385 
386  // 2.3 - The current item is the result of a FromElementsOp, it will be
387  // trivially detensored later as part of canonicalization patterns
388  // applied at the end of detensoring.
389  //
390  // Note: No need to check whether the result type of this op is
391  // detensorable since if it wasn't we wouldn't reach that point in the
392  // work list.
393  if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
394  continue;
395 
396  // 2.4 - The current item is the result of a scalar op, add all its
397  // operands to the work list.
398  if (llvm::all_of(
399  currentItemDefiningOp->getResultTypes(),
400  [&](Type resultType) { return resultType.isIntOrFloat(); }))
401  llvm::append_range(workList, currentItemDefiningOp->getOperands());
402  }
403 
404  // Since the cost model gives up on some ops (see the details of step 2.2
405  // above), block arguments that correspond to the values produced by those
406  // ops should not be detensored as well.
407 
408  DenseSet<BlockArgument> blockArgsToRemove;
409 
410  for (auto &blockArg : blockArgsToDetensor) {
411  Block *block = blockArg.getParentBlock();
412 
413  // For the potentially detensorable block argument, find the
414  // correpsonding operands in predecessor blocks.
415  for (PredecessorIterator pred = block->pred_begin();
416  pred != block->pred_end(); ++pred) {
417  BranchOpInterface terminator =
418  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
419  auto blockOperands =
420  terminator.getSuccessorOperands(pred.getSuccessorIndex());
421 
422  if (blockOperands.empty() ||
423  blockOperands.isOperandProduced(blockArg.getArgNumber()))
424  continue;
425 
426  Operation *definingOp =
427  blockOperands[blockArg.getArgNumber()].getDefiningOp();
428 
429  // If the operand is defined by a GenericOp that will not be
430  // detensored, then do not detensor the corresponding block argument.
431  if (isa_and_nonnull<GenericOp>(definingOp) &&
432  opsToDetensor.count(definingOp) == 0) {
433  blockArgsToRemove.insert(blockArg);
434  break;
435  }
436  }
437  }
438 
439  for (auto &blockArg : blockArgsToRemove) {
440  blockArgsToDetensor.erase(blockArg);
441  }
442  }
443  };
444 
445  /// Detensorize everything that can detensored.
446  class AggressiveDetensoringModel : public CostModel {
447  public:
448  void compute(FunctionOpInterface func,
449  DetensorizeTypeConverter typeConverter,
450  DenseSet<Operation *> &opsToDetensor,
451  DenseSet<BlockArgument> &blockArgsToDetensor) override {
452  func->walk([&](GenericOp genericOp) {
453  if (shouldBeDetensored(genericOp, typeConverter))
454  opsToDetensor.insert(genericOp);
455  });
456 
457  for (Block &block : llvm::drop_begin(func.getBody(), 1))
458  for (BlockArgument blockArgument : block.getArguments())
459  blockArgsToDetensor.insert(blockArgument);
460  }
461  };
462 
463  void runOnOperation() override {
464  MLIRContext *context = &getContext();
465  DetensorizeTypeConverter typeConverter;
466  RewritePatternSet patterns(context);
467  ConversionTarget target(*context);
468  DenseSet<Operation *> opsToDetensor;
469  DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
470  DenseSet<BlockArgument> blockArgsToDetensor;
471  FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());
472 
473  if (aggressiveMode.getValue()) {
474  AggressiveDetensoringModel costModel;
475  costModel.compute(funcOp, typeConverter, opsToDetensor,
476  blockArgsToDetensor);
477  } else {
478  ControlFlowDetectionModel costModel;
479  costModel.compute(funcOp, typeConverter, opsToDetensor,
480  blockArgsToDetensor);
481  }
482 
483  detensorableBranchOps =
484  CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
485 
486  target.addDynamicallyLegalOp<GenericOp>(
487  [&](GenericOp op) { return !opsToDetensor.count(op); });
488 
489  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
490  // A function is legal if all of its non-entry blocks are legal. We
491  // don't legalize the entry block (i.e. the function's signature)
492  // since detensoring can't happen along external calling convention
493  // boundaries, which we conservatively approximate as all function
494  // signatures.
495  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
496  Region &body = funcOp.getBody();
497  return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
498  return !llvm::any_of(
499  blockArgsToDetensor, [&](BlockArgument blockArgument) {
500  return blockArgument.getOwner() == &block &&
501  !typeConverter.isLegal(blockArgument.getType());
502  });
503  });
504  }
505 
507  isLegalForReturnOpTypeConversionPattern(op, typeConverter,
508  /*returnOpAlwaysLegal*/ true))
509  return true;
510 
511  if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
512  if (!detensorableBranchOps.count(branchOp))
513  return true;
514 
515  for (auto operandIdx : detensorableBranchOps[branchOp])
516  if (!typeConverter.isLegal(
517  branchOp->getOperand(operandIdx).getType()))
518  return false;
519 
520  return true;
521  }
522 
523  return false;
524  });
525 
526  patterns.add<DetensorizeGenericOp>(typeConverter, context);
527  patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
528  blockArgsToDetensor);
529  // Since non-entry block arguments get detensorized, we also need to
530  // update the control flow inside the function to reflect the correct
531  // types.
532  auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
533  int operandIdx) -> bool {
534  return detensorableBranchOps.count(branchOp) &&
535  detensorableBranchOps[branchOp].count(operandIdx);
536  };
537 
538  populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
539  shouldConvertBranchOperand);
540 
541  if (failed(
542  applyFullConversion(getOperation(), target, std::move(patterns))))
543  signalPassFailure();
544 
545  RewritePatternSet canonPatterns(context);
546  tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
547  if (failed(applyPatternsAndFoldGreedily(getOperation(),
548  std::move(canonPatterns))))
549  signalPassFailure();
550  }
551 };
552 } // namespace
553 
554 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
555  return std::make_unique<LinalgDetensorize>();
556 }
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
Block represents an ordered list of Operations.
Definition: Block.h:29
pred_iterator pred_end()
Definition: Block.h:224
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isLegalForReturnOpTypeConversionPattern(Operation *op, TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:312
user_range getUsers() const
Definition: Value.h:213
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:414
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region...
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:309
LogicalResult applyFullConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
OpListType::iterator iterator
Definition: Block.h:131
void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override
PatternRewriter hook for merging a block into another.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
U dyn_cast() const
Definition: Value.h:100
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
This class represents an argument of a Block.
Definition: Value.h:300
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
Type getType() const
Return the type of this value.
Definition: Value.h:118
Type conversion class.
type_range getType() const
Definition: ValueRange.cpp:46
Implement a predecessor iterator for blocks.
Definition: BlockSupport.h:49
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Type getElementType() const
Returns the element type of this tensor type.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class represents an operand of an operation.
Definition: Value.h:251
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
Definition: Detensorize.cpp:25
This class implements a pattern rewriter for use with ConversionPatterns.
U cast() const
Definition: Value.h:108
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
This class describes a specific conversion target.
std::unique_ptr< Pass > createLinalgDetensorizePass()
Create a pass to convert Linalg operations to equivalent operations that work on primitive types...
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
result_type_range getResultTypes()
Definition: Operation.h:345
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
pred_iterator pred_begin()
Definition: Block.h:221