MLIR  22.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
18 #include "mlir/Dialect/Linalg/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
25  return false;
26 
27  // TODO: The conversion pattern can be made to work for `any_of` here, but
28  // it's more complex as it requires tracking which operands are scalars.
29  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
30 }
31 
32 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
33 /// the result types and return a list of values such that, for each result type
34 /// `t` and value `v` at the same index `idx`:
35 /// 1. `v.getType() == t`
36 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
37 /// such operand. Then`v == operand_first`.
38 /// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
39 /// a. Static and dynamic dims extracted from the first operand of `op`.
40 /// b. Elemental type equal to the elemental type of `t`.
41 ///
42 /// This is sufficient because ElementwiseMappable guarantees that "The static
43 /// types of all vector (resp. tensor) operands and results must have the same
44 /// shape".
48  Location loc = op->getLoc();
49  ValueRange operands = op->getOperands();
50  TypeRange rankedTensorTypes = op->getResultTypes();
52  res.reserve(rankedTensorTypes.size());
53  for (Type t : rankedTensorTypes) {
54  // Try to find an operand with type matching the result tensor.
55  bool found = false;
56  for (Value v : operands) {
57  if (v.getType() == t) {
58  found = true;
59  res.push_back(v);
60  break;
61  }
62  }
63  if (found)
64  continue;
65 
66  // Extract static / dynamic shape mix from the first operand.
67  res.push_back(b.create<tensor::EmptyOp>(
68  loc, tensor::getMixedSizes(b, loc, operands.front()),
69  cast<RankedTensorType>(t).getElementType()));
70  }
71  return res;
72 }
73 
74 namespace {
75 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
76  ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
77  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
78  LogicalResult matchAndRewrite(Operation *op,
79  PatternRewriter &rewriter) const final {
81  return rewriter.notifyMatchFailure(
82  op, "requires elementwise op on ranked tensors");
83 
84  auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
85  SmallVector<AffineMap, 3> indexingMaps(
86  op->getNumResults() + op->getNumOperands(),
87  rewriter.getMultiDimIdentityMap(rank));
89  rank, utils::IteratorType::parallel);
90  auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
91  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
92  op, /*resultTensorTypes=*/op->getResultTypes(),
93  /*inputs=*/op->getOperands(),
94  /*outputs=*/outputs,
95  /*indexingMaps=*/indexingMaps,
96  /*iteratorTypes=*/iteratorTypes,
97  /*bodyBuilder=*/
98  [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
99  auto resultTypes = llvm::to_vector<6>(
100  llvm::map_range(op->getResultTypes(), [](Type type) {
101  return cast<TensorType>(type).getElementType();
102  }));
103  auto *scalarOp =
104  builder.create(loc, op->getName().getIdentifier(),
105  regionArgs.take_front(op->getNumOperands()),
106  resultTypes, op->getAttrs());
107  builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
108  });
109  return success();
110  }
111 };
112 } // namespace
113 
116  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
117  patterns.getContext());
118 }
119 
120 namespace {
121 class ConvertElementwiseToLinalgPass
122  : public impl::ConvertElementwiseToLinalgPassBase<
123  ConvertElementwiseToLinalgPass> {
124  using impl::ConvertElementwiseToLinalgPassBase<
125  ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
126 
127  void runOnOperation() final {
128  auto *func = getOperation();
129  auto *context = &getContext();
130  ConversionTarget target(*context);
131  RewritePatternSet patterns(context);
132 
134  target.markUnknownOpDynamicallyLegal([](Operation *op) {
136  });
137 
138  if (failed(applyPartialConversion(func, target, std::move(patterns))))
139  signalPassFailure();
140  }
141 };
142 } // namespace
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.