MLIR  16.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "PassDetail.h"
17 
18 using namespace mlir;
19 
22  return false;
23 
24  // TODO: The conversion pattern can be made to work for `any_of` here, but
25  // it's more complex as it requires tracking which operands are scalars.
26  return llvm::all_of(op->getOperandTypes(),
27  [](Type type) { return type.isa<RankedTensorType>(); });
28 }
29 
30 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
31 /// the result types and return a list of values such that, for each result type
32 /// `t` and value `v` at the same index `idx`:
33 /// 1. `v.getType() == t`
34 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
35 /// such operand. Then`v == operand_first`.
36 /// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with:
37 /// a. Static and dynamic dims extracted from the first operand of `op`.
38 /// b. Elemental type equal to the elemental type of `t`.
39 ///
40 /// This is sufficient because ElementwiseMappable guarantees that "The static
41 /// types of all vector (resp. tensor) operands and results must have the same
42 /// shape".
43 static SmallVector<Value, 4>
46  Location loc = op->getLoc();
47  ValueRange operands = op->getOperands();
48  TypeRange rankedTensorTypes = op->getResultTypes();
49  SmallVector<Value, 4> res;
50  res.reserve(rankedTensorTypes.size());
51  for (Type t : rankedTensorTypes) {
52  // Try to find an operand with type matching the result tensor.
53  bool found = false;
54  for (Value v : operands) {
55  if (v.getType() == t) {
56  found = true;
57  res.push_back(v);
58  break;
59  }
60  }
61  if (found)
62  continue;
63 
64  // Extract static / dynamic shape mix from the first operand.
65  Value firstOperand = operands.front();
66  auto rankedTensorType = t.cast<RankedTensorType>();
67  auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape());
68  auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b);
69 
70  res.push_back(b.create<linalg::InitTensorOp>(
71  loc, dynamicShape, staticShape, rankedTensorType.getElementType()));
72  }
73  return res;
74 }
75 
76 namespace {
77 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
78  ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
79  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
80  LogicalResult matchAndRewrite(Operation *op,
81  PatternRewriter &rewriter) const final {
83  return rewriter.notifyMatchFailure(
84  op, "requires elementwise op on ranked tensors");
85 
86  auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
87  SmallVector<AffineMap, 3> indexingMaps(
88  op->getNumResults() + op->getNumOperands(),
89  rewriter.getMultiDimIdentityMap(rank));
90  SmallVector<StringRef, 6> iteratorTypes(rank,
92  auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
93  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
94  op, /*resultTensorTypes=*/op->getResultTypes(),
95  /*inputs=*/op->getOperands(),
96  /*outputs=*/outputs,
97  /*indexingMaps=*/indexingMaps,
98  /*iteratorTypes=*/iteratorTypes,
99  /*bodyBuilder=*/
100  [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
101  auto resultTypes = llvm::to_vector<6>(
102  llvm::map_range(op->getResultTypes(), [](Type type) {
103  return type.cast<TensorType>().getElementType();
104  }));
105  auto *scalarOp =
106  builder.create(loc, op->getName().getIdentifier(),
107  regionArgs.take_front(op->getNumOperands()),
108  resultTypes, op->getAttrs());
109  builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
110  });
111  return success();
112  }
113 };
114 } // namespace
115 
117  RewritePatternSet &patterns) {
118  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
119  patterns.getContext());
120 }
121 
122 namespace {
123 class ConvertElementwiseToLinalgPass
124  : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
125 
126  void runOnOperation() final {
127  auto *func = getOperation();
128  auto *context = &getContext();
129  ConversionTarget target(*context);
130  RewritePatternSet patterns(context);
131 
135  });
136 
137  if (failed(applyPartialConversion(func, target, std::move(patterns))))
138  signalPassFailure();
139  }
140 };
141 } // namespace
142 
144  return std::make_unique<ConvertElementwiseToLinalgPass>();
145 }
Include the generated interface declarations.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:332
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1122
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:263
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
operand_type_range getOperandTypes()
Definition: Operation.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
SmallVector< Value, 4 > getDynOperands(Location loc, Value val, OpBuilder &b)
Given an operation, retrieves the value of each dynamic dimension through constructing the necessary ...
Definition: Utils.cpp:220
U cast() const
Definition: Value.h:108
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class describes a specific conversion target.
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
result_type_range getResultTypes()
Definition: Operation.h:345
std::unique_ptr< Pass > createConvertElementwiseToLinalgPass()
MLIRContext * getContext() const
U cast() const
Definition: Types.h:278