MLIR 23.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/SmallVectorExtras.h"
30#include <optional>
31#include <utility>
32
33namespace mlir {
34#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
35#include "mlir/Dialect/Linalg/Passes.h.inc"
36} // namespace mlir
37
38using namespace mlir;
39using namespace mlir::linalg;
40
41//===---------------------------------------------------------------------===//
42// Methods and patterns that fuse elementwise `linalg.generic` operations.
43//===---------------------------------------------------------------------===//
44
45/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
46/// the `producer` to use in the fused operation given the indexing map of the
47/// result of the producer in the consumer.
49 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
50 AffineMap fusedConsumerArgIndexMap) {
51 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
52 // from consumer loop -> consumer arg tensor index/producer result tensor
53 // index. The fused loop is same as the consumer loop. For each producer arg
54 // the indexing map to be computed is a map from consumer loop -> producer
55 // arg tensor index.
56 // producerResultIndexMap is a map from producer loop -> tensor index.
57 // Compute the inverse to get map from tensor index -> producer loop.
58 // The inverse is a map from producer result tensor index -> producer loop.
59 AffineMap invProducerResultIndexMap =
60 inversePermutation(producerResultIndexMap);
61 assert(invProducerResultIndexMap &&
62 "expected producer result indexing map to be invertible");
63
64 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
65 // argMap is a map from producer loop -> producer arg tensor index.
66 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
67
68 // Compose argMap with invProducerResultIndexMap to get a map from
69 // producer result tensor index -> producer arg tensor index.
70 AffineMap t1 = argMap.compose(invProducerResultIndexMap);
71
72 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
73 // consumer loop/ fused loop -> producer arg tensor index.
74 return t1.compose(fusedConsumerArgIndexMap);
75}
76
77// Checks if the given operand can be dropped, and the remaining operands
78// of the fused producer & consumer after the fusion can still compute the
79// bounds of the op.
81 GenericOp producer, GenericOp consumer,
82 ArrayRef<OpOperand *> opOperandsToIgnore) {
83 SmallVector<AffineMap> indexingMaps;
84
85 SmallVector<GenericOp> ops = {producer, consumer};
86 for (auto &op : ops) {
87 for (auto &opOperand : op->getOpOperands()) {
88 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
89 continue;
90 }
91 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
92 }
93 }
94 if (indexingMaps.empty()) {
95 // If there are no indexing maps, the operand can only be dropped
96 // if neither op has loops.
97 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
98 }
99
100 // The concatanation of the remained indexing maps must be invertible, so
101 // the bounds of the op can be still computed after dropping the selected
102 // operand. inversePermutation returns an empty AffineMap in case the
103 // concatanated indexing maps are not invertible.
105 indexingMaps, producer.getContext())) != AffineMap();
106}
107
108/// Returns a set of indices of the producer's results which would
109/// be preserved after the fusion.
110/// * There is a chance that the implementation of the transformation does not
111/// agree with the result of this method. This function gives a prediction based
112/// on an optimized fusion.
114 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
115 llvm::SmallDenseSet<int> preservedProducerResults;
116 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
117
118 // The fusedOperand will be removed during the fusion
119 opOperandsToIgnore.emplace_back(fusedOperand);
120
121 for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
122 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
123 opOperandsToIgnore.emplace_back(outputOperand);
124 if (producer.payloadUsesValueFromOperand(outputOperand) ||
126 opOperandsToIgnore) ||
127 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
128 return user != consumer.getOperation();
129 })) {
130 preservedProducerResults.insert(producerResult.index());
131
132 // In case the operand can't be dropped
133 (void)opOperandsToIgnore.pop_back_val();
134 }
135 }
136 return preservedProducerResults;
137}
138
139/// Conditions for elementwise fusion of generic operations.
141 if (!fusedOperand)
142 return false;
143
144 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
145 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
146
147 // Check producer and consumer are generic ops.
148 if (!producer || !consumer)
149 return false;
150
151 // Consumer can have mixed semantics, just check operand itself has tensor
152 // type. Producer must have full tensor semantics to avoid potential
153 // aliasing between producer and consumer memrefs.
154 if (!producer.hasPureTensorSemantics() ||
155 !isa<RankedTensorType>(fusedOperand->get().getType()))
156 return false;
157
158 // Verify that
159 // - the producer has all "parallel" iterator type.
160 if (producer.getNumParallelLoops() != producer.getNumLoops())
161 return false;
162
163 // Only allow fusing the producer of an input operand for now.
164 // TODO: allow fusing the producer of an output operand.
165 if (!consumer.isDpsInput(fusedOperand))
166 return false;
167
168 // Get the consumer index map. The number of results of the consumer index
169 // map must match the number of loops of the producer.
170 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
171 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
172 return false;
173
174 // Finally the index_map for the result must be invertible. For now just
175 // verify it is a permutation.
176 auto producerResult = cast<OpResult>(fusedOperand->get());
177 AffineMap producerResultIndexMap =
178 producer.getIndexingMapMatchingResult(producerResult);
179 if (!producerResultIndexMap.isPermutation())
180 return false;
181
182 // Ensure that the fusion does not remove size information required to
183 // get the loop bounds. For non-reduction generics, this is trivially the
184 // case due to the output operand. For reductions, we need to check that after
185 // the fusion, each loop dimension has at least one input that defines it.
186 if ((consumer.getNumReductionLoops())) {
187 BitVector coveredDims(consumer.getNumLoops(), false);
188
189 auto addToCoveredDims = [&](AffineMap map) {
190 for (auto result : map.getResults())
191 if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
192 coveredDims[dimExpr.getPosition()] = true;
193 };
194
195 for (auto pair :
196 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
197 Value operand = std::get<0>(pair);
198 if (operand == fusedOperand->get())
199 continue;
200 AffineMap operandMap = std::get<1>(pair);
201 addToCoveredDims(operandMap);
202 }
203
204 for (OpOperand *operand : producer.getDpsInputOperands()) {
205 AffineMap newIndexingMap =
207 operand, producerResultIndexMap, consumerIndexMap);
208 addToCoveredDims(newIndexingMap);
209 }
210 if (!coveredDims.all())
211 return false;
212 }
213
214 return true;
215}
216
217/// Generate the region of the fused tensor operation. The region of the fused
218/// op must be empty.
220 RewriterBase &rewriter, GenericOp fusedOp,
221 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
222 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
223 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
224 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
225 // Build the region of the fused op.
226 Block &producerBlock = producer->getRegion(0).front();
227 Block &consumerBlock = consumer->getRegion(0).front();
228 OpBuilder::InsertionGuard guard(rewriter);
229 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
230 IRMapping mapper;
231
232 // 2. Add an index operation for every fused loop dimension and use the
233 // `consumerToProducerLoopsMap` to map the producer indices.
234 if (producer.hasIndexSemantics()) {
235 // Add an index operation for every fused loop dimension.
236 unsigned numFusedOpLoops = fusedOp.getNumLoops();
237 SmallVector<Value> fusedIndices;
238 fusedIndices.reserve(numFusedOpLoops);
239 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
240 std::back_inserter(fusedIndices), [&](uint64_t dim) {
241 return IndexOp::create(rewriter, producer.getLoc(), dim);
242 });
243 for (IndexOp indexOp :
244 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
245 Value newIndex = affine::AffineApplyOp::create(
246 rewriter, producer.getLoc(),
247 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
248 mapper.map(indexOp.getResult(), newIndex);
249 }
250 }
251 // TODO: allow fusing the producer of an output operand.
252 assert(consumer.isDpsInput(fusedOperand) &&
253 "expected producer of input operand");
254 // 3. Consumer input operands up to consumerIdx (exclusive).
255 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
256 fusedOperand->getOperandNumber())) // input assumption.
257 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
258
259 // Replacing consumerIdx requires getting the cloned, yielded, value from
260 // the (cloned) producer block. This happens in step 9.
261
262 // 4. Splice in producer's input operands.
263 for (BlockArgument bbArg :
264 producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
265 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
266
267 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
268 for (BlockArgument bbArg :
269 consumerBlock.getArguments()
270 .take_front(consumer.getNumDpsInputs())
271 .drop_front(fusedOperand->getOperandNumber() + 1))
272 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
273
274 // 6. All of the producer's output operands
275 for (const auto &bbArg : llvm::enumerate(
276 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
277 if (!preservedProducerResults.count(bbArg.index()))
278 continue;
279 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
280 bbArg.value().getLoc()));
281 }
282
283 // 7. All of consumer's output operands.
284 for (BlockArgument bbArg :
285 consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
286 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
287
288 // 8. Clone all producer operations except for the yield and index operations
289 // to the fused operation.
290 for (auto &op : producerBlock.without_terminator()) {
291 if (!isa<IndexOp>(op))
292 rewriter.clone(op, mapper);
293 }
294 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
295 // forward the yield operand.
296 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
297 unsigned producerResultNumber =
298 cast<OpResult>(fusedOperand->get()).getResultNumber();
300 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
301
302 // Sanity checks, if replacement is not already in the mapper then it must be
303 // produced outside.
304 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
305 if (auto bb = dyn_cast<BlockArgument>(replacement))
306 assert(bb.getOwner() != &producerBlock &&
307 "yielded block argument must have been mapped");
308 else
309 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
310 "yielded value must have been mapped");
311 }
312 mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
314 // 10. Clone operations from the consumer to the fused op.
315 for (auto &op : consumerBlock.without_terminator())
316 rewriter.clone(op, mapper);
317
318 // 11. Include the final yield (which is the remapped values for all the
319 // yield)
320 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
321 SmallVector<Value> fusedYieldValues;
322 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
323 consumerYieldOp.getNumOperands());
324 for (const auto &producerYieldVal :
325 llvm::enumerate(producerYieldOp.getOperands())) {
326 if (preservedProducerResults.count(producerYieldVal.index()))
327 fusedYieldValues.push_back(
328 mapper.lookupOrDefault(producerYieldVal.value()));
329 }
330 for (auto consumerYieldVal : consumerYieldOp.getOperands())
331 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
332 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
333
334 // Sanity checks.
335 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
336 "Ill-formed GenericOp region");
337}
338
339FailureOr<mlir::linalg::ElementwiseOpFusionResult>
341 OpOperand *fusedOperand) {
342 assert(areElementwiseOpsFusable(fusedOperand) &&
343 "expected elementwise operation pre-conditions to pass");
344 auto producerResult = cast<OpResult>(fusedOperand->get());
345 auto producer = cast<GenericOp>(producerResult.getOwner());
346 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
347 // TODO: allow fusing the producer of an output operand.
348 assert(consumer.isDpsInput(fusedOperand) &&
349 "expected producer of input operand");
350 /// Find the results of the producer that have uses outside of the consumer,
351 /// after the fusion.
352 llvm::SmallDenseSet<int> preservedProducerResults =
354 fusedOperand);
355
356 // Compute the fused operands list and indexing maps.
357 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
358 SmallVector<Type> fusedResultTypes;
359 SmallVector<AffineMap> fusedIndexMaps;
360 fusedInputOperands.reserve(producer.getNumDpsInputs() +
361 consumer.getNumDpsInputs());
362 fusedOutputOperands.reserve(preservedProducerResults.size() +
363 consumer.getNumDpsInits());
364 fusedResultTypes.reserve(preservedProducerResults.size() +
365 consumer.getNumDpsInits());
366 fusedIndexMaps.reserve(producer->getNumOperands() +
367 consumer->getNumOperands());
368 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
369 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
370 auto consumerInputs = consumer.getDpsInputOperands();
371 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
372 return operand == fusedOperand;
373 });
374 assert(it != consumerInputs.end() && "expected to find the consumer operand");
375 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
376 fusedInputOperands.push_back(opOperand->get());
377 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
378 }
379 // 4. Splice in producer's input operands/maps.
380 AffineMap producerResultIndexMap =
381 producer.getIndexingMapMatchingResult(producerResult);
382 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
383 fusedInputOperands.push_back(opOperand->get());
384 // Compute indexing maps for the producer args in the fused operation.
386 opOperand, producerResultIndexMap,
387 consumer.getMatchingIndexingMap(fusedOperand));
388 fusedIndexMaps.push_back(map);
389 }
390 // 5. Remaining consumer's input operands/maps (drop past index
391 // `consumerIdx`).
392 for (OpOperand *opOperand :
393 llvm::make_range(std::next(it), consumerInputs.end())) {
394 fusedInputOperands.push_back(opOperand->get());
395 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
396 }
397
398 // 6. Collect all of the producer outputs.
399 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
400 if (!preservedProducerResults.count(opOperand.index()))
401 continue;
402
403 fusedOutputOperands.push_back(opOperand.value().get());
405 &opOperand.value(), producerResultIndexMap,
406 consumer.getMatchingIndexingMap(fusedOperand));
407 fusedIndexMaps.push_back(map);
408 fusedResultTypes.push_back(opOperand.value().get().getType());
409 }
410
411 // 7. All of consumer's output operands (skip operands: added by the builder).
412 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
413 fusedOutputOperands.push_back(opOperand.get());
414 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
415 Type resultType = opOperand.get().getType();
416 if (!isa<MemRefType>(resultType))
417 fusedResultTypes.push_back(resultType);
418 }
419
420 // Generate the fused op.
421 auto fusedOp = GenericOp::create(
422 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
424 consumer.getIteratorTypes(),
425 /*doc=*/nullptr,
426 /*library_call=*/nullptr);
427 if (!fusedOp.getShapesToLoopsMap()) {
428 // Fused op has invalid indexing maps. Typically this means something is off
429 // in the input, but going ahead here would result in verification errors.
430 // So cleanup and abort.
431 rewriter.eraseOp(fusedOp);
432 return rewriter.notifyMatchFailure(
433 fusedOp, "fused op failed loop bound computation check");
434 }
435
436 // Construct an AffineMap from consumer loops to producer loops.
437 // consumer loop -> tensor index
438 AffineMap consumerResultIndexMap =
439 consumer.getMatchingIndexingMap(fusedOperand);
440 // tensor index -> producer loop
441 AffineMap invProducerResultIndexMap =
442 inversePermutation(producerResultIndexMap);
443 assert(invProducerResultIndexMap &&
444 "expected producer result indexig map to be invertible");
445 // consumer loop -> producer loop
446 AffineMap consumerToProducerLoopsMap =
447 invProducerResultIndexMap.compose(consumerResultIndexMap);
448
450 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
451 consumer.getNumLoops(), preservedProducerResults);
453 result.fusedOp = fusedOp;
454 int resultNum = 0;
455 for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
456 if (preservedProducerResults.count(index))
457 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
458 for (auto consumerResult : consumer->getResults())
459 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
460 return result;
461}
463namespace {
464/// Patterns to fuse a generic op, with the producer of its operands.
465class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
466public:
467 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
468 PatternBenefit benefit = 1)
469 : OpRewritePattern<GenericOp>(context, benefit),
470 controlFn(std::move(fun)) {}
472 LogicalResult matchAndRewrite(GenericOp genericOp,
473 PatternRewriter &rewriter) const override {
474 // Find the first operand that is defined by another generic op on tensors.
475 for (OpOperand &opOperand : genericOp->getOpOperands()) {
476 if (!areElementwiseOpsFusable(&opOperand))
477 continue;
478 if (!controlFn(&opOperand))
479 continue;
480
481 Operation *producer = opOperand.get().getDefiningOp();
483 // Find the producer of the operand.
484 FailureOr<ElementwiseOpFusionResult> fusionResult =
485 fuseElementwiseOps(rewriter, &opOperand);
486 if (failed(fusionResult))
487 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
489 // Perform the fusion.
490 for (auto [origVal, replacement] : fusionResult->replacements) {
491 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
492 // Only replace consumer uses.
493 return use.get().getDefiningOp() != producer;
494 });
495 }
496 rewriter.eraseOp(genericOp);
497 return success();
499 return failure();
500 }
501
502private:
503 ControlFusionFn controlFn;
504};
505} // namespace
506
507//===---------------------------------------------------------------------===//
508// Methods and patterns that fuse reshape ops with elementwise operations by
509// expanding the dimensionality of the elementwise operations.
510//===---------------------------------------------------------------------===//
511
512/// Conditions for folding a structured linalg operation with a reshape op by
513/// expanding the iteration space dimensionality for tensor operations. These
514/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
515/// the following fusion pattern.
516///
517/// Consider
518///
519/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
520/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
521/// affine_map<(d0, d1, d2) -> (d1, d2)>,
522/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
523/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
524/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
525///
526/// The reshape can be folded into the `linalgOp` if its loop dimensionality
527/// is increased to match the result (operand) of the tensor.expand_shape.
528/// The indexing_map of the fused tensor in the `linalgOp` and the
529/// reassociation map helps compute the indexing maps of the modified op.
530/// For the above example, based on the reassociation map it
531/// can be concluded that
532///
533/// - The loop used to access the first dimension of the fused tensor is split
534/// into two.
535/// - The loop used to access the second dimension of the fused tensor is kept
536/// as is.
537/// - The loop used to access the third dimension of the fused tensor is split
538/// into three.
539///
540/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
541/// op, then
542///
543/// d0 -> e0, e1
544/// d1 -> e2, e3, e4
545/// d2 -> e5
546///
547/// substituting this, the structured op can be rewritten as
548///
549/// %d = linalg.generic ins(%0, %1 : )
550/// indexing_maps =
551/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
552/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
553/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
554///
555/// Since operands to the linalg generic are now 5D, reshapes can be introduced
556/// to make it consistent
557///
558/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
559/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
560/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
561/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
562///
563/// The added reshapes are again expanding patterns, so they will get fused
564/// with its producers if possible.
565static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
566 OpOperand *fusableOpOperand) {
567 // Is fusable only if:
568 // - All the indexing maps for operands and results are projected
569 // permutations.
570 // - The fused tensor is not a scalar.
572 linalgOp.getIteratorTypesArray();
573 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
574 return linalgOp.hasPureTensorSemantics() &&
575 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 [](Attribute attr) {
577 return cast<AffineMapAttr>(attr)
578 .getValue()
579 .isProjectedPermutation();
580 }) &&
581 operandMap.getNumResults() > 0;
582}
583
584namespace {
585/// Information needed to expand a generic operation to fold the reshape with
586/// it.
587class ExpansionInfo {
588public:
589 // Computes the mapping from original dimensions of the op to the dimensions
590 // of the expanded op given the `indexingMap` of the fused operand/result of
591 // the generic op, the `reassocationMaps` of the reshape op and the shape of
592 // the expanded op.
593 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
594 ArrayRef<AffineMap> reassociationMaps,
595 ArrayRef<OpFoldResult> expandedShape,
596 PatternRewriter &rewriter);
597 unsigned getOrigOpNumDims() const { return reassociation.size(); }
598 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
599 ReassociationIndicesRef getExpandedDims(unsigned i) const {
600 return reassociation[i];
601 }
602 ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
603 return expandedShapeMap[i];
604 }
605 ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
606
607private:
608 /// Reassociation from the dimensions in the original operation to the
609 /// dimension of the expanded operation.
610 SmallVector<ReassociationIndices> reassociation;
611 /// Mapping from extent of loops in the original operation, to the extent of
612 /// loops in the expanded operation.
613 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
614 /// Extent of the loop in the original operation.
615 SmallVector<OpFoldResult> originalLoopExtent;
616 unsigned expandedOpNumDims;
617};
618} // namespace
619
620LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
621 OpOperand *fusableOpOperand,
622 ArrayRef<AffineMap> reassociationMaps,
623 ArrayRef<OpFoldResult> expandedShape,
624 PatternRewriter &rewriter) {
625 if (reassociationMaps.empty())
626 return failure();
627 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
628
629 OpBuilder::InsertionGuard g(rewriter);
630 rewriter.setInsertionPoint(linalgOp);
631 originalLoopExtent = llvm::map_to_vector(
632 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
633 [](Range r) { return r.size; });
634
635 reassociation.clear();
636 expandedShapeMap.clear();
637 // Compute the number of dimension in the expanded op that correspond to each
638 // dimension of the original op.
639 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
640 expandedShapeMap.resize(fusedIndexMap.getNumDims());
641 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
642 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
643 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
644 numExpandedDims[pos] = foldedDims.getNumResults();
645 ArrayRef<OpFoldResult> shape =
646 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
647 expandedShapeMap[pos].assign(shape.begin(), shape.end());
648 }
649 // The remaining dimensions remain the same.
650 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
651 if (expandedShapeMap[i].empty())
652 expandedShapeMap[i] = {originalLoopExtent[i]};
653
654 // Compute reassociation map from the original op to the expanded op.
655 unsigned sum = 0;
656 reassociation.reserve(fusedIndexMap.getNumDims());
657 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
658 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
659 reassociation.emplace_back(seq.begin(), seq.end());
660 sum += numFoldedDim.value();
661 }
662 expandedOpNumDims = sum;
663 return success();
664}
665
666/// Return the indexing map to use in the expanded op for a given the
667/// `indexingMap` of the original operation.
668static AffineMap
670 const ExpansionInfo &expansionInfo) {
672 for (AffineExpr expr : indexingMap.getResults()) {
673 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 SmallVector<AffineExpr, 4> expandedExprs = llvm::map_to_vector<4>(
675 expansionInfo.getExpandedDims(pos), [&](int64_t v) {
676 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 });
678 newExprs.append(expandedExprs.begin(), expandedExprs.end());
679 }
680 return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
681 indexingMap.getNumSymbols(), newExprs,
682 builder.getContext());
683}
684
685/// Return the shape and type of the operand/result to use in the expanded op
686/// given the type in the original op.
687static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
689 const ExpansionInfo &expansionInfo) {
690 SmallVector<OpFoldResult> expandedShape;
691 for (AffineExpr expr : indexingMap.getResults()) {
692 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 ArrayRef<OpFoldResult> dimExpansion =
694 expansionInfo.getExpandedShapeOfDim(dim);
695 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
696 }
697 SmallVector<int64_t> expandedStaticShape;
698 std::tie(expandedStaticShape, std::ignore) =
699 decomposeMixedValues(expandedShape);
700 return {expandedShape, RankedTensorType::get(expandedStaticShape,
701 originalType.getElementType())};
702}
703
704/// Returns the reassociation maps to use in the `tensor.expand_shape`
705/// operation to convert the operands of the original operation to operands of
706/// the expanded operation. The same method is used to compute the
707/// `tensor.collapse_shape` used to collapse the result of the expanded
708/// op to get the value that can replace all uses of the results of the original
709/// op.
710static SmallVector<ReassociationIndices>
712 const ExpansionInfo &expansionInfo) {
714 unsigned numReshapeDims = 0;
715 for (AffineExpr expr : indexingMap.getResults()) {
716 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
717 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
719 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
720 reassociation.emplace_back(std::move(indices));
721 numReshapeDims += numExpandedDims;
722 }
723 return reassociation;
724}
725
726/// Update the body of an expanded linalg operation having index semantics. The
727/// indices of the original operation need to be recovered by linearizing the
728/// indices of the correspoding dimensions of the expanded operation. For now it
729/// is assumed that the shapes of the expanded operation needed for
730/// linearization are static.
732 Location loc, Region &fusedRegion,
733 const ExpansionInfo &expansionInfo) {
734 // Replace the original indices by the linearization of the expanded indices.
735 for (IndexOp indexOp :
736 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
737 ArrayRef<int64_t> expandedDims =
738 expansionInfo.getExpandedDims(indexOp.getDim());
739 assert(!expandedDims.empty() && "expected valid expansion info");
740
741 // Skip index operations that are not affected by the expansion.
742 if (expandedDims.size() == 1 &&
743 expandedDims.front() == (int64_t)indexOp.getDim())
744 continue;
745
746 // Linearize the expanded indices of the original index dimension.
747 OpBuilder::InsertionGuard guard(rewriter);
748 rewriter.setInsertionPointAfter(indexOp);
749 ArrayRef<OpFoldResult> expandedDimsShape =
750 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 SmallVector<Value> expandedIndices;
752 expandedIndices.reserve(expandedDims.size() - 1);
753 llvm::transform(
754 expandedDims.drop_front(), std::back_inserter(expandedIndices),
755 [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
756 OpFoldResult newIndex =
757 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
758 for (auto [expandedShape, expandedIndex] :
759 llvm::zip(expandedDimsShape, expandedIndices)) {
760 AffineExpr idx, acc, shape;
761 bindDims(rewriter.getContext(), idx, acc);
762 bindSymbols(rewriter.getContext(), shape);
764 rewriter, indexOp.getLoc(), idx + acc * shape,
765 ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
766 }
767 Value newIndexVal =
768 getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
769 rewriter.replaceOp(indexOp, newIndexVal);
770 }
771}
772
773// Create an expanded transpose op.
774// the reassociation map is already permuted hence we inverse permute and then
775// flatten it. Then we inverse permute it again to get the final expanded
776// transpose permutation. For example,
777//
778// permutation = [2, 0, 1]
779// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
780//
781// inverse permutation = [1, 2, 0]
782// applied to reassocation_map and then flattened becomes
783// flatened permutation = [2, 3, 4, 5, 0, 1]
784// final permuation is the inverse of the flattened permutation.
785//
786// Becomes
787//
788// permutation=[4, 5, 0, 1, 2, 3]
789
791 TransposeOp transposeOp,
792 Value expandedInput, Value output,
793 ExpansionInfo &expansionInfo) {
794 SmallVector<int64_t> newPerm;
795 for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
796 auto reassoc = expansionInfo.getExpandedDims(perm);
797 for (int64_t dim : reassoc) {
798 newPerm.push_back(dim);
799 }
800 }
801 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
802 output, invertPermutationVector(newPerm));
803}
804
805// Create an expanded generic op.
807 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
808 ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
809 ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
810 // The iterator types of the expanded op are all parallel.
812 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813
814 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
815 for (auto j : expansionInfo.getExpandedDims(i))
816 iteratorTypes[j] = type;
817
818 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
819 expandedOpOperands, outputs,
820 expandedOpIndexingMaps, iteratorTypes);
821
822 Region &fusedRegion = fused->getRegion(0);
823 Region &originalRegion = linalgOp->getRegion(0);
824 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
825
826 // Update the index accesses after the expansion.
827 updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
828 expansionInfo);
829
830 return fused;
831}
832
833// Create an expanded fused op that retains the name for certain ops
834// such as fill, copy and transpose and produce a generic op for
835// rest of linalg ops.
836static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
837 TypeRange resultTypes,
838 ArrayRef<Value> expandedOpOperands,
839 ArrayRef<Value> outputs,
840 ArrayRef<AffineMap> expandedOpIndexingMaps,
841 ExpansionInfo &expansionInfo) {
842
843 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
844 .Case([&](TransposeOp transposeOp) {
845 return createExpandedTransposeOp(rewriter, transposeOp,
846 expandedOpOperands[0], outputs[0],
847 expansionInfo);
848 })
849 .Case<FillOp, CopyOp>([&](Operation *op) {
850 return clone(rewriter, linalgOp, resultTypes,
851 llvm::to_vector(llvm::concat<Value>(
852 llvm::to_vector(expandedOpOperands),
853 llvm::to_vector(outputs))));
854 })
855 .Default([&](Operation *op) {
856 return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
857 expandedOpOperands, outputs,
858 expansionInfo, expandedOpIndexingMaps);
859 });
860}
861
862/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
863/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
864/// that those conditions have been satisfied.
865static std::optional<SmallVector<Value>>
866fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
867 OpOperand *fusableOpOperand,
868 PatternRewriter &rewriter) {
869 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
870 "preconditions for fuse operation failed");
871
872 Location loc = linalgOp.getLoc();
873 SmallVector<OpFoldResult> expandedShape;
874 SmallVector<AffineMap, 4> reassociationIndices;
875 Value src;
876 if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
877 // Try to move the dynamic dimensions in output shape before the `linalgOp`
878 // to maintain SSA validity
879 if (failed(moveValueDefinitions(
880 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
881 return std::nullopt;
882
883 expandedShape = expandingReshapeOp.getMixedOutputShape();
884 reassociationIndices = expandingReshapeOp.getReassociationMaps();
885 src = expandingReshapeOp.getSrc();
886 } else {
887 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
888 if (!collapsingReshapeOp)
889 return std::nullopt;
890
891 expandedShape = tensor::getMixedSizes(
892 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
893 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
894 src = collapsingReshapeOp.getSrc();
895 }
896
897 ExpansionInfo expansionInfo;
898 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
899 reassociationIndices, expandedShape,
900 rewriter)))
901 return std::nullopt;
902
903 SmallVector<AffineMap, 4> expandedOpIndexingMaps =
904 llvm::map_to_vector<4>(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
905 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
906 });
907
908 // Set insertion point to the generic op.
909 OpBuilder::InsertionGuard g(rewriter);
910 rewriter.setInsertionPoint(linalgOp);
911
912 SmallVector<Value> expandedOpOperands;
913 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
914 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
915 if (opOperand == fusableOpOperand) {
916 expandedOpOperands.push_back(src);
917 continue;
918 }
919 if (auto opOperandType =
920 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
921 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 SmallVector<OpFoldResult> expandedOperandShape;
923 RankedTensorType expandedOperandType;
924 std::tie(expandedOperandShape, expandedOperandType) =
925 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
926 if (expandedOperandType != opOperand->get().getType()) {
927 // Reshape the operand to get the right type.
929 getReassociationForExpansion(indexingMap, expansionInfo);
930 if (failed(reshapeLikeShapesAreCompatible(
931 [&](const Twine &msg) {
932 return rewriter.notifyMatchFailure(linalgOp, msg);
933 },
934 opOperandType.getShape(), expandedOperandType.getShape(),
935 reassociation,
936 /*isExpandingReshape=*/true)))
937 return std::nullopt;
938 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
939 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
940 expandedOperandShape));
941 continue;
942 }
943 }
944 expandedOpOperands.push_back(opOperand->get());
945 }
946
947 SmallVector<Value> outputs;
948 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
949 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
950 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 SmallVector<OpFoldResult> expandedOutputShape;
952 RankedTensorType expandedOutputType;
953 std::tie(expandedOutputShape, expandedOutputType) =
954 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
955 if (expandedOutputType != opOperand.get().getType()) {
957 getReassociationForExpansion(indexingMap, expansionInfo);
958 if (failed(reshapeLikeShapesAreCompatible(
959 [&](const Twine &msg) {
960 return rewriter.notifyMatchFailure(linalgOp, msg);
961 },
962 opOperandType.getShape(), expandedOutputType.getShape(),
963 reassociation,
964 /*isExpandingReshape=*/true)))
965 return std::nullopt;
966 outputs.push_back(tensor::ExpandShapeOp::create(
967 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
968 expandedOutputShape));
969 } else {
970 outputs.push_back(opOperand.get());
971 }
972 }
973
974 TypeRange resultTypes = ValueRange(outputs).getTypes();
975 Operation *fusedOp =
976 createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
977 outputs, expandedOpIndexingMaps, expansionInfo);
978 // Reshape the result values to their original shape if this is a collapsing
979 // reshape folded into its consumer.
980 SmallVector<Value> resultVals;
981 for (OpResult opResult : linalgOp->getOpResults()) {
982 int64_t resultNumber = opResult.getResultNumber();
983 if (resultTypes[resultNumber] != opResult.getType()) {
986 linalgOp.getMatchingIndexingMap(
987 linalgOp.getDpsInitOperand(resultNumber)),
988 expansionInfo);
989 resultVals.push_back(tensor::CollapseShapeOp::create(
990 rewriter, linalgOp.getLoc(), opResult.getType(),
991 fusedOp->getResult(resultNumber), reassociation));
992 } else {
993 resultVals.push_back(fusedOp->getResult(resultNumber));
994 }
995 }
996 // Assuming a single result.
997 return resultVals;
998}
999
1000namespace {
1001
1002/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1003/// when the reshape op is collapsing dimensions. The dimensionality of the loop
1004/// in the consumer is expanded.
1005class FoldWithProducerReshapeOpByExpansion
1006 : public OpInterfaceRewritePattern<LinalgOp> {
1007public:
1008 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1009 ControlFusionFn foldReshapes,
1010 PatternBenefit benefit = 1)
1011 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1012 controlFoldingReshapes(std::move(foldReshapes)) {}
1013
1014 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 PatternRewriter &rewriter) const override {
1016 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1017 tensor::CollapseShapeOp reshapeOp =
1018 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1019 if (!reshapeOp)
1020 continue;
1021 // Fold only if
1022 // - The tensor reshape op is folding.
1023 // - All constraints of fusing with reshape by expansion are met.
1024 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1025 (!controlFoldingReshapes(opOperand)))
1026 continue;
1027
1028 std::optional<SmallVector<Value>> replacementValues =
1029 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1030 if (!replacementValues)
1031 return failure();
1032 rewriter.replaceOp(linalgOp, *replacementValues);
1033 return success();
1034 }
1035 return failure();
1036 }
1037
1038private:
1039 ControlFusionFn controlFoldingReshapes;
1040};
1041
1042/// Carries information about a padded dimension.
1043struct PadDimInfo {
1044 // The resulting shape after padding each dimension.
1045 SmallVector<int64_t> paddedShape;
1046
1047 // Low and high padding amounts for each dimension.
1048 SmallVector<OpFoldResult> lowPad;
1049 SmallVector<OpFoldResult> highPad;
1050};
1051
1052/// Computes the expanded padding information for the given pad operation based
1053/// on the provided expanded shape and reassociation indices. Returns a list of
1054/// PadDimInfo containing the low and high padding amounts and the padded
1055/// size for each dimension, or failure if the expansion is not possible.
1056static FailureOr<PadDimInfo>
1057computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1058 ArrayRef<ReassociationIndices> reassociations,
1059 PatternRewriter &rewriter) {
1060 // If the padding value depends on the index values of the pad operation,
1061 // then it may not be valid to expand the dimensions, since it will change
1062 // the index values on which the padding value depends. This is not currently
1063 // supported by the pad expansion patterns, but it could be implemented
1064 // similarly to the expansion of linalg.generic ops with linalg.index ops in
1065 // the body, as is done in `updateExpandedGenericOpRegion`.
1066 if (!padOp.getConstantPaddingValue())
1067 return failure();
1068
1069 // Expanded dimensions cannot have padding because the resulting padding may
1070 // not be representable by a tensor.pad op. There are some special cases where
1071 // it is possible (like expanding unit dims), but supporting these cases is
1072 // NYI, so disallow it for now.
1073 ArrayRef<int64_t> low = padOp.getStaticLow();
1074 ArrayRef<int64_t> high = padOp.getStaticHigh();
1075 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1076 if (reInd.size() != 1 && (l != 0 || h != 0))
1077 return failure();
1078 }
1079
1080 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1081 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1082 ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1083 PadDimInfo padDimInfo;
1084 padDimInfo.paddedShape.assign(expandedShape);
1085 padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1086 padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1087 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1088 if (reInd.size() == 1) {
1089 padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1090 padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1091 padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1092 }
1093 }
1094
1095 return padDimInfo;
1096}
1097
1098class FoldPadWithProducerReshapeOpByExpansion
1099 : public OpRewritePattern<tensor::PadOp> {
1100public:
1101 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1102 ControlFusionFn foldReshapes,
1103 PatternBenefit benefit = 1)
1104 : OpRewritePattern<tensor::PadOp>(context, benefit),
1105 controlFoldingReshapes(std::move(foldReshapes)) {}
1106
1107 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1108 PatternRewriter &rewriter) const override {
1109 tensor::CollapseShapeOp reshapeOp =
1110 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1111 if (!reshapeOp)
1112 return failure();
1113
1114 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1115 return rewriter.notifyMatchFailure(padOp,
1116 "fusion blocked by control function");
1117 }
1118
1119 RankedTensorType expandedType = reshapeOp.getSrcType();
1120 SmallVector<ReassociationIndices> reassociations =
1121 reshapeOp.getReassociationIndices();
1122 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1123 padOp, expandedType.getShape(), reassociations, rewriter);
1124 if (failed(maybeExpandedPadding))
1125 return failure();
1126 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1127
1128 Location loc = padOp->getLoc();
1129 RankedTensorType expandedPaddedType =
1130 padOp.getResultType().clone(expandedPadding.paddedShape);
1131
1132 auto newPadOp = tensor::PadOp::create(
1133 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1134 expandedPadding.lowPad, expandedPadding.highPad,
1135 padOp.getConstantPaddingValue(), padOp.getNofold());
1136
1137 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1138 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1139
1140 return success();
1141 }
1142
1143private:
1144 ControlFusionFn controlFoldingReshapes;
1145};
1146
1147class FoldReshapeWithProducerPadOpByExpansion
1148 : public OpRewritePattern<tensor::ExpandShapeOp> {
1149public:
1150 FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1151 ControlFusionFn foldReshapes,
1152 PatternBenefit benefit = 1)
1153 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1154 controlFoldingReshapes(std::move(foldReshapes)) {}
1155
1156 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1157 PatternRewriter &rewriter) const override {
1158 tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1159 if (!padOp)
1160 return failure();
1161
1162 if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1163 return rewriter.notifyMatchFailure(expandOp,
1164 "fusion blocked by control function");
1165 }
1166
1167 RankedTensorType expandedType = expandOp.getResultType();
1168 SmallVector<ReassociationIndices> reassociations =
1169 expandOp.getReassociationIndices();
1170 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1171 padOp, expandedType.getShape(), reassociations, rewriter);
1172 if (failed(maybeExpandedPadding))
1173 return failure();
1174 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1175
1176 Location loc = expandOp->getLoc();
1177 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1178 SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1179 rewriter.setInsertionPointAfterValue(padOp.getSource());
1180 SmallVector<OpFoldResult> padSrcSizes =
1181 tensor::getMixedSizes(rewriter, loc, padOp.getSource());
1182 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1183 // We know that any reassociation with multiple dims is not padded because
1184 // of the requirements of computeExpandedPadding.
1185 if (reInd.size() == 1) {
1186 newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1187 newExpandedSizes[reInd[0]] = padSrcSizes[idx];
1188 }
1189 }
1190 RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1191 auto newExpandOp = tensor::ExpandShapeOp::create(
1192 rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1193 newExpandedSizes);
1194 RankedTensorType expandedPaddedType =
1195 padOp.getResultType().clone(expandedPadding.paddedShape);
1196 rewriter.setInsertionPoint(expandOp);
1197 auto newPadOp = tensor::PadOp::create(
1198 rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1199 expandedPadding.lowPad, expandedPadding.highPad,
1200 padOp.getConstantPaddingValue(), padOp.getNofold());
1201
1202 rewriter.replaceOp(expandOp, newPadOp.getResult());
1203
1204 return success();
1205 }
1206
1207private:
1208 ControlFusionFn controlFoldingReshapes;
1209};
1210
1211/// Pattern to fold a tensor.expand_shape op with its producer generic op
1212/// by expanding the dimensionality of the loop in the producer op.
1213struct FoldReshapeWithGenericOpByExpansion
1214 : public OpRewritePattern<tensor::ExpandShapeOp> {
1215
1216 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1217 ControlFusionFn foldReshapes,
1218 PatternBenefit benefit = 1)
1219 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1220 controlFoldingReshapes(std::move(foldReshapes)) {}
1221
1222 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1223 PatternRewriter &rewriter) const override {
1224 // Fold only if all constraints of fusing with reshape by expansion are met.
1225 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1226 if (!producerResult) {
1227 return rewriter.notifyMatchFailure(reshapeOp,
1228 "source not produced by an operation");
1229 }
1230
1231 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1232 if (!producer) {
1233 return rewriter.notifyMatchFailure(reshapeOp,
1234 "producer not a generic op");
1235 }
1236
1238 producer,
1239 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1240 return rewriter.notifyMatchFailure(
1241 reshapeOp, "failed preconditions of fusion with producer generic op");
1242 }
1243
1244 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1245 return rewriter.notifyMatchFailure(reshapeOp,
1246 "fusion blocked by control function");
1247 }
1248
1249 std::optional<SmallVector<Value>> replacementValues =
1251 producer, reshapeOp,
1252 producer.getDpsInitOperand(producerResult.getResultNumber()),
1253 rewriter);
1254 if (!replacementValues) {
1255 return rewriter.notifyMatchFailure(reshapeOp,
1256 "fusion by expansion failed");
1257 }
1258
1259 // Find the replacement for the reshape op. Since the replacements have the
1260 // same type as the returns of the original generic op, the consumer reshape
1261 // op can be replaced by the source of the collapse_shape op that defines
1262 // the replacement.
1263 Value reshapeReplacement =
1264 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1265 .getResultNumber()];
1266 if (auto collapseOp =
1267 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1268 reshapeReplacement = collapseOp.getSrc();
1269 }
1270 rewriter.replaceOp(reshapeOp, reshapeReplacement);
1271 rewriter.replaceOp(producer, *replacementValues);
1272 return success();
1273 }
1274
1275private:
1276 ControlFusionFn controlFoldingReshapes;
1277};
1278} // namespace
1279
1280//===---------------------------------------------------------------------===//
1281// Methods and patterns to fuse reshape with linalg.generic operations by
1282// contraction of dimensions.
1283//===---------------------------------------------------------------------===//
1284
1285/// For a given list of indices in the range of the `indexingMap` that are
1286/// folded, return the indices of the corresponding domain. Return
1287/// `std::nullopt` on failure. Ensures that all the elements of the returned
1288/// reassociation are distinct.
1291 ReassociationIndicesRef rangeReassociation) {
1292 assert(indexingMap.isProjectedPermutation() &&
1293 "expected projected permutation");
1294
1295 ReassociationIndices domainReassociation =
1296 llvm::map_to_vector<4>(rangeReassociation, [&](int64_t pos) -> int64_t {
1297 return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1298 });
1299 // The projected permutation semantics ensures that there is no repetition of
1300 // the domain indices.
1301 return domainReassociation;
1302}
1303
1304/// For a given `dimSequence`, check if the sequence is conserved in the
1305/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1306/// Non-existence of the sequence returns true as well.
1308 ReassociationIndicesRef dimSequence) {
1309 assert(!dimSequence.empty() &&
1310 "expected non-empty list for dimension sequence");
1311 assert(indexingMap.isProjectedPermutation() &&
1312 "expected indexing map to be projected permutation");
1313
1314 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1315 sequenceElements.insert_range(dimSequence);
1316
1317 unsigned dimSequenceStart = dimSequence[0];
1318 for (const auto &expr : enumerate(indexingMap.getResults())) {
1319 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1320 // 1. Check if this start of the sequence.
1321 if (dimInMapStart == dimSequenceStart) {
1322 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1323 return false;
1324 // 1a. Check if sequence is preserved.
1325 for (const auto &dimInSequence : enumerate(dimSequence)) {
1326 unsigned dimInMap =
1327 cast<AffineDimExpr>(
1328 indexingMap.getResult(expr.index() + dimInSequence.index()))
1329 .getPosition();
1330 if (dimInMap != dimInSequence.value())
1331 return false;
1332 }
1333 // Found the sequence. Projected permutation
1334 // enforces that all AffineDimExprs in the result are unique, so no
1335 // further checks are needed.
1336 return true;
1337 }
1338 // 2. If position in the expr (which is of type AffineDimExpr) is part
1339 // of sequence, return false here. This implies the entire sequence does not
1340 // exist in the indexing map.
1341 if (sequenceElements.count(dimInMapStart))
1342 return false;
1343 }
1344 // 3. No element of sequence found. Return true.
1345 return true;
1346}
1347
1350 return llvm::all_of(maps, [&](AffineMap map) {
1351 return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1352 return isDimSequencePreserved(map, dimSequence);
1353 });
1354 });
1355}
1356
1357// Return the list of dimensions of the iteration domain that can be
1358// collapsed to allow for fusion with the a producer that is an expand_shape
1359// operation. If all dimensions created by expansion can be collapsed in the
1360// iteration space then the reshape is defunct.
1361//
1362// Example:
1363//
1364// ```mlir
1365// #map = affine_map<(d0, d1) -> (d0, d1)>
1366// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1367// %2 = tensor.empty [..] : tensor<?x4xf32>
1368// %3 = linalg.generic {
1369// indexing_maps = [#map, #map],
1370// iterator_types = ["parallel" ,"parallel"]}
1371// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1372// ```
1373//
1374// can be fused by collapsing the dimensions of the iteration space.
1375//
1376// ```mlir
1377// #map = affine_map<(d0) -> (d0)>
1378// %2 = tensor.empty [..] : tensor<?xf32>
1379// %3 = linalg.generic {
1380// indexing_maps = [#map, #map],
1381// iterator_types = ["parallel"]}
1382// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1383// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1384// ```
1385//
1386// In the following example,
1387//
1388// ```mlir
1389// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1390// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1391// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1392// %2 = tensor.empty [..] : tensor<4x?xf32>
1393// %2 = linalg.generic {
1394// indexing_maps = [#map0, #map1],
1395// iterator_types = ["parallel" ,"parallel"]}
1396// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1397// ```
1398//
1399// the reshape cannot be fused with the generic op by collapsing the op
1400// dimensions since the indexing maps will have to contain mods and divs
1401// to preserve the accesses pattern. When no dimensions of the iteration
1402// space are collapsable and empty vector is returned.
1404getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1405 ArrayRef<ReassociationIndices> reassociation) {
1406 // Some basic checks for this fusion to be valid.
1407 if (!genericOp.hasPureTensorSemantics())
1408 return {};
1409
1410 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1411 return map.isProjectedPermutation();
1412 })) {
1413 return {};
1414 }
1415
1416 // Compute all the loops with the reduction iterator types.
1417 SmallVector<unsigned> reductionDims;
1418 genericOp.getReductionDims(reductionDims);
1419
1420 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1421 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1422 auto iteratorTypes = genericOp.getIteratorTypesArray();
1423 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1424 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1425 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1426
1427 // Ignore dims that are not folded.
1428 if (foldedRangeDims.size() == 1)
1429 continue;
1430
1431 ReassociationIndices foldedIterationSpaceDims =
1432 getDomainReassociation(indexingMap, foldedRangeDims);
1433
1434 // Check that the folded iteration dims do not contain already processed
1435 // dims.
1436 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1437 return processedIterationDims.count(dim);
1438 }))
1439 continue;
1440
1441 // Check that all folded iterator types are all parallel or all reductions.
1442 utils::IteratorType startIteratorType =
1443 iteratorTypes[foldedIterationSpaceDims[0]];
1444 if (!isParallelIterator(startIteratorType) &&
1445 !isReductionIterator(startIteratorType))
1446 continue;
1447 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1448 return iteratorTypes[dim] != startIteratorType;
1449 }))
1450 continue;
1451
1452 // If the folded dimensions correspond to a "reduction" iterator type,
1453 // the folded dimensions need to be "in-order". Strictly speaking this is
1454 // not necessary, for reductions that are associative and commutative, but
1455 // using a more strict definition of reduction for now.
1456 if (isReductionIterator(startIteratorType)) {
1457 bool isContiguous = false;
1458 for (const auto &startDim : llvm::enumerate(reductionDims)) {
1459 // Move window in `reductionDims` to start of the folded iteration dims.
1460 if (startDim.value() != foldedIterationSpaceDims[0])
1461 continue;
1462 // If sizes doesnt match, trivial not contiguous. This condition should
1463 // not be hit.
1464 if (startDim.index() + foldedIterationSpaceDims.size() >
1465 reductionDims.size())
1466 break;
1467 // Check that the contiguity is maintained.
1468 isContiguous = true;
1469 for (const auto &foldedDim :
1470 llvm::enumerate(foldedIterationSpaceDims)) {
1471 if (reductionDims[foldedDim.index() + startDim.index()] !=
1472 foldedDim.value()) {
1473 isContiguous = false;
1474 break;
1475 }
1476 }
1477 break;
1478 }
1479 if (!isContiguous)
1480 continue;
1481 }
1482
1483 // Check that the sequence is preserved in all indexing maps.
1484 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1485 [&](AffineMap indexingMap) {
1486 return !isDimSequencePreserved(indexingMap,
1487 foldedIterationSpaceDims);
1488 }))
1489 continue;
1490
1491 processedIterationDims.insert_range(foldedIterationSpaceDims);
1492 iterationSpaceReassociation.emplace_back(
1493 std::move(foldedIterationSpaceDims));
1494 }
1495
1496 return iterationSpaceReassociation;
1497}
1498
1499/// Helper class to carry state while collapsing the `linalg.generic` op.
1500namespace {
1501class CollapsingInfo {
1502public:
1503 LogicalResult initialize(unsigned origNumLoops,
1504 ArrayRef<ReassociationIndices> foldedIterationDims) {
1505 llvm::SmallDenseSet<int64_t, 4> processedDims;
1506 // Find all the dims that are folded.
1507 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1508 if (foldedIterationDim.empty())
1509 continue;
1510 // If the folded dims contain dims already folded, that's illegal
1511 // specification. Repetition within a list is also illegal.
1512 for (auto dim : foldedIterationDim) {
1513 if (dim >= origNumLoops)
1514 return failure();
1515 if (processedDims.count(dim))
1516 return failure();
1517 processedDims.insert(dim);
1518 }
1519 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1520 foldedIterationDim.end());
1521 }
1522 if (processedDims.size() > origNumLoops)
1523 return failure();
1524
1525 // Add all the preserved dims of the original op as single
1526 // elements to `collapsedOpToOrigOpIterationDim`.
1527 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1528 if (processedDims.count(dim))
1529 continue;
1530 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1531 }
1532
1533 llvm::sort(collapsedOpToOrigOpIterationDim,
1535 return lhs[0] < rhs[0];
1536 });
1537 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1538 for (const auto &foldedDims :
1539 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1540 for (const auto &dim : enumerate(foldedDims.value()))
1541 origOpToCollapsedOpIterationDim[dim.value()] =
1542 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1543 }
1544 return success();
1545 }
1546
1547 /// Return mapping from collapsed loop domain to original loop domain.
1548 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1549 return collapsedOpToOrigOpIterationDim;
1550 }
1551
1552 /// Return mapping from original loop domain to collapsed loop domain. The
1553 /// mapping is a pair. First value is the dimension in the collapsed loop that
1554 /// the original loop is mapped to. Second is the relative position in folded
1555 /// list of this domain. For example if the original loop domain is 3D, and
1556 /// the collapsed loop domain is folding all of it, i.e.
1557 ///
1558 /// ```
1559 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1560 /// ```
1561 ///
1562 /// then
1563 ///
1564 /// ```
1565 /// origOpToCollapsedOpMapping[0] = {0, 0};
1566 /// origOpToCollapsedOpMapping[1] = {0, 1};
1567 /// origOpToCollapsedOpMapping[2] = {0, 2};
1568 /// origOpToCollapsedOpMapping[3] = {1, 0};
1569 /// origOpToCollapsedOpMapping[4] = {1, 1};
1570 /// ```
1571 ///
1572 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1573 return origOpToCollapsedOpIterationDim;
1574 }
1575
1576 /// Return the collapsed op iteration domain rank.
1577 unsigned getCollapsedOpIterationRank() const {
1578 return collapsedOpToOrigOpIterationDim.size();
1579 }
1580
1581private:
1582 /// Map from the iteration domain index in collapsed op to the iteration
1583 /// domain indices in the original op.
1584 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1585
1586 /// Map from iteration domain index in the original op to the iteration domain
1587 /// index in the collapsed op.
1588 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1589};
1590} // namespace
1591
1592/// Get the iterator types for the collapsed operation given the original
1593/// iterator types and collapsed dimensions.
1594static SmallVector<utils::IteratorType>
1595getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1596 const CollapsingInfo &collapsingInfo) {
1597 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1598 for (ReassociationIndicesRef foldedIterDims :
1599 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1600 assert(!foldedIterDims.empty() &&
1601 "reassociation indices expected to have non-empty sets");
1602 // Just pick the iterator type of the first folded dim. Pre-condition checks
1603 // expected to have checked that iterator types of all folded dimensions are
1604 // the same.
1605 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1606 }
1607 return collapsedIteratorTypes;
1608}
1609
1610/// Compute the indexing map in the collapsed op that corresponds to the given
1611/// `indexingMap` of the original operation.
1612static AffineMap
1613getCollapsedOpIndexingMap(AffineMap indexingMap,
1614 const CollapsingInfo &collapsingInfo) {
1615 MLIRContext *context = indexingMap.getContext();
1616 assert(indexingMap.isProjectedPermutation() &&
1617 "expected indexing map to be projected permutation");
1618 SmallVector<AffineExpr> resultExprs;
1619 auto origOpToCollapsedOpMapping =
1620 collapsingInfo.getOrigOpToCollapsedOpMapping();
1621 for (auto expr : indexingMap.getResults()) {
1622 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1623 // If the dim is not the first of the collapsed dim, do nothing.
1624 if (origOpToCollapsedOpMapping[dim].second != 0)
1625 continue;
1626 // The next n-dims are guaranteed to be collapsed. So just use the
1627 // iteration dimension of the collapsed op.
1628 resultExprs.push_back(
1629 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1630 }
1631 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1632 resultExprs, context);
1633}
1634
1635/// Return the `reassociation` indices to use to collapse the operand when the
1636/// iteration space of a generic op is collapsed.
1637static SmallVector<ReassociationIndices>
1638getOperandReassociation(AffineMap indexingMap,
1639 const CollapsingInfo &collapsingInfo) {
1640 unsigned counter = 0;
1641 SmallVector<ReassociationIndices> operandReassociation;
1642 auto origOpToCollapsedOpMapping =
1643 collapsingInfo.getOrigOpToCollapsedOpMapping();
1644 auto collapsedOpToOrigOpMapping =
1645 collapsingInfo.getCollapsedOpToOrigOpMapping();
1646 while (counter < indexingMap.getNumResults()) {
1647 unsigned dim =
1648 cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1649 // This is the start of a collapsed dimensions of the iteration that
1650 // is gauranteed to be preserved in the indexing map. The number of folded
1651 // dims is obtained from the collapsed op to original op mapping.
1652 unsigned numFoldedDims =
1653 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1654 .size();
1655 if (origOpToCollapsedOpMapping[dim].second == 0) {
1656 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1657 operandReassociation.emplace_back(range.begin(), range.end());
1658 }
1659 counter += numFoldedDims;
1660 }
1661 return operandReassociation;
1662}
1663
1664/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1665static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1666 OpOperand *opOperand,
1667 const CollapsingInfo &collapsingInfo,
1668 OpBuilder &builder) {
1669 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1670 SmallVector<ReassociationIndices> operandReassociation =
1671 getOperandReassociation(indexingMap, collapsingInfo);
1672
1673 // If the number of entries in the reassociation for the operand is same as
1674 // the number of results of the indexing map, then nothing to do for this
1675 // operand.
1676 Value operand = opOperand->get();
1677 if (operandReassociation.size() == indexingMap.getNumResults())
1678 return operand;
1679
1680 // Insert a reshape to collapse the dimensions.
1681 if (isa<MemRefType>(operand.getType())) {
1682 return memref::CollapseShapeOp::create(builder, loc, operand,
1683 operandReassociation)
1684 .getResult();
1685 }
1686 return tensor::CollapseShapeOp::create(builder, loc, operand,
1687 operandReassociation)
1688 .getResult();
1689}
1690
1691/// Modify the `linalg.index` operations in the original generic op, to its
1692/// value in the collapsed operation.
1693static void generateCollapsedIndexingRegion(
1694 Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1695 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1696 OpBuilder::InsertionGuard g(rewriter);
1697 rewriter.setInsertionPointToStart(block);
1698
1699 // Collect all the original index ops.
1700 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1701
1702 // For each folded dimension list resolve the original induction variable
1703 // values in terms of the folded dimension induction variable.
1704 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1705 // can be inverted to
1706 // i2 = i_{folded} % d2
1707 // i1 = (i_{folded} / d2) % d1
1708 // i0 = i_{folded} / (d1 * d2)
1709 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1710 for (auto foldedDims :
1711 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1712 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1713 Value newIndexVal =
1714 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1715 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1716 Value loopDim =
1717 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1718 indexReplacementVals[dim] =
1719 rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1720 newIndexVal =
1721 rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1722 }
1723 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1724 }
1725
1726 for (auto indexOp : indexOps) {
1727 auto dim = indexOp.getDim();
1728 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1729 }
1730}
1731
1732static void collapseOperandsAndResults(LinalgOp op,
1733 const CollapsingInfo &collapsingInfo,
1734 RewriterBase &rewriter,
1735 SmallVectorImpl<Value> &inputOperands,
1736 SmallVectorImpl<Value> &outputOperands,
1737 SmallVectorImpl<Type> &resultTypes) {
1738 Location loc = op->getLoc();
1739 inputOperands =
1740 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1741 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1742 rewriter);
1743 });
1744
1745 // Get the output operands and result types.
1746 resultTypes.reserve(op.getNumDpsInits());
1747 outputOperands.reserve(op.getNumDpsInits());
1748 for (OpOperand &output : op.getDpsInitsMutable()) {
1749 Value newOutput =
1750 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1751 outputOperands.push_back(newOutput);
1752 // If the op has "buffer semantics", then the init operands are ranked
1753 // memrefs and the op has no results.
1754 if (!op.hasPureBufferSemantics())
1755 resultTypes.push_back(newOutput.getType());
1756 }
1757}
1758
1759/// Clone a `LinalgOp` to a collapsed version of same name
1760template <typename OpTy>
1761static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1762 const CollapsingInfo &collapsingInfo) {
1763 return nullptr;
1764}
1765
1766/// Collapse any `LinalgOp` that does not require any specialization such as
1767/// indexing_maps, iterator_types, etc.
1768template <>
1769LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1770 const CollapsingInfo &collapsingInfo) {
1771 SmallVector<Value> inputOperands, outputOperands;
1772 SmallVector<Type> resultTypes;
1773 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1774 outputOperands, resultTypes);
1775
1776 return clone(
1777 rewriter, origOp, resultTypes,
1778 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1779}
1780
1781/// Collapse a `GenericOp`
1782template <>
1783GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1784 GenericOp origOp,
1785 const CollapsingInfo &collapsingInfo) {
1786 SmallVector<Value> inputOperands, outputOperands;
1787 SmallVector<Type> resultTypes;
1788 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1789 outputOperands, resultTypes);
1790 SmallVector<AffineMap> indexingMaps(
1791 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1792 return getCollapsedOpIndexingMap(map, collapsingInfo);
1793 }));
1794
1795 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1796 origOp.getIteratorTypesArray(), collapsingInfo));
1797
1798 GenericOp collapsedOp = linalg::GenericOp::create(
1799 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1800 indexingMaps, iteratorTypes,
1801 [](OpBuilder &builder, Location loc, ValueRange args) {});
1802 Block *origOpBlock = &origOp->getRegion(0).front();
1803 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1804 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1805 collapsedOpBlock->getArguments());
1806 return collapsedOp;
1807}
1808
1809static LinalgOp createCollapsedOp(LinalgOp op,
1810 const CollapsingInfo &collapsingInfo,
1811 RewriterBase &rewriter) {
1812 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1813 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1814 }
1815 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1816}
1817
1818/// Implementation of fusion with reshape operation by collapsing dimensions.
1819FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1820 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1821 RewriterBase &rewriter) {
1822 // Bail on trivial no-op cases.
1823 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1824 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1825 return foldedDims.size() <= 1;
1826 }))
1827 return failure();
1828
1829 CollapsingInfo collapsingInfo;
1830 if (failed(
1831 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1832 return rewriter.notifyMatchFailure(
1833 op, "illegal to collapse specified dimensions");
1834 }
1835
1836 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1837 if (hasPureBufferSemantics &&
1838 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
1839 MemRefType memRefToCollapse =
1840 dyn_cast<MemRefType>(opOperand.get().getType());
1841 if (!memRefToCollapse)
1842 return true;
1843
1844 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1845 SmallVector<ReassociationIndices> operandReassociation =
1846 getOperandReassociation(indexingMap, collapsingInfo);
1847 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1848 memRefToCollapse, operandReassociation);
1849 }))
1850 return rewriter.notifyMatchFailure(op,
1851 "memref is not guaranteed collapsible");
1852
1853 // Bail on non-canonical ranges.
1854 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1855 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1856 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1857 return cast<IntegerAttr>(attr).getInt() == value;
1858 llvm::APInt actual;
1859 return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1860 actual.getSExtValue() == value;
1861 };
1862 if (!llvm::all_of(loopRanges, [&](Range range) {
1863 return opFoldIsConstantValue(range.offset, 0) &&
1864 opFoldIsConstantValue(range.stride, 1);
1865 })) {
1866 return rewriter.notifyMatchFailure(
1867 op, "expected all loop ranges to have zero start and unit stride");
1868 }
1869
1870 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1871
1872 Location loc = op->getLoc();
1873 SmallVector<OpFoldResult> loopBound =
1874 llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1875
1876 if (collapsedOp.hasIndexSemantics()) {
1877 // Collect the loop range of the generic op.
1878 OpBuilder::InsertionGuard g(rewriter);
1879 rewriter.setInsertionPoint(collapsedOp);
1880 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1881 collapsingInfo, loopBound, rewriter);
1882 }
1883
1884 // Insert expanding reshape for the result to get back the original result
1885 // type.
1886 SmallVector<Value> results;
1887 for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1888 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1889 auto originalResultType =
1890 cast<ShapedType>(originalResult.value().getType());
1891 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1892 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1893 AffineMap indexingMap =
1894 op.getIndexingMapMatchingResult(originalResult.value());
1895 SmallVector<ReassociationIndices> reassociation =
1896 getOperandReassociation(indexingMap, collapsingInfo);
1897 assert(
1898 indexingMap.isProjectedPermutation() &&
1899 "Expected indexing map to be a projected permutation for collapsing");
1900 SmallVector<OpFoldResult> resultShape =
1901 applyPermutationMap(indexingMap, ArrayRef(loopBound));
1902 Value result;
1903 if (isa<MemRefType>(collapsedOpResult.getType())) {
1904 result = memref::ExpandShapeOp::create(
1905 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1906 resultShape);
1907 } else {
1908 result = tensor::ExpandShapeOp::create(
1909 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1910 resultShape);
1911 }
1912 results.push_back(result);
1913 } else {
1914 results.push_back(collapsedOpResult);
1915 }
1916 }
1917 return CollapseResult{results, collapsedOp};
1918}
1919
1920namespace {
1921
1922/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1923/// contracting dimensions of the loop.
1924class FoldWithProducerReshapeOpByCollapsing
1925 : public OpRewritePattern<GenericOp> {
1926public:
1927 // TODO : support fusion with all linalg ops, not just generic.
1928 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1929 ControlFusionFn foldReshapes,
1930 PatternBenefit benefit = 1)
1931 : OpRewritePattern<GenericOp>(context, benefit),
1932 controlFoldingReshapes(std::move(foldReshapes)) {}
1933
1934 LogicalResult matchAndRewrite(GenericOp genericOp,
1935 PatternRewriter &rewriter) const override {
1936 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1937 tensor::ExpandShapeOp reshapeOp =
1938 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1939 if (!reshapeOp)
1940 continue;
1941
1942 SmallVector<ReassociationIndices> collapsableIterationDims =
1943 getCollapsableIterationSpaceDims(genericOp, &opOperand,
1944 reshapeOp.getReassociationIndices());
1945 if (collapsableIterationDims.empty() ||
1946 !controlFoldingReshapes(&opOperand)) {
1947 continue;
1948 }
1949
1950 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1951 genericOp, collapsableIterationDims, rewriter);
1952 if (!collapseResult) {
1953 return rewriter.notifyMatchFailure(
1954 genericOp, "failed to do the fusion by collapsing transformation");
1955 }
1956
1957 rewriter.replaceOp(genericOp, collapseResult->results);
1958 return success();
1959 }
1960 return failure();
1961 }
1962
1963private:
1964 ControlFusionFn controlFoldingReshapes;
1965};
1966
1967/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1968/// by expanding the dimensionality of the loop in the producer op.
1969struct FoldReshapeWithGenericOpByCollapsing
1970 : public OpRewritePattern<tensor::CollapseShapeOp> {
1971
1972 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1973 ControlFusionFn foldReshapes,
1974 PatternBenefit benefit = 1)
1975 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1976 controlFoldingReshapes(std::move(foldReshapes)) {}
1977
1978 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1979 PatternRewriter &rewriter) const override {
1980 // Fold only if all constraints of fusing with reshape by collapsing are
1981 // met.
1982 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1983 if (!producerResult) {
1984 return rewriter.notifyMatchFailure(reshapeOp,
1985 "source not produced by an operation");
1986 }
1987
1988 // TODO : support fusion with all linalg producers, not just generic.
1989 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1990 if (!producer) {
1991 return rewriter.notifyMatchFailure(reshapeOp,
1992 "producer not a generic op");
1993 }
1994
1995 SmallVector<ReassociationIndices> collapsableIterationDims =
1997 producer,
1998 producer.getDpsInitOperand(producerResult.getResultNumber()),
1999 reshapeOp.getReassociationIndices());
2000 if (collapsableIterationDims.empty()) {
2001 return rewriter.notifyMatchFailure(
2002 reshapeOp, "failed preconditions of fusion with producer generic op");
2003 }
2004
2005 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2006 return rewriter.notifyMatchFailure(reshapeOp,
2007 "fusion blocked by control function");
2008 }
2009
2010 // Set the insertion point after `producer` because there could be uses
2011 // of `producer` between it and the `tensor.collapse_shape` op.
2012 rewriter.setInsertionPointAfter(producer);
2013 std::optional<CollapseResult> collapseResult =
2014 collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
2015 if (!collapseResult) {
2016 return rewriter.notifyMatchFailure(
2017 producer, "failed to do the fusion by collapsing transformation");
2018 }
2019
2020 rewriter.replaceOp(producer, collapseResult->results);
2021 return success();
2022 }
2023
2024private:
2025 ControlFusionFn controlFoldingReshapes;
2026};
2027
2028/// Computes the collapsed padding information for the given pad operation based
2029/// on the provided collapsed shape and reassociation indices. Returns a
2030/// PadDimInfo containing the low and high padding amounts and the collapsed
2031/// shape for each dimension, or failure if the collapse is not possible.
2032static FailureOr<PadDimInfo>
2033computeCollapsedPadding(tensor::PadOp padOp,
2034 ArrayRef<ReassociationIndices> reassociations,
2035 PatternRewriter &rewriter) {
2036 // If the padding value depends on the index values of the pad operation,
2037 // then it may not be valid to collapse the dimensions, since it will change
2038 // the index values on which the padding value depends. This is not currently
2039 // supported by the pad collapsing patterns, but it could be implemented
2040 // similarly to the collapsing of linalg.generic ops with linalg.index ops in
2041 // the body, as is done in `generateCollapsedIndexingRegion`.
2042 if (!padOp.getConstantPaddingValue())
2043 return failure();
2044
2045 // Collapsed dimensions cannot have padding because this can produce strided
2046 // padding that isn't representable by a tensor.pad op. There are some special
2047 // cases where it is possible (like collapsing unit dims), but supporting
2048 // these cases is NYI, so disallow it for now.
2049 ArrayRef<int64_t> low = padOp.getStaticLow();
2050 ArrayRef<int64_t> high = padOp.getStaticHigh();
2051 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2052 for (int64_t dim : reInd) {
2053 if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2054 return failure();
2055 }
2056 }
2057
2058 // Initialize padding values for collapsed tensors with zeros
2059 ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2060 PadDimInfo padDimInfo;
2061 padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2062 padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2063
2064 // Update padding for dimensions that are not being collapsed, and compute
2065 // the collapsed padded shape.
2066 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2067 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2068 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2069 if (reInd.size() == 1) {
2070 padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2071 padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2072 }
2073 SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
2074 for (int64_t dim : reInd) {
2075 collapsedSize =
2076 collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
2077 }
2078 padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
2079 }
2080
2081 return padDimInfo;
2082}
2083
2084class FoldPadWithProducerReshapeOpByCollapsing
2085 : public OpRewritePattern<tensor::PadOp> {
2086public:
2087 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
2088 ControlFusionFn foldReshapes,
2089 PatternBenefit benefit = 1)
2090 : OpRewritePattern<tensor::PadOp>(context, benefit),
2091 controlFoldingReshapes(std::move(foldReshapes)) {}
2092
2093 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2094 PatternRewriter &rewriter) const override {
2095 tensor::ExpandShapeOp reshapeOp =
2096 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
2097 if (!reshapeOp)
2098 return failure();
2099
2100 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
2101 return rewriter.notifyMatchFailure(padOp,
2102 "fusion blocked by control function");
2103 }
2104
2105 SmallVector<ReassociationIndices> reassociations =
2106 reshapeOp.getReassociationIndices();
2107 FailureOr<PadDimInfo> maybeCollapsedPadding =
2108 computeCollapsedPadding(padOp, reassociations, rewriter);
2109 if (failed(maybeCollapsedPadding))
2110 return failure();
2111 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2112
2113 SmallVector<OpFoldResult> expandedPaddedSizes =
2114 reshapeOp.getMixedOutputShape();
2115 AffineExpr d0, d1, d2;
2116 bindDims(rewriter.getContext(), d0, d1, d2);
2117 auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
2118 Location loc = reshapeOp->getLoc();
2119 for (auto [reInd, l, h] :
2120 llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2121 collapsedPadding.highPad)) {
2122 if (reInd.size() == 1) {
2123 expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
2124 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
2125 }
2126 }
2127
2128 RankedTensorType collapsedPaddedType =
2129 padOp.getType().clone(collapsedPadding.paddedShape);
2130 auto newPadOp = tensor::PadOp::create(
2131 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2132 collapsedPadding.lowPad, collapsedPadding.highPad,
2133 padOp.getConstantPaddingValue(), padOp.getNofold());
2134
2135 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
2136 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2137 expandedPaddedSizes);
2138
2139 return success();
2140 }
2141
2142private:
2143 ControlFusionFn controlFoldingReshapes;
2144};
2145
2146class FoldReshapeWithProducerPadOpByCollapsing
2147 : public OpRewritePattern<tensor::CollapseShapeOp> {
2148public:
2149 FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2150 ControlFusionFn foldReshapes,
2151 PatternBenefit benefit = 1)
2152 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2153 controlFoldingReshapes(std::move(foldReshapes)) {}
2154
2155 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2156 PatternRewriter &rewriter) const override {
2157 tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2158 if (!padOp)
2159 return failure();
2160
2161 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2162 return rewriter.notifyMatchFailure(padOp,
2163 "fusion blocked by control function");
2164 }
2165
2166 SmallVector<ReassociationIndices> reassociations =
2167 reshapeOp.getReassociationIndices();
2168 RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2169 FailureOr<PadDimInfo> maybeCollapsedPadding =
2170 computeCollapsedPadding(padOp, reassociations, rewriter);
2171 if (failed(maybeCollapsedPadding))
2172 return failure();
2173 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2174
2175 Location loc = reshapeOp->getLoc();
2176 auto newCollapseOp = tensor::CollapseShapeOp::create(
2177 rewriter, loc, padOp.getSource(), reassociations);
2178
2179 auto newPadOp = tensor::PadOp::create(
2180 rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2181 collapsedPadding.lowPad, collapsedPadding.highPad,
2182 padOp.getConstantPaddingValue(), padOp.getNofold());
2183
2184 rewriter.replaceOp(reshapeOp, newPadOp.getResult());
2185 return success();
2186 }
2187
2188private:
2189 ControlFusionFn controlFoldingReshapes;
2190};
2191
2192/// Pattern to collapse dimensions.
2193template <typename LinalgType>
2194class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2195public:
2196 CollapseLinalgDimensions(MLIRContext *context,
2197 GetCollapsableDimensionsFn collapseDimensions,
2198 PatternBenefit benefit = 1)
2199 : OpRewritePattern<LinalgType>(context, benefit),
2200 controlCollapseDimension(std::move(collapseDimensions)) {}
2201
2202 LogicalResult matchAndRewrite(LinalgType op,
2203 PatternRewriter &rewriter) const override {
2204 SmallVector<ReassociationIndices> collapsableIterationDims =
2205 controlCollapseDimension(op);
2206 if (collapsableIterationDims.empty())
2207 return failure();
2208
2209 // Check if the specified list of dimensions to collapse is a valid list.
2210 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2211 collapsableIterationDims)) {
2212 return rewriter.notifyMatchFailure(
2213 op, "specified dimensions cannot be collapsed");
2214 }
2215
2216 std::optional<CollapseResult> collapseResult =
2217 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2218 if (!collapseResult) {
2219 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2220 }
2221 rewriter.replaceOp(op, collapseResult->results);
2222 return success();
2223 }
2224
2225private:
2226 GetCollapsableDimensionsFn controlCollapseDimension;
2227};
2228
2229} // namespace
2230
2231//===---------------------------------------------------------------------===//
2232// Methods and patterns that fuse constants with linalg.generic operations.
2233//===---------------------------------------------------------------------===//
2234
2235namespace {
2236/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2237/// handle cases where the constant is not single-valued.
2238class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2239public:
2240 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2241 : OpRewritePattern<GenericOp>(context, benefit) {}
2242
2243 LogicalResult matchAndRewrite(GenericOp genericOp,
2244 PatternRewriter &rewriter) const override {
2245 if (!genericOp.hasPureTensorSemantics())
2246 return failure();
2247 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2248 Operation *def = opOperand->get().getDefiningOp();
2249 TypedAttr constantAttr;
2250 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2251 {
2252 DenseElementsAttr splatAttr;
2253 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2254 splatAttr.isSplat() &&
2255 splatAttr.getType().getElementType().isIntOrFloat()) {
2256 constantAttr = splatAttr.getSplatValue<TypedAttr>();
2257 return true;
2258 }
2259 }
2260 {
2261 IntegerAttr intAttr;
2262 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2263 constantAttr = intAttr;
2264 return true;
2265 }
2266 }
2267 {
2268 FloatAttr floatAttr;
2269 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2270 constantAttr = floatAttr;
2271 return true;
2272 }
2273 }
2274 return false;
2275 };
2276
2277 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2278 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2279 continue;
2280
2281 // The operands and the indexing_maps of the fused operation the same as
2282 // the operands and indexing_maps of the generic operations with the
2283 // values at the constant index dropped.
2284 SmallVector<AffineMap> fusedIndexMaps;
2285 SmallVector<Value> fusedOperands;
2286 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2287 fusedIndexMaps.reserve(genericOp->getNumOperands());
2288 fusedOperands.reserve(genericOp.getNumDpsInputs());
2289 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2290 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2291 if (inputOperand == opOperand)
2292 continue;
2293 Value inputValue = inputOperand->get();
2294 fusedIndexMaps.push_back(
2295 genericOp.getMatchingIndexingMap(inputOperand));
2296 fusedOperands.push_back(inputValue);
2297 fusedLocs.push_back(inputValue.getLoc());
2298 }
2299 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2300 fusedIndexMaps.push_back(
2301 genericOp.getMatchingIndexingMap(&outputOperand));
2302
2303 // Check if the operation shapes to loops map is computable.
2304 if (!inversePermutation(
2305 concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2306 return rewriter.notifyMatchFailure(
2307 genericOp, "fused op loop bound computation failed");
2308 }
2309
2310 // Create a constant scalar value from the splat constant.
2311 Value scalarConstant =
2312 arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
2313
2314 SmallVector<Value> outputOperands = genericOp.getOutputs();
2315 auto fusedOp =
2316 GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs),
2317 genericOp->getResultTypes(),
2318 /*inputs=*/fusedOperands,
2319 /*outputs=*/outputOperands,
2320 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2321 genericOp.getIteratorTypes(),
2322 /*doc=*/nullptr,
2323 /*library_call=*/nullptr);
2324
2325 // Map the block argument corresponding to the replaced argument with the
2326 // scalar constant.
2327 Region &region = genericOp->getRegion(0);
2328 Block &entryBlock = *region.begin();
2329 IRMapping mapping;
2330 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2331 scalarConstant);
2332 Region &fusedRegion = fusedOp->getRegion(0);
2333 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2334 mapping);
2335 rewriter.replaceOp(genericOp, fusedOp->getResults());
2336 return success();
2337 }
2338 return failure();
2339 }
2340};
2341
2342} // namespace
2343
2344//===---------------------------------------------------------------------===//
2345// Miscellaneous patterns that help fusion.
2346//===---------------------------------------------------------------------===//
2347
2348namespace {
2349/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2350/// value of the `outs` operand is not used within the op. This is only
2351/// implemented for `linalg.generic` operations for now, but should hold for all
2352/// linalg structured ops.
2353struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2354 using OpRewritePattern<GenericOp>::OpRewritePattern;
2355
2356 LogicalResult matchAndRewrite(GenericOp op,
2357 PatternRewriter &rewriter) const override {
2358 rewriter.startOpModification(op);
2359 bool modifiedOutput = false;
2360 Location loc = op.getLoc();
2361 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2362 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2363 Value operandVal = opOperand.get();
2364 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2365 if (!operandType)
2366 continue;
2367
2368 // If outs is sparse, leave it to the sparsifier.
2370 continue;
2371
2372 // If outs is already an `empty` operation, nothing to do.
2373 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2374 if (definingOp)
2375 continue;
2376 modifiedOutput = true;
2377 SmallVector<OpFoldResult> mixedSizes =
2378 tensor::getMixedSizes(rewriter, loc, operandVal);
2379 Value emptyTensor = tensor::EmptyOp::create(
2380 rewriter, loc, mixedSizes, operandType.getElementType());
2381 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2382 }
2383 }
2384 if (!modifiedOutput) {
2385 rewriter.cancelOpModification(op);
2386 return failure();
2387 }
2388 rewriter.finalizeOpModification(op);
2389 return success();
2390 }
2391};
2392
2393/// Fold linalg.fill into linalg.generic
2394struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2395 using OpRewritePattern<GenericOp>::OpRewritePattern;
2396
2397 LogicalResult matchAndRewrite(GenericOp genericOp,
2398 PatternRewriter &rewriter) const override {
2399 if (!genericOp.hasPureTensorSemantics())
2400 return failure();
2401 bool fillFound = false;
2402 Block &payload = genericOp.getRegion().front();
2403 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2404 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2405 continue;
2406 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2407 if (!fillOp)
2408 continue;
2409 fillFound = true;
2410 Value fillVal = fillOp.value();
2411 auto resultType =
2412 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2413 Value convertedVal =
2414 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2415 /*isUnsignedCast =*/false);
2416 rewriter.replaceAllUsesWith(
2417 payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2418 }
2419 return success(fillFound);
2420 }
2421};
2422} // namespace
2423
2426 const ControlFusionFn &controlFoldingReshapes) {
2427 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2428 controlFoldingReshapes);
2429 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2430 controlFoldingReshapes);
2431 patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2432 controlFoldingReshapes);
2433 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2434 controlFoldingReshapes);
2435}
2436
2439 const ControlFusionFn &controlFoldingReshapes) {
2440 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2441 controlFoldingReshapes);
2442 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2443 patterns.getContext(), controlFoldingReshapes);
2444 patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2445 patterns.getContext(), controlFoldingReshapes);
2446 patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2447 controlFoldingReshapes);
2448}
2449
2452 const ControlFusionFn &controlElementwiseOpsFusion) {
2453 auto *context = patterns.getContext();
2454 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2455 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2456 RemoveOutsDependency>(context);
2457 // Add the patterns that clean up dead operands and results.
2459}
2460
2463 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2464 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2465 CollapseLinalgDimensions<linalg::CopyOp>>(
2466 patterns.getContext(), controlCollapseDimensions);
2467}
2468
2469//===---------------------------------------------------------------------===//
2470// Passes
2471//===---------------------------------------------------------------------===//
2472
2473namespace {
2474
2475/// Pass that fuses generic ops on tensors. Used only for testing.
2476// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2477// patterns added here heavily depends on the cost function used. Having an
2478// opinionated pass of this form is not recommended. Deprecate this pass in
2479// favor of test passes that check the functionality of each of the patterns
2480// added here individually.
2481struct LinalgElementwiseOpFusionPass
2483 LinalgElementwiseOpFusionPass> {
2485 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2486 void runOnOperation() override {
2487 Operation *op = getOperation();
2488 MLIRContext *context = op->getContext();
2489 RewritePatternSet patterns(context);
2490
2491 // Add folding with reshape by expansion patterns.
2492 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2493 Operation *producer = fusedOperand->get().getDefiningOp();
2494 return producer && producer->hasOneUse();
2495 };
2496
2497 // Add elementwise op fusion patterns.
2501
2502 // General canonicalization patterns.
2503 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2504 GenericOp::getCanonicalizationPatterns(patterns, context);
2505 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2506 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2507 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2508 patterns);
2509
2510 // Add constant folding patterns.
2512
2513 // Use TopDownTraversal for compile time reasons.
2514 (void)applyPatternsGreedily(op, std::move(patterns),
2515 GreedyRewriteConfig().setUseTopDownTraversal());
2516 }
2517};
2518
2519} // namespace
return success()
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
ArrayRef< ReassociationIndices > getCollapsedOpToOrigOpMapping() const
Return mapping from collapsed loop domain to original loop domain.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
lhs
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:593
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:232
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
This transformation is intended to be used with a top-down traversal (from producer to consumer).
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:239
ArrayRef< int64_t > ReassociationIndicesRef
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values (and their transitive dependencies) before insertionPoint.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition Transforms.h:656
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.