MLIR  20.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 
17 namespace mlir {
18 #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
19 #include "mlir/Dialect/Linalg/Passes.h.inc"
20 } // namespace mlir
21 
22 using namespace mlir;
23 
26  return false;
27 
28  // TODO: The conversion pattern can be made to work for `any_of` here, but
29  // it's more complex as it requires tracking which operands are scalars.
30  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
31 }
32 
33 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
34 /// the result types and return a list of values such that, for each result type
35 /// `t` and value `v` at the same index `idx`:
36 /// 1. `v.getType() == t`
37 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
38 /// such operand. Then`v == operand_first`.
39 /// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
40 /// a. Static and dynamic dims extracted from the first operand of `op`.
41 /// b. Elemental type equal to the elemental type of `t`.
42 ///
43 /// This is sufficient because ElementwiseMappable guarantees that "The static
44 /// types of all vector (resp. tensor) operands and results must have the same
45 /// shape".
49  Location loc = op->getLoc();
50  ValueRange operands = op->getOperands();
51  TypeRange rankedTensorTypes = op->getResultTypes();
53  res.reserve(rankedTensorTypes.size());
54  for (Type t : rankedTensorTypes) {
55  // Try to find an operand with type matching the result tensor.
56  bool found = false;
57  for (Value v : operands) {
58  if (v.getType() == t) {
59  found = true;
60  res.push_back(v);
61  break;
62  }
63  }
64  if (found)
65  continue;
66 
67  // Extract static / dynamic shape mix from the first operand.
68  res.push_back(b.create<tensor::EmptyOp>(
69  loc, tensor::getMixedSizes(b, loc, operands.front()),
70  cast<RankedTensorType>(t).getElementType()));
71  }
72  return res;
73 }
74 
75 namespace {
76 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
77  ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
78  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
79  LogicalResult matchAndRewrite(Operation *op,
80  PatternRewriter &rewriter) const final {
82  return rewriter.notifyMatchFailure(
83  op, "requires elementwise op on ranked tensors");
84 
85  auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
86  SmallVector<AffineMap, 3> indexingMaps(
87  op->getNumResults() + op->getNumOperands(),
88  rewriter.getMultiDimIdentityMap(rank));
90  rank, utils::IteratorType::parallel);
91  auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
92  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
93  op, /*resultTensorTypes=*/op->getResultTypes(),
94  /*inputs=*/op->getOperands(),
95  /*outputs=*/outputs,
96  /*indexingMaps=*/indexingMaps,
97  /*iteratorTypes=*/iteratorTypes,
98  /*bodyBuilder=*/
99  [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
100  auto resultTypes = llvm::to_vector<6>(
101  llvm::map_range(op->getResultTypes(), [](Type type) {
102  return cast<TensorType>(type).getElementType();
103  }));
104  auto *scalarOp =
105  builder.create(loc, op->getName().getIdentifier(),
106  regionArgs.take_front(op->getNumOperands()),
107  resultTypes, op->getAttrs());
108  builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
109  });
110  return success();
111  }
112 };
113 } // namespace
114 
116  RewritePatternSet &patterns) {
117  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
118  patterns.getContext());
119 }
120 
121 namespace {
122 class ConvertElementwiseToLinalgPass
123  : public impl::ConvertElementwiseToLinalgPassBase<
124  ConvertElementwiseToLinalgPass> {
125  using impl::ConvertElementwiseToLinalgPassBase<
126  ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
127 
128  void runOnOperation() final {
129  auto *func = getOperation();
130  auto *context = &getContext();
131  ConversionTarget target(*context);
132  RewritePatternSet patterns(context);
133 
135  target.markUnknownOpDynamicallyLegal([](Operation *op) {
137  });
138 
139  if (failed(applyPartialConversion(func, target, std::move(patterns))))
140  signalPassFailure();
141  }
142 };
143 } // namespace
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.