MLIR  14.0.0git
Detensorize.cpp
Go to the documentation of this file.
1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
14 #include "mlir/IR/OpDefinition.h"
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
25  ValueRange inputs, Location loc) {
26  assert(inputs.size() == 1);
27  auto inputType = inputs[0].getType();
28  if (inputType.isa<TensorType>())
29  return nullptr;
30 
31  // A detensored value is converted back by creating a new tensor from its
32  // element(s).
33  return builder.create<tensor::FromElementsOp>(
34  loc, RankedTensorType::get({}, inputType), inputs[0]);
35 }
36 
37 namespace {
38 /// Defines the criteria a TensorType must follow in order to be considered
39 /// "detensorable".
40 ///
41 /// NOTE: For now, only 0-D tensors are supported.
42 ///
43 /// Returns true if tensorType can be detensored.
44 bool canBeDetensored(TensorType tensorType) {
45  return tensorType.hasRank() && tensorType.getRank() == 0;
46 }
47 
48 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
49  GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
50  return genericOp &&
51  llvm::all_of(
52  genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
53  return !typeConverter.isLegal(opOperand->get().getType());
54  });
55 }
56 
57 /// A conversion patttern for detensoring `linalg.generic` ops.
58 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
59 public:
62  matchAndRewrite(GenericOp op, OpAdaptor adaptor,
63  ConversionPatternRewriter &rewriter) const override {
64  Block *originalBlock = op->getBlock();
65 
66  // Gather some information about the op before inling its region.
67  Block *opEntryBlock = &*op.region().begin();
68  YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
69 
70  // Split the op's region before the op. This way, we have a clear insertion
71  // point in which the op can be inlined.
72  Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
73  rewriter.inlineRegionBefore(op.region(), newBlock);
74  // Now that op's region is inlined, the operands of its YieldOp are mapped
75  // to the materialized target values. Therefore, we can replace the op's
76  // uses with those of its YielOp's operands.
77  rewriter.replaceOp(op, yieldOp->getOperands());
78 
79  // No need for these intermediate blocks, merge them into 1.
80  rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
81  rewriter.mergeBlocks(newBlock, originalBlock, {});
82 
83  rewriter.eraseOp(&*Block::iterator(yieldOp));
84 
85  return success();
86  }
87 };
88 
89 /// A conversion pattern for detensoring internal (non-entry) blocks within a
90 /// function.
91 struct FunctionNonEntryBlockConversion
92  : public OpInterfaceConversionPattern<FunctionOpInterface> {
93  FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
94  DenseSet<BlockArgument> blockArgsToDetensor)
95  : OpInterfaceConversionPattern(converter, ctx),
96  blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
97 
99  matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
100  ConversionPatternRewriter &rewriter) const override {
101  rewriter.startRootUpdate(op);
102  Region &region = op.getBody();
104 
105  for (Block &block : llvm::drop_begin(region, 1)) {
106  conversions.emplace_back(block.getNumArguments());
107  TypeConverter::SignatureConversion &back = conversions.back();
108 
109  for (BlockArgument blockArgument : block.getArguments()) {
110  int idx = blockArgument.getArgNumber();
111 
112  if (blockArgsToDetensor.count(blockArgument))
113  back.addInputs(idx, {getTypeConverter()->convertType(
114  block.getArgumentTypes()[idx])});
115  else
116  back.addInputs(idx, {block.getArgumentTypes()[idx]});
117  }
118  }
119 
120  if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
121  conversions))) {
122  rewriter.cancelRootUpdate(op);
123  return failure();
124  }
125 
126  rewriter.finalizeRootUpdate(op);
127  return success();
128  }
129 
130 private:
131  const DenseSet<BlockArgument> blockArgsToDetensor;
132 };
133 
134 class DetensorizeTypeConverter : public TypeConverter {
135 public:
136  DetensorizeTypeConverter() {
137  addConversion([](Type type) { return type; });
138 
139  // A TensorType that can be detensored, is converted to the underlying
140  // element type.
141  addConversion([](TensorType tensorType) -> Type {
142  if (canBeDetensored(tensorType))
143  return tensorType.getElementType();
144 
145  return tensorType;
146  });
147 
148  // A tensor value is detensoried by extracting its element(s).
149  addTargetMaterialization([](OpBuilder &builder, Type type,
150  ValueRange inputs, Location loc) -> Value {
151  return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
152  });
153 
154  addSourceMaterialization(sourceMaterializationCallback);
155  addArgumentMaterialization(sourceMaterializationCallback);
156  }
157 };
158 
159 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
160 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
161  LinalgDetensorize() = default;
162 
163  class CostModel {
164  public:
165  virtual ~CostModel() = default;
166 
167  /// A cost model algorithm computes the following outputs:
168  ///
169  /// - opsToDetensor: the list of linalg ops that should be
170  /// detensored.
171  ///
172  /// - blockArgsToDetensor: since the operands and results of detensored
173  /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
174  /// from a BB argument and a linalg op's output can be passed to successor
175  /// BBs), we need to maintain the sub-set of arguments that should be
176  /// detensored (i.e. converted by typeConverter) for each affected BB.
177  ///
178  /// Example:
179  ///
180  /// For the following snippet:
181  /// ...
182  /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
183  /// %7 = linalg.init_tensor [] : tensor<i32>
184  /// %8 = linalg.generic #attrs
185  /// ins(%6, %6 : tensor<i32>, tensor<i32>)
186  /// outs(%7 : tensor<i32>) {
187  /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
188  /// %9 = arith.addi %arg0, %arg1 : i32
189  /// linalg.yield %9 : i32
190  /// } -> tensor<i32>
191  /// %10 = "some.op"(%9)
192  /// br ^bb2(%8 : tensor<i32>)
193  /// ...
194  ///
195  /// if the cost model decides that the linalg.generic op should be
196  /// detensored, then:
197  /// - opsToDetensor should be = {linalg.generic{add}}.
198  /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
199  virtual void compute(FunctionOpInterface func,
200  DetensorizeTypeConverter typeConverter,
201  DenseSet<Operation *> &opsToDetensor,
202  DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
203 
204  /// From the blockArgsToDetensor set computed by a CostModel
205  /// implementation, this method computes the corresponding branch op
206  /// detensoring. The result is a map from a branch op to a subset of indices
207  /// of its operands. The indices specify which of the branch op's operands
208  /// should be detensored.
209  ///
210  /// For the previous example, this method would compute: {bb2 -> {0}}.
211  static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
212  const DenseSet<BlockArgument> &blockArgsToDetensor) {
213  DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
214 
215  for (auto blockArgumentElem : blockArgsToDetensor) {
216  Block *block = blockArgumentElem.getOwner();
217 
218  for (PredecessorIterator pred = block->pred_begin();
219  pred != block->pred_end(); ++pred) {
220  BranchOpInterface terminator =
221  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
222  auto blockOperands =
223  terminator.getSuccessorOperands(pred.getSuccessorIndex());
224 
225  if (!blockOperands || blockOperands->empty())
226  continue;
227 
228  detensorableBranchOps[terminator].insert(
229  blockOperands->getBeginOperandIndex() +
230  blockArgumentElem.getArgNumber());
231  }
232  }
233 
234  return detensorableBranchOps;
235  }
236  };
237 
238  /// Detensorize linalg ops involved in control-flow within a function.
239  ///
240  /// This model starts from BranchOps and CondBranchOps within a function. For
241  /// each such branch, the model then walks the use-def chain for the branch's
242  /// condition backwards in order to understand where the condition's value
243  /// comes from. If the condition value is (indirectly) computed by a linalg op
244  /// that can be detensored, the model then continues walking the use-def chain
245  /// in order to understand where the linalg op's operands come from. This
246  /// leads to discovering a "detensoring component". A detensoring component is
247  /// the set of operations + block arguments that are involved in control-flow
248  /// AND can be detensored.
249  class ControlFlowDetectionModel : public CostModel {
250  public:
251  void compute(FunctionOpInterface func,
252  DetensorizeTypeConverter typeConverter,
253  DenseSet<Operation *> &opsToDetensor,
254  DenseSet<BlockArgument> &blockArgsToDetensor) override {
255  SmallVector<Value> workList;
256 
257  func->walk([&](CondBranchOp condBr) {
258  for (auto operand : condBr.getOperands()) {
259  workList.push_back(operand);
260  }
261  });
262 
263  func->walk([&](BranchOp br) {
264  for (auto operand : br.getOperands()) {
265  workList.push_back(operand);
266  }
267  });
268 
269  DenseSet<Value> visitedValues;
270  DenseSet<Operation *> visitedOps;
271 
272  // For a (to-be-detesored) value, check if it "escapes" the block by being
273  // passed to terminator. If it does, then workList is updated with the
274  // corresponding argument to the successor block.
275  auto updateWorkListWithSuccessorArguments =
276  [&](Value value, BranchOpInterface terminator) {
277  if (!terminator)
278  return;
279 
280  for (auto operandIdx :
281  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
282  Value operand = terminator->getOperand(operandIdx);
283 
284  if (operand == value) {
285  auto succBlockArg =
286  terminator.getSuccessorBlockArgument(operandIdx);
287 
288  if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
289  workList.push_back(*succBlockArg);
290  }
291  }
292  };
293 
294  while (!workList.empty()) {
295  Value currentItem = workList.pop_back_val();
296 
297  if (!visitedValues.insert(currentItem).second)
298  continue;
299 
300  // 1 - Look forward:
301  // 1.1 - If currentItem escapes to one or more successors, add
302  // the corresponding successor arguments to workList.
303  updateWorkListWithSuccessorArguments(
304  currentItem, dyn_cast<BranchOpInterface>(
305  currentItem.getParentBlock()->getTerminator()));
306 
307  // 1.2 - For each user of currentItem, add the defined values to
308  // workList. This way, the user ops can be inspected later if they are
309  // detensorable and if so, their operands will be added to workList to
310  // potentially discover other parts of the detensorable component.
311  for (auto *user : currentItem.getUsers())
312  for (Value result : user->getResults())
313  workList.push_back(result);
314 
315  // 2 - Look backward:
316  // 2.1 - The current item is defined by a block argument. If the owner
317  // block is a non-entry one, then:
318  // * Add the argument to blockArgsToDetensor.
319  // * Walk the use-def chain backwards to add each predecessor's
320  // terminator-operands corresponding to currentItem to workList.
321  if (currentItem.dyn_cast<BlockArgument>()) {
322  BlockArgument currentItemBlockArgument =
323  currentItem.cast<BlockArgument>();
324  Block *ownerBlock = currentItemBlockArgument.getOwner();
325 
326  // Function arguments are not detensored/converted.
327  if (&*ownerBlock->getParent()->begin() == ownerBlock)
328  continue;
329 
330  // This inner-block argument is involved in control-flow, it should be
331  // detensored.
332  blockArgsToDetensor.insert(currentItemBlockArgument);
333 
334  for (PredecessorIterator pred = ownerBlock->pred_begin();
335  pred != ownerBlock->pred_end(); ++pred) {
336  BranchOpInterface predTerminator =
337  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
338 
339  // TODO: For now, we give up if any of the control-flow components
340  // in a function is not detensorable. Fix that.
341  if (!predTerminator) {
342  opsToDetensor.clear();
343  blockArgsToDetensor.clear();
344  return;
345  }
346 
347  auto ownerBlockOperands =
348  predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
349 
350  if (!ownerBlockOperands || ownerBlockOperands->empty())
351  continue;
352 
353  // For each predecessor, add the value it passes to that argument to
354  // workList to find out how it's computed.
355  workList.push_back(
356  ownerBlockOperands
357  .getValue()[currentItemBlockArgument.getArgNumber()]);
358  }
359 
360  continue;
361  }
362 
363  Operation *currentItemDefiningOp = currentItem.getDefiningOp();
364 
365  if (!visitedOps.insert(currentItemDefiningOp).second)
366  continue;
367 
368  // 2.2 - The current item is computed by a GenericOp. If the op should
369  // be detensored, then:
370  // * Add it to opsToDetensor.
371  // * Add its operands to workList to discover other parts of the
372  // potentially detensorable component.
373  if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
374  // The op was encountered already, no need to inspect it again.
375  if (opsToDetensor.count(genericOp))
376  continue;
377 
378  // The op should not be detensored, give up on it but continue with
379  // discovering the rest of the control-flow component.
380  if (!shouldBeDetensored(genericOp, typeConverter)) {
381  continue;
382  }
383 
384  opsToDetensor.insert(genericOp);
385 
386  for (Value genericOpOperand : genericOp.inputs())
387  workList.push_back(genericOpOperand);
388 
389  continue;
390  }
391 
392  // 2.3 - The current item is the result of a FromElementsOp, it will be
393  // trivially detensored later as part of canonicalization patterns
394  // applied at the end of detensoring.
395  //
396  // Note: No need to check whether the result type of this op is
397  // detensorable since if it wasn't we wouldn't reach that point in the
398  // work list.
399  if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
400  continue;
401 
402  // 2.4 - The current item is the result of a scalar op, add all its
403  // operands to the work list.
404  if (llvm::all_of(
405  currentItemDefiningOp->getResultTypes(),
406  [&](Type resultType) { return resultType.isIntOrFloat(); }))
407  for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
408  workList.push_back(scalarOpOperand);
409  }
410 
411  // Since the cost model gives up on some ops (see the details of step 2.2
412  // above), block arguments that correspond to the values produced by those
413  // ops should not be detensored as well.
414 
415  DenseSet<BlockArgument> blockArgsToRemove;
416 
417  for (auto &blockArg : blockArgsToDetensor) {
418  Block *block = blockArg.getParentBlock();
419 
420  // For the potentially detensorable block argument, find the
421  // correpsonding operands in predecessor blocks.
422  for (PredecessorIterator pred = block->pred_begin();
423  pred != block->pred_end(); ++pred) {
424  BranchOpInterface terminator =
425  dyn_cast<BranchOpInterface>((*pred)->getTerminator());
426  auto blockOperands =
427  terminator.getSuccessorOperands(pred.getSuccessorIndex());
428 
429  if (!blockOperands || blockOperands->empty())
430  continue;
431 
432  Operation *definingOp =
433  terminator
434  ->getOperand(blockOperands->getBeginOperandIndex() +
435  blockArg.getArgNumber())
436  .getDefiningOp();
437 
438  // If the operand is defined by a GenericOp that will not be
439  // detensored, then do not detensor the corresponding block argument.
440  if (dyn_cast_or_null<GenericOp>(definingOp) &&
441  opsToDetensor.count(definingOp) == 0) {
442  blockArgsToRemove.insert(blockArg);
443  break;
444  }
445  }
446  }
447 
448  for (auto &blockArg : blockArgsToRemove) {
449  blockArgsToDetensor.erase(blockArg);
450  }
451  }
452  };
453 
454  /// Detensorize everything that can detensored.
455  class AggressiveDetensoringModel : public CostModel {
456  public:
457  void compute(FunctionOpInterface func,
458  DetensorizeTypeConverter typeConverter,
459  DenseSet<Operation *> &opsToDetensor,
460  DenseSet<BlockArgument> &blockArgsToDetensor) override {
461  func->walk([&](GenericOp genericOp) {
462  if (shouldBeDetensored(genericOp, typeConverter))
463  opsToDetensor.insert(genericOp);
464  });
465 
466  for (Block &block : llvm::drop_begin(func.getBody(), 1))
467  for (BlockArgument blockArgument : block.getArguments())
468  blockArgsToDetensor.insert(blockArgument);
469  }
470  };
471 
472  void runOnOperation() override {
473  MLIRContext *context = &getContext();
474  DetensorizeTypeConverter typeConverter;
475  RewritePatternSet patterns(context);
476  ConversionTarget target(*context);
477  DenseSet<Operation *> opsToDetensor;
478  DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
479  DenseSet<BlockArgument> blockArgsToDetensor;
480  FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());
481 
482  if (aggressiveMode.getValue()) {
483  AggressiveDetensoringModel costModel;
484  costModel.compute(funcOp, typeConverter, opsToDetensor,
485  blockArgsToDetensor);
486  } else {
487  ControlFlowDetectionModel costModel;
488  costModel.compute(funcOp, typeConverter, opsToDetensor,
489  blockArgsToDetensor);
490  }
491 
492  detensorableBranchOps =
493  CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
494 
495  target.addDynamicallyLegalOp<GenericOp>(
496  [&](GenericOp op) { return !opsToDetensor.count(op); });
497 
498  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
499  // A function is legal if all of its non-entry blocks are legal. We
500  // don't legalize the entry block (i.e. the function's signature)
501  // since detensoring can't happen along external calling convention
502  // boundaries, which we conservatively approximate as all function
503  // signatures.
504  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
505  Region &body = funcOp.getBody();
506  return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
507  return !llvm::any_of(
508  blockArgsToDetensor, [&](BlockArgument blockArgument) {
509  return blockArgument.getOwner() == &block &&
510  !typeConverter.isLegal(blockArgument.getType());
511  });
512  });
513  }
514 
516  isLegalForReturnOpTypeConversionPattern(op, typeConverter,
517  /*returnOpAlwaysLegal*/ true))
518  return true;
519 
520  if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
521  if (!detensorableBranchOps.count(branchOp))
522  return true;
523 
524  for (auto operandIdx : detensorableBranchOps[branchOp])
525  if (!typeConverter.isLegal(
526  branchOp->getOperand(operandIdx).getType()))
527  return false;
528 
529  return true;
530  }
531 
532  return false;
533  });
534 
535  patterns.insert<DetensorizeGenericOp>(typeConverter, context);
536  patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter,
537  blockArgsToDetensor);
538  // Since non-entry block arguments get detensorized, we also need to
539  // update the control flow inside the function to reflect the correct
540  // types.
541  auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
542  int operandIdx) -> bool {
543  return detensorableBranchOps.count(branchOp) &&
544  detensorableBranchOps[branchOp].count(operandIdx);
545  };
546 
547  populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
548  shouldConvertBranchOperand);
549 
550  if (failed(
551  applyFullConversion(getOperation(), target, std::move(patterns))))
552  signalPassFailure();
553 
554  RewritePatternSet canonPatterns(context);
555  tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
556  if (failed(applyPatternsAndFoldGreedily(getOperation(),
557  std::move(canonPatterns))))
558  signalPassFailure();
559  }
560 };
561 } // namespace
562 
563 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
564  return std::make_unique<LinalgDetensorize>();
565 }
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
Block represents an ordered list of Operations.
Definition: Block.h:29
pred_iterator pred_end()
Definition: Block.h:224
Value getOperand(unsigned idx)
Definition: Operation.h:219
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
bool isLegalForReturnOpTypeConversionPattern(Operation *op, TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:310
user_range getUsers() const
Definition: Value.h:212
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region...
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:307
LogicalResult applyFullConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
OpListType::iterator iterator
Definition: Block.h:131
void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override
PatternRewriter hook for merging a block into another.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
U dyn_cast() const
Definition: Value.h:99
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
This class represents an argument of a Block.
Definition: Value.h:298
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
auto getType() const
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
Type getType() const
Return the type of this value.
Definition: Value.h:117
Type conversion class.
Implement a predecessor iterator for blocks.
Definition: BlockSupport.h:49
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Type getElementType() const
Returns the element type of this tensor type.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class represents an operand of an operation.
Definition: Value.h:249
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
Definition: Detensorize.cpp:24
This class implements a pattern rewriter for use with ConversionPatterns.
U cast() const
Definition: Value.h:107
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
This class describes a specific conversion target.
std::unique_ptr< Pass > createLinalgDetensorizePass()
Create a pass to convert Linalg operations to equivalent operations that work on primitive types...
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
result_type_range getResultTypes()
Definition: Operation.h:297
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
pred_iterator pred_begin()
Definition: Block.h:221