MLIR 22.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Support/LLVM.h"
29#include <optional>
30#include <utility>
31
32namespace mlir {
33#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34#include "mlir/Dialect/Linalg/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38using namespace mlir::linalg;
39
40//===---------------------------------------------------------------------===//
41// Methods and patterns that fuse elementwise `linalg.generic` operations.
42//===---------------------------------------------------------------------===//
43
44/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45/// the `producer` to use in the fused operation given the indexing map of the
46/// result of the producer in the consumer.
48 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49 AffineMap fusedConsumerArgIndexMap) {
50 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51 // from consumer loop -> consumer arg tensor index/producer result tensor
52 // index. The fused loop is same as the consumer loop. For each producer arg
53 // the indexing map to be computed is a map from consumer loop -> producer
54 // arg tensor index.
55 // producerResultIndexMap is a map from producer loop -> tensor index.
56 // Compute the inverse to get map from tensor index -> producer loop.
57 // The inverse is a map from producer result tensor index -> producer loop.
58 AffineMap invProducerResultIndexMap =
59 inversePermutation(producerResultIndexMap);
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
62
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64 // argMap is a map from producer loop -> producer arg tensor index.
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66
67 // Compose argMap with invProducerResultIndexMap to get a map from
68 // producer result tensor index -> producer arg tensor index.
69 AffineMap t1 = argMap.compose(invProducerResultIndexMap);
70
71 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72 // consumer loop/ fused loop -> producer arg tensor index.
73 return t1.compose(fusedConsumerArgIndexMap);
74}
75
76// Checks if the given operand can be dropped, and the remaining operands
77// of the fused producer & consumer after the fusion can still compute the
78// bounds of the op.
80 GenericOp producer, GenericOp consumer,
81 ArrayRef<OpOperand *> opOperandsToIgnore) {
82 SmallVector<AffineMap> indexingMaps;
83
84 SmallVector<GenericOp> ops = {producer, consumer};
85 for (auto &op : ops) {
86 for (auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88 continue;
89 }
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91 }
92 }
93 if (indexingMaps.empty()) {
94 // If there are no indexing maps, the operand can only be dropped
95 // if neither op has loops.
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97 }
98
99 // The concatanation of the remained indexing maps must be invertible, so
100 // the bounds of the op can be still computed after dropping the selected
101 // operand. inversePermutation returns an empty AffineMap in case the
102 // concatanated indexing maps are not invertible.
104 indexingMaps, producer.getContext())) != AffineMap();
105}
106
107/// Returns a set of indices of the producer's results which would
108/// be preserved after the fusion.
109/// * There is a chance that the implementation of the transformation does not
110/// agree with the result of this method. This function gives a prediction based
111/// on an optimized fusion.
113 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
115 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116
117 // The fusedOperand will be removed during the fusion
118 opOperandsToIgnore.emplace_back(fusedOperand);
119
120 for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127 return user != consumer.getOperation();
128 })) {
129 preservedProducerResults.insert(producerResult.index());
130
131 // In case the operand can't be dropped
132 (void)opOperandsToIgnore.pop_back_val();
133 }
134 }
135 return preservedProducerResults;
136}
137
138/// Conditions for elementwise fusion of generic operations.
140 if (!fusedOperand)
141 return false;
142
143 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145
146 // Check producer and consumer are generic ops.
147 if (!producer || !consumer)
148 return false;
149
150 // Consumer can have mixed semantics, just check operand itself has tensor
151 // type. Producer must have full tensor semantics to avoid potential
152 // aliasing between producer and consumer memrefs.
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->get().getType()))
155 return false;
156
157 // Verify that
158 // - the producer has all "parallel" iterator type.
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
160 return false;
161
162 // Only allow fusing the producer of an input operand for now.
163 // TODO: allow fusing the producer of an output operand.
164 if (!consumer.isDpsInput(fusedOperand))
165 return false;
166
167 // Get the consumer index map. The number of results of the consumer index
168 // map must match the number of loops of the producer.
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171 return false;
172
173 // Finally the index_map for the result must be invertible. For now just
174 // verify it is a permutation.
175 auto producerResult = cast<OpResult>(fusedOperand->get());
176 AffineMap producerResultIndexMap =
177 producer.getIndexingMapMatchingResult(producerResult);
178 if (!producerResultIndexMap.isPermutation())
179 return false;
180
181 // Ensure that the fusion does not remove size information required to
182 // get the loop bounds. For non-reduction generics, this is trivially the
183 // case due to the output operand. For reductions, we need to check that after
184 // the fusion, each loop dimension has at least one input that defines it.
185 if ((consumer.getNumReductionLoops())) {
186 BitVector coveredDims(consumer.getNumLoops(), false);
187
188 auto addToCoveredDims = [&](AffineMap map) {
189 for (auto result : map.getResults())
190 if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
191 coveredDims[dimExpr.getPosition()] = true;
192 };
193
194 for (auto pair :
195 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196 Value operand = std::get<0>(pair);
197 if (operand == fusedOperand->get())
198 continue;
199 AffineMap operandMap = std::get<1>(pair);
200 addToCoveredDims(operandMap);
201 }
202
203 for (OpOperand *operand : producer.getDpsInputOperands()) {
204 AffineMap newIndexingMap =
206 operand, producerResultIndexMap, consumerIndexMap);
207 addToCoveredDims(newIndexingMap);
208 }
209 if (!coveredDims.all())
210 return false;
211 }
212
213 return true;
214}
215
216/// Generate the region of the fused tensor operation. The region of the fused
217/// op must be empty.
219 RewriterBase &rewriter, GenericOp fusedOp,
220 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
221 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
222 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
223 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
224 // Build the region of the fused op.
225 Block &producerBlock = producer->getRegion(0).front();
226 Block &consumerBlock = consumer->getRegion(0).front();
227 OpBuilder::InsertionGuard guard(rewriter);
228 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
229 IRMapping mapper;
230
231 // 2. Add an index operation for every fused loop dimension and use the
232 // `consumerToProducerLoopsMap` to map the producer indices.
233 if (producer.hasIndexSemantics()) {
234 // Add an index operation for every fused loop dimension.
235 unsigned numFusedOpLoops = fusedOp.getNumLoops();
236 SmallVector<Value> fusedIndices;
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return IndexOp::create(rewriter, producer.getLoc(), dim);
241 });
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
244 Value newIndex = affine::AffineApplyOp::create(
245 rewriter, producer.getLoc(),
246 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.map(indexOp.getResult(), newIndex);
248 }
249 }
250 // TODO: allow fusing the producer of an output operand.
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
253 // 3. Consumer input operands up to consumerIdx (exclusive).
254 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
255 fusedOperand->getOperandNumber())) // input assumption.
256 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257
258 // Replacing consumerIdx requires getting the cloned, yielded, value from
259 // the (cloned) producer block. This happens in step 9.
260
261 // 4. Splice in producer's input operands.
262 for (BlockArgument bbArg :
263 producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
265
266 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
267 for (BlockArgument bbArg :
268 consumerBlock.getArguments()
269 .take_front(consumer.getNumDpsInputs())
270 .drop_front(fusedOperand->getOperandNumber() + 1))
271 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
272
273 // 6. All of the producer's output operands
274 for (const auto &bbArg : llvm::enumerate(
275 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
277 continue;
278 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
280 }
281
282 // 7. All of consumer's output operands.
283 for (BlockArgument bbArg :
284 consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
286
287 // 8. Clone all producer operations except for the yield and index operations
288 // to the fused operation.
289 for (auto &op : producerBlock.without_terminator()) {
290 if (!isa<IndexOp>(op))
291 rewriter.clone(op, mapper);
292 }
293 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
294 // forward the yield operand.
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->get()).getResultNumber();
299 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
300
301 // Sanity checks, if replacement is not already in the mapper then it must be
302 // produced outside.
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
307 else
308 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
309 "yielded value must have been mapped");
310 }
311 mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
313 // 10. Clone operations from the consumer to the fused op.
314 for (auto &op : consumerBlock.without_terminator())
315 rewriter.clone(op, mapper);
316
317 // 11. Include the final yield (which is the remapped values for all the
318 // yield)
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
320 SmallVector<Value> fusedYieldValues;
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (const auto &producerYieldVal :
324 llvm::enumerate(producerYieldOp.getOperands())) {
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
327 mapper.lookupOrDefault(producerYieldVal.value()));
328 }
329 for (auto consumerYieldVal : consumerYieldOp.getOperands())
330 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
331 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
332
333 // Sanity checks.
334 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
335 "Ill-formed GenericOp region");
336}
337
338FailureOr<mlir::linalg::ElementwiseOpFusionResult>
340 OpOperand *fusedOperand) {
341 assert(areElementwiseOpsFusable(fusedOperand) &&
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
346 // TODO: allow fusing the producer of an output operand.
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
349 /// Find the results of the producer that have uses outside of the consumer,
350 /// after the fusion.
351 llvm::SmallDenseSet<int> preservedProducerResults =
353 fusedOperand);
354
355 // Compute the fused operands list and indexing maps.
356 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
357 SmallVector<Type> fusedResultTypes;
358 SmallVector<AffineMap> fusedIndexMaps;
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
367 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
368 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
371 return operand == fusedOperand;
372 });
373 assert(it != consumerInputs.end() && "expected to find the consumer operand");
374 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
377 }
378 // 4. Splice in producer's input operands/maps.
379 AffineMap producerResultIndexMap =
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
383 // Compute indexing maps for the producer args in the fused operation.
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
388 }
389 // 5. Remaining consumer's input operands/maps (drop past index
390 // `consumerIdx`).
391 for (OpOperand *opOperand :
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
395 }
396
397 // 6. Collect all of the producer outputs.
398 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
400 continue;
401
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
408 }
409
410 // 7. All of consumer's output operands (skip operands: added by the builder).
411 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
417 }
418
419 // Generate the fused op.
420 auto fusedOp = GenericOp::create(
421 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
423 consumer.getIteratorTypes(),
424 /*doc=*/nullptr,
425 /*library_call=*/nullptr);
426 if (!fusedOp.getShapesToLoopsMap()) {
427 // Fused op has invalid indexing maps. Typically this means something is off
428 // in the input, but going ahead here would result in verification errors.
429 // So cleanup and abort.
430 rewriter.eraseOp(fusedOp);
431 return rewriter.notifyMatchFailure(
432 fusedOp, "fused op failed loop bound computation check");
433 }
434
435 // Construct an AffineMap from consumer loops to producer loops.
436 // consumer loop -> tensor index
437 AffineMap consumerResultIndexMap =
438 consumer.getMatchingIndexingMap(fusedOperand);
439 // tensor index -> producer loop
440 AffineMap invProducerResultIndexMap =
441 inversePermutation(producerResultIndexMap);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
444 // consumer loop -> producer loop
445 AffineMap consumerToProducerLoopsMap =
446 invProducerResultIndexMap.compose(consumerResultIndexMap);
447
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
452 result.fusedOp = fusedOp;
453 int resultNum = 0;
454 for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (auto consumerResult : consumer->getResults())
458 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
459 return result;
460}
461
462namespace {
463/// Patterns to fuse a generic op, with the producer of its operands.
464class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
465public:
466 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
467 PatternBenefit benefit = 1)
468 : OpRewritePattern<GenericOp>(context, benefit),
469 controlFn(std::move(fun)) {}
470
471 LogicalResult matchAndRewrite(GenericOp genericOp,
472 PatternRewriter &rewriter) const override {
473 // Find the first operand that is defined by another generic op on tensors.
474 for (OpOperand &opOperand : genericOp->getOpOperands()) {
475 if (!areElementwiseOpsFusable(&opOperand))
476 continue;
477 if (!controlFn(&opOperand))
478 continue;
479
480 Operation *producer = opOperand.get().getDefiningOp();
481
482 // Find the producer of the operand.
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
484 fuseElementwiseOps(rewriter, &opOperand);
485 if (failed(fusionResult))
486 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
487
488 // Perform the fusion.
489 for (auto [origVal, replacement] : fusionResult->replacements) {
490 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
491 // Only replace consumer uses.
492 return use.get().getDefiningOp() != producer;
493 });
494 }
495 rewriter.eraseOp(genericOp);
496 return success();
497 }
498 return failure();
499 }
500
501private:
502 ControlFusionFn controlFn;
503};
504} // namespace
505
506//===---------------------------------------------------------------------===//
507// Methods and patterns that fuse reshape ops with elementwise operations by
508// expanding the dimensionality of the elementwise operations.
509//===---------------------------------------------------------------------===//
510
511/// Conditions for folding a structured linalg operation with a reshape op by
512/// expanding the iteration space dimensionality for tensor operations. These
513/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
514/// the following fusion pattern.
515///
516/// Consider
517///
518/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
519/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
520/// affine_map<(d0, d1, d2) -> (d1, d2)>,
521/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
522/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
523/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
524///
525/// The reshape can be folded into the `linalgOp` if its loop dimensionality
526/// is increased to match the result (operand) of the tensor.expand_shape.
527/// The indexing_map of the fused tensor in the `linalgOp` and the
528/// reassociation map helps compute the indexing maps of the modified op.
529/// For the above example, based on the reassociation map it
530/// can be concluded that
531///
532/// - The loop used to access the first dimension of the fused tensor is split
533/// into two.
534/// - The loop used to access the second dimension of the fused tensor is kept
535/// as is.
536/// - The loop used to access the third dimension of the fused tensor is split
537/// into three.
538///
539/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
540/// op, then
541///
542/// d0 -> e0, e1
543/// d1 -> e2, e3, e4
544/// d2 -> e5
545///
546/// substituting this, the structured op can be rewritten as
547///
548/// %d = linalg.generic ins(%0, %1 : )
549/// indexing_maps =
550/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
551/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
552/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
553///
554/// Since operands to the linalg generic are now 5D, reshapes can be introduced
555/// to make it consistent
556///
557/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
558/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
559/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
560/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
561///
562/// The added reshapes are again expanding patterns, so they will get fused
563/// with its producers if possible.
564static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
565 OpOperand *fusableOpOperand) {
566 // Is fusable only if:
567 // - All the indexing maps for operands and results are projected
568 // permutations.
569 // - The fused tensor is not a scalar.
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575 [](Attribute attr) {
576 return cast<AffineMapAttr>(attr)
577 .getValue()
578 .isProjectedPermutation();
579 }) &&
580 operandMap.getNumResults() > 0;
582
583namespace {
584/// Information needed to expand a generic operation to fold the reshape with
585/// it.
586class ExpansionInfo {
587public:
588 // Computes the mapping from original dimensions of the op to the dimensions
589 // of the expanded op given the `indexingMap` of the fused operand/result of
590 // the generic op, the `reassocationMaps` of the reshape op and the shape of
591 // the expanded op.
592 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593 ArrayRef<AffineMap> reassociationMaps,
595 PatternRewriter &rewriter);
596 unsigned getOrigOpNumDims() const { return reassociation.size(); }
597 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598 ReassociationIndicesRef getExpandedDims(unsigned i) const {
599 return reassociation[i];
600 }
601 ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602 return expandedShapeMap[i];
603 }
604 ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605
606private:
607 /// Reassociation from the dimensions in the original operation to the
608 /// dimension of the expanded operation.
610 /// Mapping from extent of loops in the original operation, to the extent of
611 /// loops in the expanded operation.
613 /// Extent of the loop in the original operation.
614 SmallVector<OpFoldResult> originalLoopExtent;
615 unsigned expandedOpNumDims;
616};
617} // namespace
618
619LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620 OpOperand *fusableOpOperand,
621 ArrayRef<AffineMap> reassociationMaps,
622 ArrayRef<OpFoldResult> expandedShape,
623 PatternRewriter &rewriter) {
624 if (reassociationMaps.empty())
625 return failure();
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
627
628 OpBuilder::InsertionGuard g(rewriter);
629 rewriter.setInsertionPoint(linalgOp);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](Range r) { return r.size; });
633
634 reassociation.clear();
635 expandedShapeMap.clear();
636 // Compute the number of dimension in the expanded op that correspond to each
637 // dimension of the original op.
638 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
639 expandedShapeMap.resize(fusedIndexMap.getNumDims());
640 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
643 numExpandedDims[pos] = foldedDims.getNumResults();
644 ArrayRef<OpFoldResult> shape =
645 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
647 }
648 // The remaining dimensions remain the same.
649 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
652
653 // Compute reassociation map from the original op to the expanded op.
654 unsigned sum = 0;
655 reassociation.reserve(fusedIndexMap.getNumDims());
656 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
660 }
661 expandedOpNumDims = sum;
662 return success();
663}
664
665/// Return the indexing map to use in the expanded op for a given the
666/// `indexingMap` of the original operation.
667static AffineMap
669 const ExpansionInfo &expansionInfo) {
671 for (AffineExpr expr : indexingMap.getResults()) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
673 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
676 }));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
678 }
679 return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
680 indexingMap.getNumSymbols(), newExprs,
681 builder.getContext());
682}
683
684/// Return the shape and type of the operand/result to use in the expanded op
685/// given the type in the original op.
686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688 const ExpansionInfo &expansionInfo) {
689 SmallVector<OpFoldResult> expandedShape;
690 for (AffineExpr expr : indexingMap.getResults()) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
692 ArrayRef<OpFoldResult> dimExpansion =
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
695 }
696 SmallVector<int64_t> expandedStaticShape;
697 std::tie(expandedStaticShape, std::ignore) =
698 decomposeMixedValues(expandedShape);
699 return {expandedShape, RankedTensorType::get(expandedStaticShape,
700 originalType.getElementType())};
701}
702
703/// Returns the reassociation maps to use in the `tensor.expand_shape`
704/// operation to convert the operands of the original operation to operands of
705/// the expanded operation. The same method is used to compute the
706/// `tensor.collapse_shape` used to collapse the result of the expanded
707/// op to get the value that can replace all uses of the results of the original
708/// op.
709static SmallVector<ReassociationIndices>
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
714 for (AffineExpr expr : indexingMap.getResults()) {
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
717 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(indices));
720 numReshapeDims += numExpandedDims;
721 }
722 return reassociation;
723}
724
725/// Update the body of an expanded linalg operation having index semantics. The
726/// indices of the original operation need to be recovered by linearizing the
727/// indices of the correspoding dimensions of the expanded operation. For now it
728/// is assumed that the shapes of the expanded operation needed for
729/// linearization are static.
731 Location loc, Region &fusedRegion,
732 const ExpansionInfo &expansionInfo) {
733 // Replace the original indices by the linearization of the expanded indices.
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
736 ArrayRef<int64_t> expandedDims =
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() && "expected valid expansion info");
739
740 // Skip index operations that are not affected by the expansion.
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
743 continue;
744
745 // Linearize the expanded indices of the original index dimension.
746 OpBuilder::InsertionGuard guard(rewriter);
747 rewriter.setInsertionPointAfter(indexOp);
748 ArrayRef<OpFoldResult> expandedDimsShape =
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750 SmallVector<Value> expandedIndices;
751 expandedIndices.reserve(expandedDims.size() - 1);
752 llvm::transform(
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
755 OpFoldResult newIndex =
756 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757 for (auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
759 AffineExpr idx, acc, shape;
760 bindDims(rewriter.getContext(), idx, acc);
761 bindSymbols(rewriter.getContext(), shape);
763 rewriter, indexOp.getLoc(), idx + acc * shape,
764 ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
765 }
766 Value newIndexVal =
767 getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768 rewriter.replaceOp(indexOp, newIndexVal);
769 }
770}
771
772// Create an expanded transpose op.
773// the reassociation map is already permuted hence we inverse permute and then
774// flatten it. Then we inverse permute it again to get the final expanded
775// transpose permutation. For example,
776//
777// permutation = [2, 0, 1]
778// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
779//
780// inverse permutation = [1, 2, 0]
781// applied to reassocation_map and then flattened becomes
782// flatened permutation = [2, 3, 4, 5, 0, 1]
783// final permuation is the inverse of the flattened permutation.
784//
785// Becomes
786//
787// permutation=[4, 5, 0, 1, 2, 3]
788
790 TransposeOp transposeOp,
791 Value expandedInput, Value output,
792 ExpansionInfo &expansionInfo) {
793 SmallVector<int64_t> newPerm;
794 for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
798 }
799 }
800 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
801 output, invertPermutationVector(newPerm));
802}
803
804// Create an expanded generic op.
806 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
807 ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
808 ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
809 // The iterator types of the expanded op are all parallel.
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812
813 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[j] = type;
816
817 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818 expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
820
821 Region &fusedRegion = fused->getRegion(0);
822 Region &originalRegion = linalgOp->getRegion(0);
823 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
824
825 // Update the index accesses after the expansion.
826 updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
827 expansionInfo);
828
829 return fused;
830}
831
832// Create an expanded fused op that retains the name for certain ops
833// such as fill, copy and transpose and produce a generic op for
834// rest of linalg ops.
835static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
836 TypeRange resultTypes,
837 ArrayRef<Value> expandedOpOperands,
838 ArrayRef<Value> outputs,
839 ArrayRef<AffineMap> expandedOpIndexingMaps,
840 ExpansionInfo &expansionInfo) {
841
842 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
844 return createExpandedTransposeOp(rewriter, transposeOp,
845 expandedOpOperands[0], outputs[0],
846 expansionInfo);
847 })
848 .Case<FillOp, CopyOp>([&](Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
853 })
854 .Default([&](Operation *op) {
855 return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
858 });
859}
860
861/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
862/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
863/// that those conditions have been satisfied.
864static std::optional<SmallVector<Value>>
865fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866 OpOperand *fusableOpOperand,
867 PatternRewriter &rewriter) {
868 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
869 "preconditions for fuse operation failed");
870
871 Location loc = linalgOp.getLoc();
872 SmallVector<OpFoldResult> expandedShape;
873 SmallVector<AffineMap, 4> reassociationIndices;
874 Value src;
875 if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876 // Try to move the dynamic dimensions in output shape before the `linalgOp`
877 // to maintain SSA validity
878 if (failed(moveValueDefinitions(
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880 return std::nullopt;
881
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
885 } else {
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
888 return std::nullopt;
889
890 expandedShape = tensor::getMixedSizes(
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
894 }
895
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
899 rewriter)))
900 return std::nullopt;
901
902 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
905 }));
906
907 // Set insertion point to the generic op.
908 OpBuilder::InsertionGuard g(rewriter);
909 rewriter.setInsertionPoint(linalgOp);
910
911 SmallVector<Value> expandedOpOperands;
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
916 continue;
917 }
918 if (auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921 SmallVector<OpFoldResult> expandedOperandShape;
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
924 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
925 if (expandedOperandType != opOperand->get().getType()) {
926 // Reshape the operand to get the right type.
928 getReassociationForExpansion(indexingMap, expansionInfo);
929 if (failed(reshapeLikeShapesAreCompatible(
930 [&](const Twine &msg) {
931 return rewriter.notifyMatchFailure(linalgOp, msg);
932 },
933 opOperandType.getShape(), expandedOperandType.getShape(),
934 reassociation,
935 /*isExpandingReshape=*/true)))
936 return std::nullopt;
937 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
940 continue;
941 }
942 }
943 expandedOpOperands.push_back(opOperand->get());
944 }
945
946 SmallVector<Value> outputs;
947 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950 SmallVector<OpFoldResult> expandedOutputShape;
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
953 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
954 if (expandedOutputType != opOperand.get().getType()) {
956 getReassociationForExpansion(indexingMap, expansionInfo);
957 if (failed(reshapeLikeShapesAreCompatible(
958 [&](const Twine &msg) {
959 return rewriter.notifyMatchFailure(linalgOp, msg);
960 },
961 opOperandType.getShape(), expandedOutputType.getShape(),
962 reassociation,
963 /*isExpandingReshape=*/true)))
964 return std::nullopt;
965 outputs.push_back(tensor::ExpandShapeOp::create(
966 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
968 } else {
969 outputs.push_back(opOperand.get());
970 }
971 }
972
973 TypeRange resultTypes = ValueRange(outputs).getTypes();
974 Operation *fusedOp =
975 createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
976 outputs, expandedOpIndexingMaps, expansionInfo);
977 // Reshape the result values to their original shape if this is a collapsing
978 // reshape folded into its consumer.
979 SmallVector<Value> resultVals;
980 for (OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
987 expansionInfo);
988 resultVals.push_back(tensor::CollapseShapeOp::create(
989 rewriter, linalgOp.getLoc(), opResult.getType(),
990 fusedOp->getResult(resultNumber), reassociation));
991 } else {
992 resultVals.push_back(fusedOp->getResult(resultNumber));
993 }
994 }
995 // Assuming a single result.
996 return resultVals;
997}
998
999namespace {
1000
1001/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1002/// when the reshape op is collapsing dimensions. The dimensionality of the loop
1003/// in the consumer is expanded.
1004class FoldWithProducerReshapeOpByExpansion
1005 : public OpInterfaceRewritePattern<LinalgOp> {
1006public:
1007 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1008 ControlFusionFn foldReshapes,
1009 PatternBenefit benefit = 1)
1010 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1012
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014 PatternRewriter &rewriter) const override {
1015 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1018 if (!reshapeOp)
1019 continue;
1020 // Fold only if
1021 // - The tensor reshape op is folding.
1022 // - All constraints of fusing with reshape by expansion are met.
1023 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1024 (!controlFoldingReshapes(opOperand)))
1025 continue;
1026
1027 std::optional<SmallVector<Value>> replacementValues =
1028 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1029 if (!replacementValues)
1030 return failure();
1031 rewriter.replaceOp(linalgOp, *replacementValues);
1032 return success();
1033 }
1034 return failure();
1035 }
1036
1037private:
1038 ControlFusionFn controlFoldingReshapes;
1039};
1040
1041class FoldPadWithProducerReshapeOpByExpansion
1042 : public OpRewritePattern<tensor::PadOp> {
1043public:
1044 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1045 ControlFusionFn foldReshapes,
1046 PatternBenefit benefit = 1)
1047 : OpRewritePattern<tensor::PadOp>(context, benefit),
1048 controlFoldingReshapes(std::move(foldReshapes)) {}
1049
1050 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1051 PatternRewriter &rewriter) const override {
1052 tensor::CollapseShapeOp reshapeOp =
1053 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1054 if (!reshapeOp)
1055 return failure();
1056 if (!reshapeOp->hasOneUse())
1057 return failure();
1058
1059 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1060 return rewriter.notifyMatchFailure(padOp,
1061 "fusion blocked by control function");
1062 }
1063
1064 ArrayRef<int64_t> low = padOp.getStaticLow();
1065 ArrayRef<int64_t> high = padOp.getStaticHigh();
1066 SmallVector<ReassociationIndices> reassociations =
1067 reshapeOp.getReassociationIndices();
1068
1069 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070 if (reInd.size() != 1 && (l != 0 || h != 0))
1071 return failure();
1072 }
1073
1074 SmallVector<OpFoldResult> newLow, newHigh;
1075 RankedTensorType expandedType = reshapeOp.getSrcType();
1076 RankedTensorType paddedType = padOp.getResultType();
1077 SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1078 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1079 if (reInd.size() == 1) {
1080 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081 }
1082 for (size_t i = 0; i < reInd.size(); ++i) {
1083 newLow.push_back(padOp.getMixedLowPad()[idx]);
1084 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1085 }
1086 }
1087
1088 Location loc = padOp->getLoc();
1089 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090 auto newPadOp = tensor::PadOp::create(
1091 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092 padOp.getConstantPaddingValue(), padOp.getNofold());
1093
1094 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1095 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1096
1097 return success();
1098 }
1099
1100private:
1101 ControlFusionFn controlFoldingReshapes;
1102};
1103
1104/// Pattern to fold a tensor.expand_shape op with its producer generic op
1105/// by expanding the dimensionality of the loop in the producer op.
1106struct FoldReshapeWithGenericOpByExpansion
1107 : public OpRewritePattern<tensor::ExpandShapeOp> {
1108
1109 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1110 ControlFusionFn foldReshapes,
1111 PatternBenefit benefit = 1)
1112 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113 controlFoldingReshapes(std::move(foldReshapes)) {}
1114
1115 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1116 PatternRewriter &rewriter) const override {
1117 // Fold only if all constraints of fusing with reshape by expansion are met.
1118 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119 if (!producerResult) {
1120 return rewriter.notifyMatchFailure(reshapeOp,
1121 "source not produced by an operation");
1122 }
1123
1124 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1125 if (!producer) {
1126 return rewriter.notifyMatchFailure(reshapeOp,
1127 "producer not a generic op");
1128 }
1129
1131 producer,
1132 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1133 return rewriter.notifyMatchFailure(
1134 reshapeOp, "failed preconditions of fusion with producer generic op");
1135 }
1136
1137 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1138 return rewriter.notifyMatchFailure(reshapeOp,
1139 "fusion blocked by control function");
1140 }
1141
1142 std::optional<SmallVector<Value>> replacementValues =
1144 producer, reshapeOp,
1145 producer.getDpsInitOperand(producerResult.getResultNumber()),
1146 rewriter);
1147 if (!replacementValues) {
1148 return rewriter.notifyMatchFailure(reshapeOp,
1149 "fusion by expansion failed");
1150 }
1151
1152 // Find the replacement for the reshape op. Since the replacements have the
1153 // same type as the returns of the original generic op, the consumer reshape
1154 // op can be replaced by the source of the collapse_shape op that defines
1155 // the replacement.
1156 Value reshapeReplacement =
1157 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158 .getResultNumber()];
1159 if (auto collapseOp =
1160 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1161 reshapeReplacement = collapseOp.getSrc();
1162 }
1163 rewriter.replaceOp(reshapeOp, reshapeReplacement);
1164 rewriter.replaceOp(producer, *replacementValues);
1165 return success();
1166 }
1167
1168private:
1169 ControlFusionFn controlFoldingReshapes;
1170};
1171} // namespace
1172
1173//===---------------------------------------------------------------------===//
1174// Methods and patterns to fuse reshape with linalg.generic operations by
1175// contraction of dimensions.
1176//===---------------------------------------------------------------------===//
1177
1178/// For a given list of indices in the range of the `indexingMap` that are
1179/// folded, return the indices of the corresponding domain. Return
1180/// `std::nullopt` on failure. Ensures that all the elements of the returned
1181/// reassociation are distinct.
1184 ReassociationIndicesRef rangeReassociation) {
1185 assert(indexingMap.isProjectedPermutation() &&
1186 "expected projected permutation");
1187
1188 ReassociationIndices domainReassociation = llvm::to_vector<4>(
1189 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190 return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1191 }));
1192 // The projected permutation semantics ensures that there is no repetition of
1193 // the domain indices.
1194 return domainReassociation;
1195}
1196
1197/// For a given `dimSequence`, check if the sequence is conserved in the
1198/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1199/// Non-existence of the sequence returns true as well.
1201 ReassociationIndicesRef dimSequence) {
1202 assert(!dimSequence.empty() &&
1203 "expected non-empty list for dimension sequence");
1204 assert(indexingMap.isProjectedPermutation() &&
1205 "expected indexing map to be projected permutation");
1206
1207 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208 sequenceElements.insert_range(dimSequence);
1209
1210 unsigned dimSequenceStart = dimSequence[0];
1211 for (const auto &expr : enumerate(indexingMap.getResults())) {
1212 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1213 // 1. Check if this start of the sequence.
1214 if (dimInMapStart == dimSequenceStart) {
1215 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1216 return false;
1217 // 1a. Check if sequence is preserved.
1218 for (const auto &dimInSequence : enumerate(dimSequence)) {
1219 unsigned dimInMap =
1220 cast<AffineDimExpr>(
1221 indexingMap.getResult(expr.index() + dimInSequence.index()))
1222 .getPosition();
1223 if (dimInMap != dimInSequence.value())
1224 return false;
1225 }
1226 // Found the sequence. Projected permutation
1227 // enforces that all AffineDimExprs in the result are unique, so no
1228 // further checks are needed.
1229 return true;
1230 }
1231 // 2. If position in the expr (which is of type AffineDimExpr) is part
1232 // of sequence, return false here. This implies the entire sequence does not
1233 // exist in the indexing map.
1234 if (sequenceElements.count(dimInMapStart))
1235 return false;
1236 }
1237 // 3. No element of sequence found. Return true.
1238 return true;
1239}
1240
1243 return llvm::all_of(maps, [&](AffineMap map) {
1244 return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1245 return isDimSequencePreserved(map, dimSequence);
1246 });
1247 });
1248}
1249
1250// Return the list of dimensions of the iteration domain that can be
1251// collapsed to allow for fusion with the a producer that is an expand_shape
1252// operation. If all dimensions created by expansion can be collapsed in the
1253// iteration space then the reshape is defunct.
1254//
1255// Example:
1256//
1257// ```mlir
1258// #map = affine_map<(d0, d1) -> (d0, d1)>
1259// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1260// %2 = tensor.empty [..] : tensor<?x4xf32>
1261// %3 = linalg.generic {
1262// indexing_maps = [#map, #map],
1263// iterator_types = ["parallel" ,"parallel"]}
1264// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1265// ```
1266//
1267// can be fused by collapsing the dimensions of the iteration space.
1268//
1269// ```mlir
1270// #map = affine_map<(d0) -> (d0)>
1271// %2 = tensor.empty [..] : tensor<?xf32>
1272// %3 = linalg.generic {
1273// indexing_maps = [#map, #map],
1274// iterator_types = ["parallel"]}
1275// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1276// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1277// ```
1278//
1279// In the following example,
1280//
1281// ```mlir
1282// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1283// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1284// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1285// %2 = tensor.empty [..] : tensor<4x?xf32>
1286// %2 = linalg.generic {
1287// indexing_maps = [#map0, #map1],
1288// iterator_types = ["parallel" ,"parallel"]}
1289// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1290// ```
1291//
1292// the reshape cannot be fused with the generic op by collapsing the op
1293// dimensions since the indexing maps will have to contain mods and divs
1294// to preserve the accesses pattern. When no dimensions of the iteration
1295// space are collapsable and empty vector is returned.
1297getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1298 ArrayRef<ReassociationIndices> reassociation) {
1299 // Some basic checks for this fusion to be valid.
1300 if (!genericOp.hasPureTensorSemantics())
1301 return {};
1302
1303 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1304 return map.isProjectedPermutation();
1305 })) {
1306 return {};
1307 }
1308
1309 // Compute all the loops with the reduction iterator types.
1310 SmallVector<unsigned> reductionDims;
1311 genericOp.getReductionDims(reductionDims);
1312
1313 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315 auto iteratorTypes = genericOp.getIteratorTypesArray();
1316 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1317 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1318 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1319
1320 // Ignore dims that are not folded.
1321 if (foldedRangeDims.size() == 1)
1322 continue;
1323
1324 ReassociationIndices foldedIterationSpaceDims =
1325 getDomainReassociation(indexingMap, foldedRangeDims);
1326
1327 // Check that the folded iteration dims do not contain already processed
1328 // dims.
1329 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330 return processedIterationDims.count(dim);
1331 }))
1332 continue;
1333
1334 // Check that all folded iterator types are all parallel or all reductions.
1335 utils::IteratorType startIteratorType =
1336 iteratorTypes[foldedIterationSpaceDims[0]];
1337 if (!isParallelIterator(startIteratorType) &&
1338 !isReductionIterator(startIteratorType))
1339 continue;
1340 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341 return iteratorTypes[dim] != startIteratorType;
1342 }))
1343 continue;
1344
1345 // If the folded dimensions correspond to a "reduction" iterator type,
1346 // the folded dimensions need to be "in-order". Strictly speaking this is
1347 // not necessary, for reductions that are associative and commutative, but
1348 // using a more strict definition of reduction for now.
1349 if (isReductionIterator(startIteratorType)) {
1350 bool isContiguous = false;
1351 for (const auto &startDim : llvm::enumerate(reductionDims)) {
1352 // Move window in `reductionDims` to start of the folded iteration dims.
1353 if (startDim.value() != foldedIterationSpaceDims[0])
1354 continue;
1355 // If sizes doesnt match, trivial not contiguous. This condition should
1356 // not be hit.
1357 if (startDim.index() + foldedIterationSpaceDims.size() >
1358 reductionDims.size())
1359 break;
1360 // Check that the contiguity is maintained.
1361 isContiguous = true;
1362 for (const auto &foldedDim :
1363 llvm::enumerate(foldedIterationSpaceDims)) {
1364 if (reductionDims[foldedDim.index() + startDim.index()] !=
1365 foldedDim.value()) {
1366 isContiguous = false;
1367 break;
1368 }
1369 }
1370 break;
1371 }
1372 if (!isContiguous)
1373 continue;
1374 }
1375
1376 // Check that the sequence is preserved in all indexing maps.
1377 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1378 [&](AffineMap indexingMap) {
1379 return !isDimSequencePreserved(indexingMap,
1380 foldedIterationSpaceDims);
1381 }))
1382 continue;
1383
1384 processedIterationDims.insert_range(foldedIterationSpaceDims);
1385 iterationSpaceReassociation.emplace_back(
1386 std::move(foldedIterationSpaceDims));
1387 }
1388
1389 return iterationSpaceReassociation;
1390}
1391
1392/// Helper class to carry state while collapsing the `linalg.generic` op.
1393namespace {
1394class CollapsingInfo {
1395public:
1396 LogicalResult initialize(unsigned origNumLoops,
1397 ArrayRef<ReassociationIndices> foldedIterationDims) {
1398 llvm::SmallDenseSet<int64_t, 4> processedDims;
1399 // Find all the dims that are folded.
1400 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1401 if (foldedIterationDim.empty())
1402 continue;
1403 // If the folded dims contain dims already folded, that's illegal
1404 // specification. Repetition within a list is also illegal.
1405 for (auto dim : foldedIterationDim) {
1406 if (dim >= origNumLoops)
1407 return failure();
1408 if (processedDims.count(dim))
1409 return failure();
1410 processedDims.insert(dim);
1411 }
1412 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1413 foldedIterationDim.end());
1414 }
1415 if (processedDims.size() > origNumLoops)
1416 return failure();
1417
1418 // Add all the preserved dims of the original op as single
1419 // elements to `collapsedOpToOrigOpIterationDim`.
1420 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1421 if (processedDims.count(dim))
1422 continue;
1423 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1424 }
1425
1426 llvm::sort(collapsedOpToOrigOpIterationDim,
1428 return lhs[0] < rhs[0];
1429 });
1430 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1431 for (const auto &foldedDims :
1432 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1433 for (const auto &dim : enumerate(foldedDims.value()))
1434 origOpToCollapsedOpIterationDim[dim.value()] =
1435 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1436 }
1437 return success();
1438 }
1439
1440 /// Return mapping from collapsed loop domain to original loop domain.
1441 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1442 return collapsedOpToOrigOpIterationDim;
1443 }
1444
1445 /// Return mapping from original loop domain to collapsed loop domain. The
1446 /// mapping is a pair. First value is the dimension in the collapsed loop that
1447 /// the original loop is mapped to. Second is the relative position in folded
1448 /// list of this domain. For example if the original loop domain is 3D, and
1449 /// the collapsed loop domain is folding all of it, i.e.
1450 ///
1451 /// ```
1452 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1453 /// ```
1454 ///
1455 /// then
1456 ///
1457 /// ```
1458 /// origOpToCollapsedOpMapping[0] = {0, 0};
1459 /// origOpToCollapsedOpMapping[1] = {0, 1};
1460 /// origOpToCollapsedOpMapping[2] = {0, 2};
1461 /// origOpToCollapsedOpMapping[3] = {1, 0};
1462 /// origOpToCollapsedOpMapping[4] = {1, 1};
1463 /// ```
1464 ///
1465 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1466 return origOpToCollapsedOpIterationDim;
1467 }
1468
1469 /// Return the collapsed op iteration domain rank.
1470 unsigned getCollapsedOpIterationRank() const {
1471 return collapsedOpToOrigOpIterationDim.size();
1472 }
1473
1474private:
1475 /// Map from the iteration domain index in collapsed op to the iteration
1476 /// domain indices in the original op.
1477 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1478
1479 /// Map from iteration domain index in the original op to the iteration domain
1480 /// index in the collapsed op.
1481 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1482};
1483} // namespace
1484
1485/// Get the iterator types for the collapsed operation given the original
1486/// iterator types and collapsed dimensions.
1487static SmallVector<utils::IteratorType>
1488getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1489 const CollapsingInfo &collapsingInfo) {
1490 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1491 for (ReassociationIndicesRef foldedIterDims :
1492 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493 assert(!foldedIterDims.empty() &&
1494 "reassociation indices expected to have non-empty sets");
1495 // Just pick the iterator type of the first folded dim. Pre-condition checks
1496 // expected to have checked that iterator types of all folded dimensions are
1497 // the same.
1498 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1499 }
1500 return collapsedIteratorTypes;
1501}
1502
1503/// Compute the indexing map in the collapsed op that corresponds to the given
1504/// `indexingMap` of the original operation.
1505static AffineMap
1506getCollapsedOpIndexingMap(AffineMap indexingMap,
1507 const CollapsingInfo &collapsingInfo) {
1508 MLIRContext *context = indexingMap.getContext();
1509 assert(indexingMap.isProjectedPermutation() &&
1510 "expected indexing map to be projected permutation");
1511 SmallVector<AffineExpr> resultExprs;
1512 auto origOpToCollapsedOpMapping =
1513 collapsingInfo.getOrigOpToCollapsedOpMapping();
1514 for (auto expr : indexingMap.getResults()) {
1515 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1516 // If the dim is not the first of the collapsed dim, do nothing.
1517 if (origOpToCollapsedOpMapping[dim].second != 0)
1518 continue;
1519 // The next n-dims are guaranteed to be collapsed. So just use the
1520 // iteration dimension of the collapsed op.
1521 resultExprs.push_back(
1522 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1523 }
1524 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1525 resultExprs, context);
1526}
1527
1528/// Return the `reassociation` indices to use to collapse the operand when the
1529/// iteration space of a generic op is collapsed.
1530static SmallVector<ReassociationIndices>
1531getOperandReassociation(AffineMap indexingMap,
1532 const CollapsingInfo &collapsingInfo) {
1533 unsigned counter = 0;
1534 SmallVector<ReassociationIndices> operandReassociation;
1535 auto origOpToCollapsedOpMapping =
1536 collapsingInfo.getOrigOpToCollapsedOpMapping();
1537 auto collapsedOpToOrigOpMapping =
1538 collapsingInfo.getCollapsedOpToOrigOpMapping();
1539 while (counter < indexingMap.getNumResults()) {
1540 unsigned dim =
1541 cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1542 // This is the start of a collapsed dimensions of the iteration that
1543 // is gauranteed to be preserved in the indexing map. The number of folded
1544 // dims is obtained from the collapsed op to original op mapping.
1545 unsigned numFoldedDims =
1546 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1547 .size();
1548 if (origOpToCollapsedOpMapping[dim].second == 0) {
1549 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1550 operandReassociation.emplace_back(range.begin(), range.end());
1551 }
1552 counter += numFoldedDims;
1553 }
1554 return operandReassociation;
1555}
1556
1557/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1558static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1559 OpOperand *opOperand,
1560 const CollapsingInfo &collapsingInfo,
1561 OpBuilder &builder) {
1562 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1563 SmallVector<ReassociationIndices> operandReassociation =
1564 getOperandReassociation(indexingMap, collapsingInfo);
1565
1566 // If the number of entries in the reassociation for the operand is same as
1567 // the number of results of the indexing map, then nothing to do for this
1568 // operand.
1569 Value operand = opOperand->get();
1570 if (operandReassociation.size() == indexingMap.getNumResults())
1571 return operand;
1572
1573 // Insert a reshape to collapse the dimensions.
1574 if (isa<MemRefType>(operand.getType())) {
1575 return memref::CollapseShapeOp::create(builder, loc, operand,
1576 operandReassociation)
1577 .getResult();
1578 }
1579 return tensor::CollapseShapeOp::create(builder, loc, operand,
1580 operandReassociation)
1581 .getResult();
1582}
1583
1584/// Modify the `linalg.index` operations in the original generic op, to its
1585/// value in the collapsed operation.
1586static void generateCollapsedIndexingRegion(
1587 Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1588 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1589 OpBuilder::InsertionGuard g(rewriter);
1590 rewriter.setInsertionPointToStart(block);
1591
1592 // Collect all the original index ops.
1593 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1594
1595 // For each folded dimension list resolve the original induction variable
1596 // values in terms of the folded dimension induction variable.
1597 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1598 // can be inverted to
1599 // i2 = i_{folded} % d2
1600 // i1 = (i_{folded} / d2) % d1
1601 // i0 = i_{folded} / (d1 * d2)
1602 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1603 for (auto foldedDims :
1604 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1605 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1606 Value newIndexVal =
1607 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1608 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1609 Value loopDim =
1610 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1611 indexReplacementVals[dim] =
1612 rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1613 newIndexVal =
1614 rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1615 }
1616 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1617 }
1618
1619 for (auto indexOp : indexOps) {
1620 auto dim = indexOp.getDim();
1621 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1622 }
1623}
1624
1625static void collapseOperandsAndResults(LinalgOp op,
1626 const CollapsingInfo &collapsingInfo,
1627 RewriterBase &rewriter,
1628 SmallVectorImpl<Value> &inputOperands,
1629 SmallVectorImpl<Value> &outputOperands,
1630 SmallVectorImpl<Type> &resultTypes) {
1631 Location loc = op->getLoc();
1632 inputOperands =
1633 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1634 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1635 rewriter);
1636 });
1637
1638 // Get the output operands and result types.
1639 resultTypes.reserve(op.getNumDpsInits());
1640 outputOperands.reserve(op.getNumDpsInits());
1641 for (OpOperand &output : op.getDpsInitsMutable()) {
1642 Value newOutput =
1643 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1644 outputOperands.push_back(newOutput);
1645 // If the op has "buffer semantics", then the init operands are ranked
1646 // memrefs and the op has no results.
1647 if (!op.hasPureBufferSemantics())
1648 resultTypes.push_back(newOutput.getType());
1649 }
1650}
1651
1652/// Clone a `LinalgOp` to a collapsed version of same name
1653template <typename OpTy>
1654static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1655 const CollapsingInfo &collapsingInfo) {
1656 return nullptr;
1657}
1658
1659/// Collapse any `LinalgOp` that does not require any specialization such as
1660/// indexing_maps, iterator_types, etc.
1661template <>
1662LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1663 const CollapsingInfo &collapsingInfo) {
1664 SmallVector<Value> inputOperands, outputOperands;
1665 SmallVector<Type> resultTypes;
1666 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1667 outputOperands, resultTypes);
1668
1669 return clone(
1670 rewriter, origOp, resultTypes,
1671 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1672}
1673
1674/// Collapse a `GenericOp`
1675template <>
1676GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1677 GenericOp origOp,
1678 const CollapsingInfo &collapsingInfo) {
1679 SmallVector<Value> inputOperands, outputOperands;
1680 SmallVector<Type> resultTypes;
1681 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1682 outputOperands, resultTypes);
1683 SmallVector<AffineMap> indexingMaps(
1684 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1685 return getCollapsedOpIndexingMap(map, collapsingInfo);
1686 }));
1687
1688 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1689 origOp.getIteratorTypesArray(), collapsingInfo));
1690
1691 GenericOp collapsedOp = linalg::GenericOp::create(
1692 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1693 indexingMaps, iteratorTypes,
1694 [](OpBuilder &builder, Location loc, ValueRange args) {});
1695 Block *origOpBlock = &origOp->getRegion(0).front();
1696 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1697 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1698 collapsedOpBlock->getArguments());
1699 return collapsedOp;
1700}
1701
1702static LinalgOp createCollapsedOp(LinalgOp op,
1703 const CollapsingInfo &collapsingInfo,
1704 RewriterBase &rewriter) {
1705 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1706 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1707 } else {
1708 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1709 }
1710}
1711
1712/// Implementation of fusion with reshape operation by collapsing dimensions.
1713FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1714 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1715 RewriterBase &rewriter) {
1716 // Bail on trivial no-op cases.
1717 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1718 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1719 return foldedDims.size() <= 1;
1720 }))
1721 return failure();
1722
1723 CollapsingInfo collapsingInfo;
1724 if (failed(
1725 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1726 return rewriter.notifyMatchFailure(
1727 op, "illegal to collapse specified dimensions");
1728 }
1729
1730 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1731 if (hasPureBufferSemantics &&
1732 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
1733 MemRefType memRefToCollapse =
1734 dyn_cast<MemRefType>(opOperand.get().getType());
1735 if (!memRefToCollapse)
1736 return true;
1737
1738 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1739 SmallVector<ReassociationIndices> operandReassociation =
1740 getOperandReassociation(indexingMap, collapsingInfo);
1741 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1742 memRefToCollapse, operandReassociation);
1743 }))
1744 return rewriter.notifyMatchFailure(op,
1745 "memref is not guaranteed collapsible");
1746
1747 // Bail on non-canonical ranges.
1748 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1749 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1750 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1751 return cast<IntegerAttr>(attr).getInt() == value;
1752 llvm::APInt actual;
1753 return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1754 actual.getSExtValue() == value;
1755 };
1756 if (!llvm::all_of(loopRanges, [&](Range range) {
1757 return opFoldIsConstantValue(range.offset, 0) &&
1758 opFoldIsConstantValue(range.stride, 1);
1759 })) {
1760 return rewriter.notifyMatchFailure(
1761 op, "expected all loop ranges to have zero start and unit stride");
1762 }
1763
1764 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1765
1766 Location loc = op->getLoc();
1767 SmallVector<OpFoldResult> loopBound =
1768 llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1769
1770 if (collapsedOp.hasIndexSemantics()) {
1771 // Collect the loop range of the generic op.
1772 OpBuilder::InsertionGuard g(rewriter);
1773 rewriter.setInsertionPoint(collapsedOp);
1774 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1775 collapsingInfo, loopBound, rewriter);
1776 }
1777
1778 // Insert expanding reshape for the result to get back the original result
1779 // type.
1780 SmallVector<Value> results;
1781 for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1782 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1783 auto originalResultType =
1784 cast<ShapedType>(originalResult.value().getType());
1785 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1786 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1787 AffineMap indexingMap =
1788 op.getIndexingMapMatchingResult(originalResult.value());
1789 SmallVector<ReassociationIndices> reassociation =
1790 getOperandReassociation(indexingMap, collapsingInfo);
1791 assert(
1792 indexingMap.isProjectedPermutation() &&
1793 "Expected indexing map to be a projected permutation for collapsing");
1794 SmallVector<OpFoldResult> resultShape =
1795 applyPermutationMap(indexingMap, ArrayRef(loopBound));
1796 Value result;
1797 if (isa<MemRefType>(collapsedOpResult.getType())) {
1798 MemRefType expandShapeResultType = MemRefType::get(
1799 originalResultType.getShape(), originalResultType.getElementType());
1800 result = memref::ExpandShapeOp::create(
1801 rewriter, loc, expandShapeResultType, collapsedOpResult,
1802 reassociation, resultShape);
1803 } else {
1804 result = tensor::ExpandShapeOp::create(
1805 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1806 resultShape);
1807 }
1808 results.push_back(result);
1809 } else {
1810 results.push_back(collapsedOpResult);
1811 }
1812 }
1813 return CollapseResult{results, collapsedOp};
1814}
1815
1816namespace {
1817
1818/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1819/// contracting dimensions of the loop.
1820class FoldWithProducerReshapeOpByCollapsing
1821 : public OpRewritePattern<GenericOp> {
1822public:
1823 // TODO : support fusion with all linalg ops, not just generic.
1824 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1825 ControlFusionFn foldReshapes,
1826 PatternBenefit benefit = 1)
1827 : OpRewritePattern<GenericOp>(context, benefit),
1828 controlFoldingReshapes(std::move(foldReshapes)) {}
1829
1830 LogicalResult matchAndRewrite(GenericOp genericOp,
1831 PatternRewriter &rewriter) const override {
1832 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1833 tensor::ExpandShapeOp reshapeOp =
1834 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1835 if (!reshapeOp)
1836 continue;
1837
1838 SmallVector<ReassociationIndices> collapsableIterationDims =
1839 getCollapsableIterationSpaceDims(genericOp, &opOperand,
1840 reshapeOp.getReassociationIndices());
1841 if (collapsableIterationDims.empty() ||
1842 !controlFoldingReshapes(&opOperand)) {
1843 continue;
1844 }
1845
1846 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1847 genericOp, collapsableIterationDims, rewriter);
1848 if (!collapseResult) {
1849 return rewriter.notifyMatchFailure(
1850 genericOp, "failed to do the fusion by collapsing transformation");
1851 }
1852
1853 rewriter.replaceOp(genericOp, collapseResult->results);
1854 return success();
1855 }
1856 return failure();
1857 }
1858
1859private:
1860 ControlFusionFn controlFoldingReshapes;
1861};
1862
1863/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1864/// by expanding the dimensionality of the loop in the producer op.
1865struct FoldReshapeWithGenericOpByCollapsing
1866 : public OpRewritePattern<tensor::CollapseShapeOp> {
1867
1868 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1869 ControlFusionFn foldReshapes,
1870 PatternBenefit benefit = 1)
1871 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1872 controlFoldingReshapes(std::move(foldReshapes)) {}
1873
1874 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1875 PatternRewriter &rewriter) const override {
1876 // Fold only if all constraints of fusing with reshape by collapsing are
1877 // met.
1878 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1879 if (!producerResult) {
1880 return rewriter.notifyMatchFailure(reshapeOp,
1881 "source not produced by an operation");
1882 }
1883
1884 // TODO : support fusion with all linalg producers, not just generic.
1885 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1886 if (!producer) {
1887 return rewriter.notifyMatchFailure(reshapeOp,
1888 "producer not a generic op");
1889 }
1890
1891 SmallVector<ReassociationIndices> collapsableIterationDims =
1893 producer,
1894 producer.getDpsInitOperand(producerResult.getResultNumber()),
1895 reshapeOp.getReassociationIndices());
1896 if (collapsableIterationDims.empty()) {
1897 return rewriter.notifyMatchFailure(
1898 reshapeOp, "failed preconditions of fusion with producer generic op");
1899 }
1900
1901 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1902 return rewriter.notifyMatchFailure(reshapeOp,
1903 "fusion blocked by control function");
1904 }
1905
1906 // Set the insertion point after `producer` because there could be uses
1907 // of `producer` between it and the `tensor.collapse_shape` op.
1908 rewriter.setInsertionPointAfter(producer);
1909 std::optional<CollapseResult> collapseResult =
1910 collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
1911 if (!collapseResult) {
1912 return rewriter.notifyMatchFailure(
1913 producer, "failed to do the fusion by collapsing transformation");
1914 }
1915
1916 rewriter.replaceOp(producer, collapseResult->results);
1917 return success();
1918 }
1919
1920private:
1921 ControlFusionFn controlFoldingReshapes;
1922};
1923
1924class FoldPadWithProducerReshapeOpByCollapsing
1925 : public OpRewritePattern<tensor::PadOp> {
1926public:
1927 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1928 ControlFusionFn foldReshapes,
1929 PatternBenefit benefit = 1)
1930 : OpRewritePattern<tensor::PadOp>(context, benefit),
1931 controlFoldingReshapes(std::move(foldReshapes)) {}
1932
1933 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1934 PatternRewriter &rewriter) const override {
1935 tensor::ExpandShapeOp reshapeOp =
1936 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1937 if (!reshapeOp)
1938 return failure();
1939 if (!reshapeOp->hasOneUse())
1940 return failure();
1941
1942 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1943 return rewriter.notifyMatchFailure(padOp,
1944 "fusion blocked by control function");
1945 }
1946
1947 ArrayRef<int64_t> low = padOp.getStaticLow();
1948 ArrayRef<int64_t> high = padOp.getStaticHigh();
1949 SmallVector<ReassociationIndices> reassociations =
1950 reshapeOp.getReassociationIndices();
1951
1952 for (auto reInd : reassociations) {
1953 if (reInd.size() == 1)
1954 continue;
1955 if (llvm::any_of(reInd, [&](int64_t ind) {
1956 return low[ind] != 0 || high[ind] != 0;
1957 })) {
1958 return failure();
1959 }
1960 }
1961
1962 SmallVector<OpFoldResult> newLow, newHigh;
1963 RankedTensorType collapsedType = reshapeOp.getSrcType();
1964 RankedTensorType paddedType = padOp.getResultType();
1965 SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1966 SmallVector<OpFoldResult> expandedPaddedSizes(
1967 getMixedValues(reshapeOp.getStaticOutputShape(),
1968 reshapeOp.getOutputShape(), rewriter));
1969 AffineExpr d0, d1, d2;
1970 bindDims(rewriter.getContext(), d0, d1, d2);
1971 auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1972 Location loc = reshapeOp->getLoc();
1973 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1974 OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1975 OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1976 if (reInd.size() == 1) {
1977 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1978 OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
1979 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1980 expandedPaddedSizes[reInd[0]] = paddedSize;
1981 }
1982 newLow.push_back(l);
1983 newHigh.push_back(h);
1984 }
1985
1986 RankedTensorType collapsedPaddedType =
1987 paddedType.clone(collapsedPaddedShape);
1988 auto newPadOp = tensor::PadOp::create(
1989 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1990 padOp.getConstantPaddingValue(), padOp.getNofold());
1991
1992 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1993 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1994 expandedPaddedSizes);
1995
1996 return success();
1997 }
1998
1999private:
2000 ControlFusionFn controlFoldingReshapes;
2001};
2002
2003/// Pattern to collapse dimensions.
2004template <typename LinalgType>
2005class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2006public:
2007 CollapseLinalgDimensions(MLIRContext *context,
2008 GetCollapsableDimensionsFn collapseDimensions,
2009 PatternBenefit benefit = 1)
2010 : OpRewritePattern<LinalgType>(context, benefit),
2011 controlCollapseDimension(std::move(collapseDimensions)) {}
2012
2013 LogicalResult matchAndRewrite(LinalgType op,
2014 PatternRewriter &rewriter) const override {
2015 SmallVector<ReassociationIndices> collapsableIterationDims =
2016 controlCollapseDimension(op);
2017 if (collapsableIterationDims.empty())
2018 return failure();
2019
2020 // Check if the specified list of dimensions to collapse is a valid list.
2021 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2022 collapsableIterationDims)) {
2023 return rewriter.notifyMatchFailure(
2024 op, "specified dimensions cannot be collapsed");
2025 }
2026
2027 std::optional<CollapseResult> collapseResult =
2028 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2029 if (!collapseResult) {
2030 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2031 }
2032 rewriter.replaceOp(op, collapseResult->results);
2033 return success();
2034 }
2035
2036private:
2037 GetCollapsableDimensionsFn controlCollapseDimension;
2038};
2039
2040} // namespace
2041
2042//===---------------------------------------------------------------------===//
2043// Methods and patterns that fuse constants with linalg.generic operations.
2044//===---------------------------------------------------------------------===//
2045
2046namespace {
2047/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2048/// handle cases where the constant is not single-valued.
2049class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2050public:
2051 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2052 : OpRewritePattern<GenericOp>(context, benefit) {}
2053
2054 LogicalResult matchAndRewrite(GenericOp genericOp,
2055 PatternRewriter &rewriter) const override {
2056 if (!genericOp.hasPureTensorSemantics())
2057 return failure();
2058 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2059 Operation *def = opOperand->get().getDefiningOp();
2060 TypedAttr constantAttr;
2061 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2062 {
2063 DenseElementsAttr splatAttr;
2064 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2065 splatAttr.isSplat() &&
2066 splatAttr.getType().getElementType().isIntOrFloat()) {
2067 constantAttr = splatAttr.getSplatValue<TypedAttr>();
2068 return true;
2069 }
2070 }
2071 {
2072 IntegerAttr intAttr;
2073 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2074 constantAttr = intAttr;
2075 return true;
2076 }
2077 }
2078 {
2079 FloatAttr floatAttr;
2080 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2081 constantAttr = floatAttr;
2082 return true;
2083 }
2084 }
2085 return false;
2086 };
2087
2088 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2089 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2090 continue;
2091
2092 // The operands and the indexing_maps of the fused operation the same as
2093 // the operands and indexing_maps of the generic operations with the
2094 // values at the constant index dropped.
2095 SmallVector<AffineMap> fusedIndexMaps;
2096 SmallVector<Value> fusedOperands;
2097 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2098 fusedIndexMaps.reserve(genericOp->getNumOperands());
2099 fusedOperands.reserve(genericOp.getNumDpsInputs());
2100 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2101 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2102 if (inputOperand == opOperand)
2103 continue;
2104 Value inputValue = inputOperand->get();
2105 fusedIndexMaps.push_back(
2106 genericOp.getMatchingIndexingMap(inputOperand));
2107 fusedOperands.push_back(inputValue);
2108 fusedLocs.push_back(inputValue.getLoc());
2109 }
2110 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2111 fusedIndexMaps.push_back(
2112 genericOp.getMatchingIndexingMap(&outputOperand));
2113
2114 // Check if the operation shapes to loops map is computable.
2115 if (!inversePermutation(
2116 concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2117 return rewriter.notifyMatchFailure(
2118 genericOp, "fused op loop bound computation failed");
2119 }
2120
2121 // Create a constant scalar value from the splat constant.
2122 Value scalarConstant =
2123 arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
2124
2125 SmallVector<Value> outputOperands = genericOp.getOutputs();
2126 auto fusedOp =
2127 GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs),
2128 genericOp->getResultTypes(),
2129 /*inputs=*/fusedOperands,
2130 /*outputs=*/outputOperands,
2131 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2132 genericOp.getIteratorTypes(),
2133 /*doc=*/nullptr,
2134 /*library_call=*/nullptr);
2135
2136 // Map the block argument corresponding to the replaced argument with the
2137 // scalar constant.
2138 Region &region = genericOp->getRegion(0);
2139 Block &entryBlock = *region.begin();
2140 IRMapping mapping;
2141 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2142 scalarConstant);
2143 Region &fusedRegion = fusedOp->getRegion(0);
2144 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2145 mapping);
2146 rewriter.replaceOp(genericOp, fusedOp->getResults());
2147 return success();
2148 }
2149 return failure();
2150 }
2151};
2152
2153} // namespace
2154
2155//===---------------------------------------------------------------------===//
2156// Miscellaneous patterns that help fusion.
2157//===---------------------------------------------------------------------===//
2158
2159namespace {
2160/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2161/// value of the `outs` operand is not used within the op. This is only
2162/// implemented for `linalg.generic` operations for now, but should hold for all
2163/// linalg structured ops.
2164struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2165 using OpRewritePattern<GenericOp>::OpRewritePattern;
2166
2167 LogicalResult matchAndRewrite(GenericOp op,
2168 PatternRewriter &rewriter) const override {
2169 rewriter.startOpModification(op);
2170 bool modifiedOutput = false;
2171 Location loc = op.getLoc();
2172 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2173 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2174 Value operandVal = opOperand.get();
2175 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2176 if (!operandType)
2177 continue;
2178
2179 // If outs is sparse, leave it to the sparsifier.
2181 continue;
2182
2183 // If outs is already an `empty` operation, nothing to do.
2184 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2185 if (definingOp)
2186 continue;
2187 modifiedOutput = true;
2188 SmallVector<OpFoldResult> mixedSizes =
2189 tensor::getMixedSizes(rewriter, loc, operandVal);
2190 Value emptyTensor = tensor::EmptyOp::create(
2191 rewriter, loc, mixedSizes, operandType.getElementType());
2192 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2193 }
2194 }
2195 if (!modifiedOutput) {
2196 rewriter.cancelOpModification(op);
2197 return failure();
2198 }
2199 rewriter.finalizeOpModification(op);
2200 return success();
2201 }
2202};
2203
2204/// Fold linalg.fill into linalg.generic
2205struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2206 using OpRewritePattern<GenericOp>::OpRewritePattern;
2207
2208 LogicalResult matchAndRewrite(GenericOp genericOp,
2209 PatternRewriter &rewriter) const override {
2210 if (!genericOp.hasPureTensorSemantics())
2211 return failure();
2212 bool fillFound = false;
2213 Block &payload = genericOp.getRegion().front();
2214 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2215 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2216 continue;
2217 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2218 if (!fillOp)
2219 continue;
2220 fillFound = true;
2221 Value fillVal = fillOp.value();
2222 auto resultType =
2223 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2224 Value convertedVal =
2225 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2226 /*isUnsignedCast =*/false);
2227 rewriter.replaceAllUsesWith(
2228 payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2229 }
2230 return success(fillFound);
2231 }
2232};
2233} // namespace
2234
2237 const ControlFusionFn &controlFoldingReshapes) {
2238 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2239 controlFoldingReshapes);
2240 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2241 controlFoldingReshapes);
2242 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2243 controlFoldingReshapes);
2244}
2245
2248 const ControlFusionFn &controlFoldingReshapes) {
2249 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2250 controlFoldingReshapes);
2251 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2252 patterns.getContext(), controlFoldingReshapes);
2253 patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2254 controlFoldingReshapes);
2255}
2256
2259 const ControlFusionFn &controlElementwiseOpsFusion) {
2260 auto *context = patterns.getContext();
2261 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2262 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2263 RemoveOutsDependency>(context);
2264 // Add the patterns that clean up dead operands and results.
2266}
2267
2270 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2271 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2272 CollapseLinalgDimensions<linalg::CopyOp>>(
2273 patterns.getContext(), controlCollapseDimensions);
2274}
2275
2276//===---------------------------------------------------------------------===//
2277// Passes
2278//===---------------------------------------------------------------------===//
2279
2280namespace {
2281
2282/// Pass that fuses generic ops on tensors. Used only for testing.
2283// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2284// patterns added here heavily depends on the cost function used. Having an
2285// opinionated pass of this form is not recommended. Deprecate this pass in
2286// favor of test passes that check the functionality of each of the patterns
2287// added here individually.
2288struct LinalgElementwiseOpFusionPass
2290 LinalgElementwiseOpFusionPass> {
2292 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2293 void runOnOperation() override {
2294 Operation *op = getOperation();
2295 MLIRContext *context = op->getContext();
2296 RewritePatternSet patterns(context);
2297
2298 // Add folding with reshape by expansion patterns.
2299 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2300 Operation *producer = fusedOperand->get().getDefiningOp();
2301 return producer && producer->hasOneUse();
2302 };
2303
2304 // Add elementwise op fusion patterns.
2308
2309 // General canonicalization patterns.
2310 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2311 GenericOp::getCanonicalizationPatterns(patterns, context);
2312 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2313 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2314 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2315 patterns);
2316
2317 // Add constant folding patterns.
2319
2320 // Use TopDownTraversal for compile time reasons.
2321 (void)applyPatternsGreedily(op, std::move(patterns),
2322 GreedyRewriteConfig().setUseTopDownTraversal());
2323 }
2324};
2325
2326} // namespace
return success()
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
ArrayRef< ReassociationIndices > getCollapsedOpToOrigOpMapping() const
Return mapping from collapsed loop domain to original loop domain.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
lhs
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:193
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Definition Utils.cpp:235
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Definition Utils.cpp:239
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
ArrayRef< int64_t > ReassociationIndicesRef
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition Transforms.h:560
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.