MLIR  22.0.0git
ElementwiseToLinalg.cpp
Go to the documentation of this file.
1 //===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
18 #include "mlir/Dialect/Linalg/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
23 static inline bool isScalarLike(Type t) {
24  return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
25 }
26 
29  return false;
30 
31  auto types = op->getOperandTypes();
32 
33  // We want at least one ranked tensor.
34  bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
35 
36  // No invalid operands (i.e., every operand is a ranked tensor or
37  // scalar-like).
38  bool noneInvalid = llvm::none_of(types, [](Type t) {
39  return !(isa<RankedTensorType>(t) || isScalarLike(t));
40  });
41 
42  return anyRankedTensor && noneInvalid;
43 }
44 
45 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
46 /// the result types and return a list of values such that, for each result type
47 /// `t` and value `v` at the same index `idx`:
48 /// 1. `v.getType() == t`
49 /// 2. If an operand of `op` has type `t`, let `operand_first` be the first
50 /// such operand. Then`v == operand_first`.
51 /// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
52 /// a. Static and dynamic dims extracted from the first operand of `op`.
53 /// b. Elemental type equal to the elemental type of `t`.
54 ///
55 /// This is sufficient because ElementwiseMappable guarantees that "The static
56 /// types of all vector (resp. tensor) operands and results must have the same
57 /// shape".
61  Location loc = op->getLoc();
62  ValueRange operands = op->getOperands();
63  TypeRange rankedTensorTypes = op->getResultTypes();
65  res.reserve(rankedTensorTypes.size());
66  for (Type t : rankedTensorTypes) {
67  // Try to find an operand with type matching the result tensor.
68  bool found = false;
69  for (Value v : operands) {
70  if (v.getType() == t) {
71  found = true;
72  res.push_back(v);
73  break;
74  }
75  }
76  if (found)
77  continue;
78 
79  // Extract static / dynamic shape mix from the first operand.
80  res.push_back(tensor::EmptyOp::create(
81  b, loc, tensor::getMixedSizes(b, loc, operands.front()),
82  cast<RankedTensorType>(t).getElementType()));
83  }
84  return res;
85 }
86 
87 namespace {
88 struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
89  ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
90  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
91  LogicalResult matchAndRewrite(Operation *op,
92  PatternRewriter &rewriter) const final {
94  return rewriter.notifyMatchFailure(
95  op, "requires elementwise op on ranked tensors");
96 
97  auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
98  auto rank = resTy.getRank();
99 
100  // Maps: identity for tensors (rank > 0), scalar map for scalars.
101  AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
102  /*results=*/{}, rewriter.getContext());
103  AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
104 
105  // Match phase.
106  SmallVector<bool> isScalarOperand;
107  isScalarOperand.reserve(op->getNumOperands());
108  for (Type ty : op->getOperandTypes()) {
109  if (isScalarLike(ty))
110  isScalarOperand.push_back(true);
111  else if (auto rt = dyn_cast<RankedTensorType>(ty))
112  isScalarOperand.push_back(false);
113  else
114  return rewriter.notifyMatchFailure(
115  op,
116  "unsupported operand type (expected scalar-like or ranked tensor)");
117  }
118 
119  // Create indexing maps.
120  SmallVector<AffineMap> indexingMaps;
121  indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
122 
123  for (bool isScalar : isScalarOperand)
124  indexingMaps.push_back(isScalar ? scalarMap : idMap);
125 
126  indexingMaps.append(op->getNumResults(), idMap);
127 
128  SmallVector<utils::IteratorType> iteratorTypes(
129  rank, utils::IteratorType::parallel);
130  SmallVector<Value> outputs =
132  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
133  op, /*resultTensorTypes=*/op->getResultTypes(),
134  /*inputs=*/op->getOperands(),
135  /*outputs=*/outputs,
136  /*indexingMaps=*/indexingMaps,
137  /*iteratorTypes=*/iteratorTypes,
138  /*bodyBuilder=*/
139  [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
140  SmallVector<Type> resultEltTys = llvm::to_vector<6>(
141  llvm::map_range(op->getResultTypes(), [](Type type) {
142  return cast<TensorType>(type).getElementType();
143  }));
144  Operation *scalarOp =
145  builder.create(loc, op->getName().getIdentifier(),
146  regionArgs.take_front(op->getNumOperands()),
147  resultEltTys, op->getAttrs());
148  linalg::YieldOp::create(builder, loc, scalarOp->getResults());
149  });
150  return success();
151  }
152 };
153 } // namespace
154 
157  patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
158  patterns.getContext());
159 }
160 
161 namespace {
162 class ConvertElementwiseToLinalgPass
163  : public impl::ConvertElementwiseToLinalgPassBase<
164  ConvertElementwiseToLinalgPass> {
165  using impl::ConvertElementwiseToLinalgPassBase<
166  ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
167 
168  void runOnOperation() final {
169  auto *func = getOperation();
170  auto *context = &getContext();
171  ConversionTarget target(*context);
172  RewritePatternSet patterns(context);
173 
175  target.markUnknownOpDynamicallyLegal([](Operation *op) {
177  });
178 
179  if (failed(applyPartialConversion(func, target, std::move(patterns))))
180  signalPassFailure();
181  }
182 };
183 } // namespace
static bool isElementwiseMappableOpOnRankedTensors(Operation *op)
static bool isScalarLike(Type t)
static SmallVector< Value, 4 > getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op)
Given op assumed isElementwiseMappableOpOnRankedTensors, iterate over the result types and return a l...
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.