MLIR  18.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 
17 namespace mlir {
18 #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALG
19 #include "mlir/Dialect/Linalg/Passes.h.inc"
20 } // namespace mlir
21 
22 using namespace mlir;
23 
26  return false;
27 
28  // TODO: The conversion pattern can be made to work for `any_of` here, but
29  // it's more complex as it requires tracking which operands are scalars.
30  return llvm::all_of(op->getOperandTypes(),
31  [](Type type) { return isa<RankedTensorType>(type); });
32 }
33 
34 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
35 /// the result types and return a list of values such that, for each result type
36 /// `t` and value `v` at the same index `idx`:
37 /// 1. `v.getType() == t`
38 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
39 /// such operand. Then`v == operand_first`.
40 /// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
41 /// a. Static and dynamic dims extracted from the first operand of `op`.
42 /// b. Elemental type equal to the elemental type of `t`.
43 ///
44 /// This is sufficient because ElementwiseMappable guarantees that "The static
45 /// types of all vector (resp. tensor) operands and results must have the same
46 /// shape".
50  Location loc = op->getLoc();
51  ValueRange operands = op->getOperands();
52  TypeRange rankedTensorTypes = op->getResultTypes();
54  res.reserve(rankedTensorTypes.size());
55  for (Type t : rankedTensorTypes) {
56  // Try to find an operand with type matching the result tensor.
57  bool found = false;
58  for (Value v : operands) {
59  if (v.getType() == t) {
60  found = true;
61  res.push_back(v);
62  break;
63  }
64  }
65  if (found)
66  continue;
67 
68  // Extract static / dynamic shape mix from the first operand.
69  res.push_back(b.create<tensor::EmptyOp>(
70  loc, tensor::getMixedSizes(b, loc, operands.front()),
71  cast<RankedTensorType>(t).getElementType()));
72  }
73  return res;
74 }
75 
76 namespace {
77 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
78  ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
79  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
80  LogicalResult matchAndRewrite(Operation *op,
81  PatternRewriter &rewriter) const final {
83  return rewriter.notifyMatchFailure(
84  op, "requires elementwise op on ranked tensors");
85 
86  auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
87  SmallVector<AffineMap, 3> indexingMaps(
88  op->getNumResults() + op->getNumOperands(),
89  rewriter.getMultiDimIdentityMap(rank));
91  rank, utils::IteratorType::parallel);
92  auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
93  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
94  op, /*resultTensorTypes=*/op->getResultTypes(),
95  /*inputs=*/op->getOperands(),
96  /*outputs=*/outputs,
97  /*indexingMaps=*/indexingMaps,
98  /*iteratorTypes=*/iteratorTypes,
99  /*bodyBuilder=*/
100  [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
101  auto resultTypes = llvm::to_vector<6>(
102  llvm::map_range(op->getResultTypes(), [](Type type) {
103  return cast<TensorType>(type).getElementType();
104  }));
105  auto *scalarOp =
106  builder.create(loc, op->getName().getIdentifier(),
107  regionArgs.take_front(op->getNumOperands()),
108  resultTypes, op->getAttrs());
109  builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
110  });
111  return success();
112  }
113 };
114 } // namespace
115 
117  RewritePatternSet &patterns) {
118  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
119  patterns.getContext());
120 }
121 
122 namespace {
123 class ConvertElementwiseToLinalgPass
124  : public impl::ConvertElementwiseToLinalgBase<
125  ConvertElementwiseToLinalgPass> {
126 
127  void runOnOperation() final {
128  auto *func = getOperation();
129  auto *context = &getContext();
130  ConversionTarget target(*context);
131  RewritePatternSet patterns(context);
132 
134  target.markUnknownOpDynamicallyLegal([](Operation *op) {
136  });
137 
138  if (failed(applyPartialConversion(func, target, std::move(patterns))))
139  signalPassFailure();
140  }
141 };
142 } // namespace
143 
145  return std::make_unique<ConvertElementwiseToLinalgPass>();
146 }
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1344
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createConvertElementwiseToLinalgPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26