MLIR  14.0.0git
Go to the documentation of this file.
1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
11 #include "PassDetail.h"
19 using namespace mlir;
23  return false;
25  // TODO: The conversion pattern can be made to work for `any_of` here, but
26  // it's more complex as it requires tracking which operands are scalars.
27  return llvm::all_of(op->getOperandTypes(),
28  [](Type type) { return type.isa<RankedTensorType>(); });
29 }
31 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
32 /// the result types and return a list of values such that, for each result type
33 /// `t` and value `v` at the same index `idx`:
34 /// 1. `v.getType() == t`
35 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
36 /// such operand. Then`v == operand_first`.
37 /// 3. Otherwise, v is a newly created `linalg::InitTensorOp` with:
38 /// a. Static and dynamic dims extracted from the first operand of `op`.
39 /// b. Elemental type equal to the elemental type of `t`.
40 ///
41 /// This is sufficient because ElementwiseMappable guarantees that "The static
42 /// types of all vector (resp. tensor) operands and results must have the same
43 /// shape".
47  Location loc = op->getLoc();
48  ValueRange operands = op->getOperands();
49  TypeRange rankedTensorTypes = op->getResultTypes();
51  res.reserve(rankedTensorTypes.size());
52  for (Type t : rankedTensorTypes) {
53  // Try to find an operand with type matching the result tensor.
54  bool found = false;
55  for (Value v : operands) {
56  if (v.getType() == t) {
57  found = true;
58  res.push_back(v);
59  break;
60  }
61  }
62  if (found)
63  continue;
65  // Extract static / dynamic shape mix from the first operand.
66  Value firstOperand = operands.front();
67  auto rankedTensorType = t.cast<RankedTensorType>();
68  auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape());
69  auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b);
71  res.push_back(b.create<linalg::InitTensorOp>(
72  loc, dynamicShape, staticShape, rankedTensorType.getElementType()));
73  }
74  return res;
75 }
77 namespace {
78 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
79  ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
80  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
81  LogicalResult matchAndRewrite(Operation *op,
82  PatternRewriter &rewriter) const final {
84  return rewriter.notifyMatchFailure(
85  op, "requires elementwise op on ranked tensors");
87  auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
88  SmallVector<AffineMap, 3> indexingMaps(
89  op->getNumResults() + op->getNumOperands(),
90  rewriter.getMultiDimIdentityMap(rank));
91  SmallVector<StringRef, 6> iteratorTypes(rank,
93  auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
94  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
95  op, /*resultTensorTypes=*/op->getResultTypes(),
96  /*inputs=*/op->getOperands(),
97  /*outputs=*/outputs,
98  /*indexingMaps=*/indexingMaps,
99  /*iteratorTypes=*/iteratorTypes,
100  /*bodyBuilder=*/
101  [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
102  OperationState state(loc, op->getName());
103  state.addAttributes(op->getAttrs());
104  // Only take the input operands in the cloned elementwise op.
105  state.addOperands(regionArgs.take_front(op->getNumOperands()));
106  auto resultTypes = llvm::to_vector<6>(
107  llvm::map_range(op->getResultTypes(), [](Type type) {
108  return type.cast<TensorType>().getElementType();
109  }));
110  state.addTypes(resultTypes);
111  auto *scalarOp = builder.createOperation(state);
112  builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
113  });
114  return success();
115  }
116 };
117 } // namespace
120  RewritePatternSet &patterns) {
121  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
122  patterns.getContext());
123 }
125 namespace {
126 class ConvertElementwiseToLinalgPass
127  : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
129  void runOnOperation() final {
130  auto *func = getOperation();
131  auto *context = &getContext();
132  ConversionTarget target(*context);
133  RewritePatternSet patterns(context);
138  });
140  if (failed(applyPartialConversion(func, target, std::move(patterns))))
141  signalPassFailure();
142  }
143 };
144 } // namespace
147  return std::make_unique<ConvertElementwiseToLinalgPass>();
148 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:308
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1122
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:215
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
operand_type_range getOperandTypes()
Definition: Operation.h:266
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
SmallVector< Value, 4 > getDynOperands(Location loc, Value val, OpBuilder &b)
Given an operation, retrieves the value of each dynamic dimension through constructing the necessary ...
Definition: Utils.cpp:169
U cast() const
Definition: Value.h:107
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
This class describes a specific conversion target.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
result_type_range getResultTypes()
Definition: Operation.h:297
std::unique_ptr< Pass > createConvertElementwiseToLinalgPass()
MLIRContext * getContext() const
Definition: PatternMatch.h:906
U cast() const
Definition: Types.h:250