MLIR 22.0.0git
ElementwiseOpFusion.cpp
Go to the documentation of this file.
1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Support/LLVM.h"
29#include <optional>
30#include <utility>
31
32namespace mlir {
33#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34#include "mlir/Dialect/Linalg/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38using namespace mlir::linalg;
39
40//===---------------------------------------------------------------------===//
41// Methods and patterns that fuse elementwise `linalg.generic` operations.
42//===---------------------------------------------------------------------===//
43
44/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45/// the `producer` to use in the fused operation given the indexing map of the
46/// result of the producer in the consumer.
48 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49 AffineMap fusedConsumerArgIndexMap) {
50 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51 // from consumer loop -> consumer arg tensor index/producer result tensor
52 // index. The fused loop is same as the consumer loop. For each producer arg
53 // the indexing map to be computed is a map from consumer loop -> producer
54 // arg tensor index.
55 // producerResultIndexMap is a map from producer loop -> tensor index.
56 // Compute the inverse to get map from tensor index -> producer loop.
57 // The inverse is a map from producer result tensor index -> producer loop.
58 AffineMap invProducerResultIndexMap =
59 inversePermutation(producerResultIndexMap);
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
62
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64 // argMap is a map from producer loop -> producer arg tensor index.
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66
67 // Compose argMap with invProducerResultIndexMap to get a map from
68 // producer result tensor index -> producer arg tensor index.
69 AffineMap t1 = argMap.compose(invProducerResultIndexMap);
70
71 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72 // consumer loop/ fused loop -> producer arg tensor index.
73 return t1.compose(fusedConsumerArgIndexMap);
74}
75
76// Checks if the given operand can be dropped, and the remaining operands
77// of the fused producer & consumer after the fusion can still compute the
78// bounds of the op.
80 GenericOp producer, GenericOp consumer,
81 ArrayRef<OpOperand *> opOperandsToIgnore) {
82 SmallVector<AffineMap> indexingMaps;
83
84 SmallVector<GenericOp> ops = {producer, consumer};
85 for (auto &op : ops) {
86 for (auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88 continue;
89 }
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91 }
92 }
93 if (indexingMaps.empty()) {
94 // If there are no indexing maps, the operand can only be dropped
95 // if neither op has loops.
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97 }
98
99 // The concatanation of the remained indexing maps must be invertible, so
100 // the bounds of the op can be still computed after dropping the selected
101 // operand. inversePermutation returns an empty AffineMap in case the
102 // concatanated indexing maps are not invertible.
104 indexingMaps, producer.getContext())) != AffineMap();
105}
106
107/// Returns a set of indices of the producer's results which would
108/// be preserved after the fusion.
109/// * There is a chance that the implementation of the transformation does not
110/// agree with the result of this method. This function gives a prediction based
111/// on an optimized fusion.
113 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
115 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116
117 // The fusedOperand will be removed during the fusion
118 opOperandsToIgnore.emplace_back(fusedOperand);
119
120 for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127 return user != consumer.getOperation();
128 })) {
129 preservedProducerResults.insert(producerResult.index());
130
131 // In case the operand can't be dropped
132 (void)opOperandsToIgnore.pop_back_val();
133 }
134 }
135 return preservedProducerResults;
136}
137
138/// Conditions for elementwise fusion of generic operations.
140 if (!fusedOperand)
141 return false;
142
143 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145
146 // Check producer and consumer are generic ops.
147 if (!producer || !consumer)
148 return false;
149
150 // Consumer can have mixed semantics, just check operand itself has tensor
151 // type. Producer must have full tensor semantics to avoid potential
152 // aliasing between producer and consumer memrefs.
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->get().getType()))
155 return false;
156
157 // Verify that
158 // - the producer has all "parallel" iterator type.
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
160 return false;
161
162 // Only allow fusing the producer of an input operand for now.
163 // TODO: allow fusing the producer of an output operand.
164 if (!consumer.isDpsInput(fusedOperand))
165 return false;
166
167 // Get the consumer index map. The number of results of the consumer index
168 // map must match the number of loops of the producer.
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171 return false;
172
173 // Finally the index_map for the result must be invertible. For now just
174 // verify it is a permutation.
175 auto producerResult = cast<OpResult>(fusedOperand->get());
176 AffineMap producerResultIndexMap =
177 producer.getIndexingMapMatchingResult(producerResult);
178 if (!producerResultIndexMap.isPermutation())
179 return false;
180
181 // Ensure that the fusion does not remove size information required to
182 // get the loop bounds. For non-reduction generics, this is trivially the
183 // case due to the output operand. For reductions, we need to check that after
184 // the fusion, each loop dimension has at least one input that defines it.
185 if ((consumer.getNumReductionLoops())) {
186 BitVector coveredDims(consumer.getNumLoops(), false);
187
188 auto addToCoveredDims = [&](AffineMap map) {
189 for (auto result : map.getResults())
190 if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
191 coveredDims[dimExpr.getPosition()] = true;
192 };
193
194 for (auto pair :
195 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196 Value operand = std::get<0>(pair);
197 if (operand == fusedOperand->get())
198 continue;
199 AffineMap operandMap = std::get<1>(pair);
200 addToCoveredDims(operandMap);
201 }
202
203 for (OpOperand *operand : producer.getDpsInputOperands()) {
204 AffineMap newIndexingMap =
206 operand, producerResultIndexMap, consumerIndexMap);
207 addToCoveredDims(newIndexingMap);
208 }
209 if (!coveredDims.all())
210 return false;
211 }
212
213 return true;
214}
215
216/// Generate the region of the fused tensor operation. The region of the fused
217/// op must be empty.
219 RewriterBase &rewriter, GenericOp fusedOp,
220 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
221 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
222 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
223 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
224 // Build the region of the fused op.
225 Block &producerBlock = producer->getRegion(0).front();
226 Block &consumerBlock = consumer->getRegion(0).front();
227 OpBuilder::InsertionGuard guard(rewriter);
228 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
229 IRMapping mapper;
230
231 // 2. Add an index operation for every fused loop dimension and use the
232 // `consumerToProducerLoopsMap` to map the producer indices.
233 if (producer.hasIndexSemantics()) {
234 // Add an index operation for every fused loop dimension.
235 unsigned numFusedOpLoops = fusedOp.getNumLoops();
236 SmallVector<Value> fusedIndices;
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return IndexOp::create(rewriter, producer.getLoc(), dim);
241 });
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
244 Value newIndex = affine::AffineApplyOp::create(
245 rewriter, producer.getLoc(),
246 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.map(indexOp.getResult(), newIndex);
248 }
249 }
250 // TODO: allow fusing the producer of an output operand.
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
253 // 3. Consumer input operands up to consumerIdx (exclusive).
254 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
255 fusedOperand->getOperandNumber())) // input assumption.
256 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257
258 // Replacing consumerIdx requires getting the cloned, yielded, value from
259 // the (cloned) producer block. This happens in step 9.
260
261 // 4. Splice in producer's input operands.
262 for (BlockArgument bbArg :
263 producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
265
266 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
267 for (BlockArgument bbArg :
268 consumerBlock.getArguments()
269 .take_front(consumer.getNumDpsInputs())
270 .drop_front(fusedOperand->getOperandNumber() + 1))
271 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
272
273 // 6. All of the producer's output operands
274 for (const auto &bbArg : llvm::enumerate(
275 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
277 continue;
278 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
280 }
281
282 // 7. All of consumer's output operands.
283 for (BlockArgument bbArg :
284 consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
286
287 // 8. Clone all producer operations except for the yield and index operations
288 // to the fused operation.
289 for (auto &op : producerBlock.without_terminator()) {
290 if (!isa<IndexOp>(op))
291 rewriter.clone(op, mapper);
292 }
293 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
294 // forward the yield operand.
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->get()).getResultNumber();
299 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
300
301 // Sanity checks, if replacement is not already in the mapper then it must be
302 // produced outside.
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
307 else
308 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
309 "yielded value must have been mapped");
310 }
311 mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
313 // 10. Clone operations from the consumer to the fused op.
314 for (auto &op : consumerBlock.without_terminator())
315 rewriter.clone(op, mapper);
316
317 // 11. Include the final yield (which is the remapped values for all the
318 // yield)
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
320 SmallVector<Value> fusedYieldValues;
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (const auto &producerYieldVal :
324 llvm::enumerate(producerYieldOp.getOperands())) {
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
327 mapper.lookupOrDefault(producerYieldVal.value()));
328 }
329 for (auto consumerYieldVal : consumerYieldOp.getOperands())
330 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
331 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
332
333 // Sanity checks.
334 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
335 "Ill-formed GenericOp region");
336}
337
338FailureOr<mlir::linalg::ElementwiseOpFusionResult>
340 OpOperand *fusedOperand) {
341 assert(areElementwiseOpsFusable(fusedOperand) &&
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
346 // TODO: allow fusing the producer of an output operand.
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
349 /// Find the results of the producer that have uses outside of the consumer,
350 /// after the fusion.
351 llvm::SmallDenseSet<int> preservedProducerResults =
353 fusedOperand);
354
355 // Compute the fused operands list and indexing maps.
356 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
357 SmallVector<Type> fusedResultTypes;
358 SmallVector<AffineMap> fusedIndexMaps;
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
367 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
368 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
371 return operand == fusedOperand;
372 });
373 assert(it != consumerInputs.end() && "expected to find the consumer operand");
374 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
377 }
378 // 4. Splice in producer's input operands/maps.
379 AffineMap producerResultIndexMap =
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
383 // Compute indexing maps for the producer args in the fused operation.
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
388 }
389 // 5. Remaining consumer's input operands/maps (drop past index
390 // `consumerIdx`).
391 for (OpOperand *opOperand :
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
395 }
396
397 // 6. Collect all of the producer outputs.
398 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
400 continue;
401
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
408 }
409
410 // 7. All of consumer's output operands (skip operands: added by the builder).
411 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
417 }
418
419 // Generate the fused op.
420 auto fusedOp = GenericOp::create(
421 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
423 consumer.getIteratorTypes(),
424 /*doc=*/nullptr,
425 /*library_call=*/nullptr);
426 if (!fusedOp.getShapesToLoopsMap()) {
427 // Fused op has invalid indexing maps. Typically this means something is off
428 // in the input, but going ahead here would result in verification errors.
429 // So cleanup and abort.
430 rewriter.eraseOp(fusedOp);
431 return rewriter.notifyMatchFailure(
432 fusedOp, "fused op failed loop bound computation check");
433 }
434
435 // Construct an AffineMap from consumer loops to producer loops.
436 // consumer loop -> tensor index
437 AffineMap consumerResultIndexMap =
438 consumer.getMatchingIndexingMap(fusedOperand);
439 // tensor index -> producer loop
440 AffineMap invProducerResultIndexMap =
441 inversePermutation(producerResultIndexMap);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
444 // consumer loop -> producer loop
445 AffineMap consumerToProducerLoopsMap =
446 invProducerResultIndexMap.compose(consumerResultIndexMap);
447
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
452 result.fusedOp = fusedOp;
453 int resultNum = 0;
454 for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (auto consumerResult : consumer->getResults())
458 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
459 return result;
460}
461
462namespace {
463/// Patterns to fuse a generic op, with the producer of its operands.
464class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
465public:
466 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
467 PatternBenefit benefit = 1)
468 : OpRewritePattern<GenericOp>(context, benefit),
469 controlFn(std::move(fun)) {}
470
471 LogicalResult matchAndRewrite(GenericOp genericOp,
472 PatternRewriter &rewriter) const override {
473 // Find the first operand that is defined by another generic op on tensors.
474 for (OpOperand &opOperand : genericOp->getOpOperands()) {
475 if (!areElementwiseOpsFusable(&opOperand))
476 continue;
477 if (!controlFn(&opOperand))
478 continue;
479
480 Operation *producer = opOperand.get().getDefiningOp();
481
482 // Find the producer of the operand.
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
484 fuseElementwiseOps(rewriter, &opOperand);
485 if (failed(fusionResult))
486 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
487
488 // Perform the fusion.
489 for (auto [origVal, replacement] : fusionResult->replacements) {
490 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
491 // Only replace consumer uses.
492 return use.get().getDefiningOp() != producer;
493 });
494 }
495 rewriter.eraseOp(genericOp);
496 return success();
497 }
498 return failure();
499 }
500
501private:
502 ControlFusionFn controlFn;
503};
504} // namespace
505
506//===---------------------------------------------------------------------===//
507// Methods and patterns that fuse reshape ops with elementwise operations by
508// expanding the dimensionality of the elementwise operations.
509//===---------------------------------------------------------------------===//
510
511/// Conditions for folding a structured linalg operation with a reshape op by
512/// expanding the iteration space dimensionality for tensor operations. These
513/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
514/// the following fusion pattern.
515///
516/// Consider
517///
518/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
519/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
520/// affine_map<(d0, d1, d2) -> (d1, d2)>,
521/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
522/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
523/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
524///
525/// The reshape can be folded into the `linalgOp` if its loop dimensionality
526/// is increased to match the result (operand) of the tensor.expand_shape.
527/// The indexing_map of the fused tensor in the `linalgOp` and the
528/// reassociation map helps compute the indexing maps of the modified op.
529/// For the above example, based on the reassociation map it
530/// can be concluded that
531///
532/// - The loop used to access the first dimension of the fused tensor is split
533/// into two.
534/// - The loop used to access the second dimension of the fused tensor is kept
535/// as is.
536/// - The loop used to access the third dimension of the fused tensor is split
537/// into three.
538///
539/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
540/// op, then
541///
542/// d0 -> e0, e1
543/// d1 -> e2, e3, e4
544/// d2 -> e5
545///
546/// substituting this, the structured op can be rewritten as
547///
548/// %d = linalg.generic ins(%0, %1 : )
549/// indexing_maps =
550/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
551/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
552/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
553///
554/// Since operands to the linalg generic are now 5D, reshapes can be introduced
555/// to make it consistent
556///
557/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
558/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
559/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
560/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
561///
562/// The added reshapes are again expanding patterns, so they will get fused
563/// with its producers if possible.
564static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
565 OpOperand *fusableOpOperand) {
566 // Is fusable only if:
567 // - All the indexing maps for operands and results are projected
568 // permutations.
569 // - The fused tensor is not a scalar.
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575 [](Attribute attr) {
576 return cast<AffineMapAttr>(attr)
577 .getValue()
578 .isProjectedPermutation();
579 }) &&
580 operandMap.getNumResults() > 0;
581}
582
583namespace {
584/// Information needed to expand a generic operation to fold the reshape with
585/// it.
586class ExpansionInfo {
587public:
588 // Computes the mapping from original dimensions of the op to the dimensions
589 // of the expanded op given the `indexingMap` of the fused operand/result of
590 // the generic op, the `reassocationMaps` of the reshape op and the shape of
591 // the expanded op.
592 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593 ArrayRef<AffineMap> reassociationMaps,
594 ArrayRef<OpFoldResult> expandedShape,
595 PatternRewriter &rewriter);
596 unsigned getOrigOpNumDims() const { return reassociation.size(); }
597 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598 ReassociationIndicesRef getExpandedDims(unsigned i) const {
599 return reassociation[i];
600 }
601 ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602 return expandedShapeMap[i];
603 }
604 ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605
606private:
607 /// Reassociation from the dimensions in the original operation to the
608 /// dimension of the expanded operation.
609 SmallVector<ReassociationIndices> reassociation;
610 /// Mapping from extent of loops in the original operation, to the extent of
611 /// loops in the expanded operation.
612 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
613 /// Extent of the loop in the original operation.
614 SmallVector<OpFoldResult> originalLoopExtent;
615 unsigned expandedOpNumDims;
616};
617} // namespace
618
619LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620 OpOperand *fusableOpOperand,
621 ArrayRef<AffineMap> reassociationMaps,
622 ArrayRef<OpFoldResult> expandedShape,
623 PatternRewriter &rewriter) {
624 if (reassociationMaps.empty())
625 return failure();
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
627
628 OpBuilder::InsertionGuard g(rewriter);
629 rewriter.setInsertionPoint(linalgOp);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](Range r) { return r.size; });
633
634 reassociation.clear();
635 expandedShapeMap.clear();
636 // Compute the number of dimension in the expanded op that correspond to each
637 // dimension of the original op.
638 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
639 expandedShapeMap.resize(fusedIndexMap.getNumDims());
640 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
643 numExpandedDims[pos] = foldedDims.getNumResults();
644 ArrayRef<OpFoldResult> shape =
645 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
647 }
648 // The remaining dimensions remain the same.
649 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
652
653 // Compute reassociation map from the original op to the expanded op.
654 unsigned sum = 0;
655 reassociation.reserve(fusedIndexMap.getNumDims());
656 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
660 }
661 expandedOpNumDims = sum;
662 return success();
663}
664
665/// Return the indexing map to use in the expanded op for a given the
666/// `indexingMap` of the original operation.
667static AffineMap
669 const ExpansionInfo &expansionInfo) {
671 for (AffineExpr expr : indexingMap.getResults()) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
673 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
676 }));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
678 }
679 return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
680 indexingMap.getNumSymbols(), newExprs,
681 builder.getContext());
682}
683
684/// Return the shape and type of the operand/result to use in the expanded op
685/// given the type in the original op.
686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688 const ExpansionInfo &expansionInfo) {
689 SmallVector<OpFoldResult> expandedShape;
690 for (AffineExpr expr : indexingMap.getResults()) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
692 ArrayRef<OpFoldResult> dimExpansion =
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
695 }
696 SmallVector<int64_t> expandedStaticShape;
697 std::tie(expandedStaticShape, std::ignore) =
698 decomposeMixedValues(expandedShape);
699 return {expandedShape, RankedTensorType::get(expandedStaticShape,
700 originalType.getElementType())};
701}
702
703/// Returns the reassociation maps to use in the `tensor.expand_shape`
704/// operation to convert the operands of the original operation to operands of
705/// the expanded operation. The same method is used to compute the
706/// `tensor.collapse_shape` used to collapse the result of the expanded
707/// op to get the value that can replace all uses of the results of the original
708/// op.
709static SmallVector<ReassociationIndices>
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
714 for (AffineExpr expr : indexingMap.getResults()) {
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
717 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(indices));
720 numReshapeDims += numExpandedDims;
721 }
722 return reassociation;
723}
724
725/// Update the body of an expanded linalg operation having index semantics. The
726/// indices of the original operation need to be recovered by linearizing the
727/// indices of the correspoding dimensions of the expanded operation. For now it
728/// is assumed that the shapes of the expanded operation needed for
729/// linearization are static.
731 Location loc, Region &fusedRegion,
732 const ExpansionInfo &expansionInfo) {
733 // Replace the original indices by the linearization of the expanded indices.
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
736 ArrayRef<int64_t> expandedDims =
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() && "expected valid expansion info");
739
740 // Skip index operations that are not affected by the expansion.
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
743 continue;
744
745 // Linearize the expanded indices of the original index dimension.
746 OpBuilder::InsertionGuard guard(rewriter);
747 rewriter.setInsertionPointAfter(indexOp);
748 ArrayRef<OpFoldResult> expandedDimsShape =
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750 SmallVector<Value> expandedIndices;
751 expandedIndices.reserve(expandedDims.size() - 1);
752 llvm::transform(
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
755 OpFoldResult newIndex =
756 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757 for (auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
759 AffineExpr idx, acc, shape;
760 bindDims(rewriter.getContext(), idx, acc);
761 bindSymbols(rewriter.getContext(), shape);
763 rewriter, indexOp.getLoc(), idx + acc * shape,
764 ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
765 }
766 Value newIndexVal =
767 getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768 rewriter.replaceOp(indexOp, newIndexVal);
769 }
770}
771
772// Create an expanded transpose op.
773// the reassociation map is already permuted hence we inverse permute and then
774// flatten it. Then we inverse permute it again to get the final expanded
775// transpose permutation. For example,
776//
777// permutation = [2, 0, 1]
778// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
779//
780// inverse permutation = [1, 2, 0]
781// applied to reassocation_map and then flattened becomes
782// flatened permutation = [2, 3, 4, 5, 0, 1]
783// final permuation is the inverse of the flattened permutation.
784//
785// Becomes
786//
787// permutation=[4, 5, 0, 1, 2, 3]
788
790 TransposeOp transposeOp,
791 Value expandedInput, Value output,
792 ExpansionInfo &expansionInfo) {
793 SmallVector<int64_t> newPerm;
794 for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
798 }
799 }
800 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
801 output, invertPermutationVector(newPerm));
802}
803
804// Create an expanded generic op.
806 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
807 ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
808 ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
809 // The iterator types of the expanded op are all parallel.
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812
813 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[j] = type;
816
817 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818 expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
820
821 Region &fusedRegion = fused->getRegion(0);
822 Region &originalRegion = linalgOp->getRegion(0);
823 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
824
825 // Update the index accesses after the expansion.
826 updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
827 expansionInfo);
828
829 return fused;
830}
831
832// Create an expanded fused op that retains the name for certain ops
833// such as fill, copy and transpose and produce a generic op for
834// rest of linalg ops.
835static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
836 TypeRange resultTypes,
837 ArrayRef<Value> expandedOpOperands,
838 ArrayRef<Value> outputs,
839 ArrayRef<AffineMap> expandedOpIndexingMaps,
840 ExpansionInfo &expansionInfo) {
841
842 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
844 return createExpandedTransposeOp(rewriter, transposeOp,
845 expandedOpOperands[0], outputs[0],
846 expansionInfo);
847 })
848 .Case<FillOp, CopyOp>([&](Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
853 })
854 .Default([&](Operation *op) {
855 return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
858 });
859}
860
861/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
862/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
863/// that those conditions have been satisfied.
864static std::optional<SmallVector<Value>>
865fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866 OpOperand *fusableOpOperand,
867 PatternRewriter &rewriter) {
868 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
869 "preconditions for fuse operation failed");
870
871 Location loc = linalgOp.getLoc();
872 SmallVector<OpFoldResult> expandedShape;
873 SmallVector<AffineMap, 4> reassociationIndices;
874 Value src;
875 if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876 // Try to move the dynamic dimensions in output shape before the `linalgOp`
877 // to maintain SSA validity
878 if (failed(moveValueDefinitions(
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880 return std::nullopt;
881
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
885 } else {
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
888 return std::nullopt;
889
890 expandedShape = tensor::getMixedSizes(
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
894 }
895
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
899 rewriter)))
900 return std::nullopt;
901
902 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
905 }));
906
907 // Set insertion point to the generic op.
908 OpBuilder::InsertionGuard g(rewriter);
909 rewriter.setInsertionPoint(linalgOp);
910
911 SmallVector<Value> expandedOpOperands;
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
916 continue;
917 }
918 if (auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921 SmallVector<OpFoldResult> expandedOperandShape;
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
924 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
925 if (expandedOperandType != opOperand->get().getType()) {
926 // Reshape the operand to get the right type.
928 getReassociationForExpansion(indexingMap, expansionInfo);
929 if (failed(reshapeLikeShapesAreCompatible(
930 [&](const Twine &msg) {
931 return rewriter.notifyMatchFailure(linalgOp, msg);
932 },
933 opOperandType.getShape(), expandedOperandType.getShape(),
934 reassociation,
935 /*isExpandingReshape=*/true)))
936 return std::nullopt;
937 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
940 continue;
941 }
942 }
943 expandedOpOperands.push_back(opOperand->get());
944 }
945
946 SmallVector<Value> outputs;
947 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950 SmallVector<OpFoldResult> expandedOutputShape;
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
953 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
954 if (expandedOutputType != opOperand.get().getType()) {
956 getReassociationForExpansion(indexingMap, expansionInfo);
957 if (failed(reshapeLikeShapesAreCompatible(
958 [&](const Twine &msg) {
959 return rewriter.notifyMatchFailure(linalgOp, msg);
960 },
961 opOperandType.getShape(), expandedOutputType.getShape(),
962 reassociation,
963 /*isExpandingReshape=*/true)))
964 return std::nullopt;
965 outputs.push_back(tensor::ExpandShapeOp::create(
966 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
968 } else {
969 outputs.push_back(opOperand.get());
970 }
971 }
972
973 TypeRange resultTypes = ValueRange(outputs).getTypes();
974 Operation *fusedOp =
975 createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
976 outputs, expandedOpIndexingMaps, expansionInfo);
977 // Reshape the result values to their original shape if this is a collapsing
978 // reshape folded into its consumer.
979 SmallVector<Value> resultVals;
980 for (OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
987 expansionInfo);
988 resultVals.push_back(tensor::CollapseShapeOp::create(
989 rewriter, linalgOp.getLoc(), opResult.getType(),
990 fusedOp->getResult(resultNumber), reassociation));
991 } else {
992 resultVals.push_back(fusedOp->getResult(resultNumber));
993 }
994 }
995 // Assuming a single result.
996 return resultVals;
997}
998
999namespace {
1000
1001/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1002/// when the reshape op is collapsing dimensions. The dimensionality of the loop
1003/// in the consumer is expanded.
1004class FoldWithProducerReshapeOpByExpansion
1005 : public OpInterfaceRewritePattern<LinalgOp> {
1006public:
1007 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1008 ControlFusionFn foldReshapes,
1009 PatternBenefit benefit = 1)
1010 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1012
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014 PatternRewriter &rewriter) const override {
1015 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1018 if (!reshapeOp)
1019 continue;
1020 // Fold only if
1021 // - The tensor reshape op is folding.
1022 // - All constraints of fusing with reshape by expansion are met.
1023 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1024 (!controlFoldingReshapes(opOperand)))
1025 continue;
1026
1027 std::optional<SmallVector<Value>> replacementValues =
1028 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1029 if (!replacementValues)
1030 return failure();
1031 rewriter.replaceOp(linalgOp, *replacementValues);
1032 return success();
1033 }
1034 return failure();
1035 }
1036
1037private:
1038 ControlFusionFn controlFoldingReshapes;
1039};
1040
1041/// Carries information about a padded dimension.
1042struct PadDimInfo {
1043 // The resulting shape after padding each dimension.
1044 SmallVector<int64_t> paddedShape;
1045
1046 // Low and high padding amounts for each dimension.
1047 SmallVector<OpFoldResult> lowPad;
1048 SmallVector<OpFoldResult> highPad;
1049};
1050
1051/// Computes the expanded padding information for the given pad operation based
1052/// on the provided expanded shape and reassociation indices. Returns a list of
1053/// PadDimInfo containing the low and high padding amounts and the padded
1054/// size for each dimension, or failure if the expansion is not possible.
1055static FailureOr<PadDimInfo>
1056computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1057 ArrayRef<ReassociationIndices> reassociations,
1058 PatternRewriter &rewriter) {
1059 // If the padding value depends on the index values of the pad operation,
1060 // then it may not be valid to expand the dimensions, since it will change
1061 // the index values on which the padding value depends. This is not currently
1062 // supported by the pad expansion patterns, but it could be implemented
1063 // similarly to the expansion of linalg.generic ops with linalg.index ops in
1064 // the body, as is done in `updateExpandedGenericOpRegion`.
1065 if (!padOp.getConstantPaddingValue())
1066 return failure();
1067
1068 // Expanded dimensions cannot have padding because the resulting padding may
1069 // not be representable by a tensor.pad op. There are some special cases where
1070 // it is possible (like expanding unit dims), but supporting these cases is
1071 // NYI, so disallow it for now.
1072 ArrayRef<int64_t> low = padOp.getStaticLow();
1073 ArrayRef<int64_t> high = padOp.getStaticHigh();
1074 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1075 if (reInd.size() != 1 && (l != 0 || h != 0))
1076 return failure();
1077 }
1078
1079 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1080 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1081 ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1082 PadDimInfo padDimInfo;
1083 padDimInfo.paddedShape.assign(expandedShape);
1084 padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1085 padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1086 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1087 if (reInd.size() == 1) {
1088 padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1089 padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1090 padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1091 }
1092 }
1093
1094 return padDimInfo;
1095}
1096
1097class FoldPadWithProducerReshapeOpByExpansion
1098 : public OpRewritePattern<tensor::PadOp> {
1099public:
1100 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1101 ControlFusionFn foldReshapes,
1102 PatternBenefit benefit = 1)
1103 : OpRewritePattern<tensor::PadOp>(context, benefit),
1104 controlFoldingReshapes(std::move(foldReshapes)) {}
1105
1106 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1107 PatternRewriter &rewriter) const override {
1108 tensor::CollapseShapeOp reshapeOp =
1109 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1110 if (!reshapeOp)
1111 return failure();
1112
1113 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1114 return rewriter.notifyMatchFailure(padOp,
1115 "fusion blocked by control function");
1116 }
1117
1118 RankedTensorType expandedType = reshapeOp.getSrcType();
1119 SmallVector<ReassociationIndices> reassociations =
1120 reshapeOp.getReassociationIndices();
1121 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1122 padOp, expandedType.getShape(), reassociations, rewriter);
1123 if (failed(maybeExpandedPadding))
1124 return failure();
1125 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1126
1127 Location loc = padOp->getLoc();
1128 RankedTensorType expandedPaddedType =
1129 padOp.getResultType().clone(expandedPadding.paddedShape);
1130
1131 auto newPadOp = tensor::PadOp::create(
1132 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1133 expandedPadding.lowPad, expandedPadding.highPad,
1134 padOp.getConstantPaddingValue(), padOp.getNofold());
1135
1136 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1137 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1138
1139 return success();
1140 }
1141
1142private:
1143 ControlFusionFn controlFoldingReshapes;
1144};
1145
1146class FoldReshapeWithProducerPadOpByExpansion
1147 : public OpRewritePattern<tensor::ExpandShapeOp> {
1148public:
1149 FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1150 ControlFusionFn foldReshapes,
1151 PatternBenefit benefit = 1)
1152 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1153 controlFoldingReshapes(std::move(foldReshapes)) {}
1154
1155 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1156 PatternRewriter &rewriter) const override {
1157 tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1158 if (!padOp)
1159 return failure();
1160
1161 if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1162 return rewriter.notifyMatchFailure(expandOp,
1163 "fusion blocked by control function");
1164 }
1165
1166 RankedTensorType expandedType = expandOp.getResultType();
1167 SmallVector<ReassociationIndices> reassociations =
1168 expandOp.getReassociationIndices();
1169 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1170 padOp, expandedType.getShape(), reassociations, rewriter);
1171 if (failed(maybeExpandedPadding))
1172 return failure();
1173 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1174
1175 Location loc = expandOp->getLoc();
1176 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1177 SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1178 rewriter.setInsertionPointAfterValue(padOp.getSource());
1179 SmallVector<OpFoldResult> padSrcSizes =
1180 tensor::getMixedSizes(rewriter, loc, padOp.getSource());
1181 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1182 // We know that any reassociation with multiple dims is not padded because
1183 // of the requirements of computeExpandedPadding.
1184 if (reInd.size() == 1) {
1185 newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1186 newExpandedSizes[reInd[0]] = padSrcSizes[idx];
1187 }
1188 }
1189 RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1190 auto newExpandOp = tensor::ExpandShapeOp::create(
1191 rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1192 newExpandedSizes);
1193 RankedTensorType expandedPaddedType =
1194 padOp.getResultType().clone(expandedPadding.paddedShape);
1195 rewriter.setInsertionPoint(expandOp);
1196 auto newPadOp = tensor::PadOp::create(
1197 rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1198 expandedPadding.lowPad, expandedPadding.highPad,
1199 padOp.getConstantPaddingValue(), padOp.getNofold());
1200
1201 rewriter.replaceOp(expandOp, newPadOp.getResult());
1202
1203 return success();
1204 }
1205
1206private:
1207 ControlFusionFn controlFoldingReshapes;
1208};
1209
1210/// Pattern to fold a tensor.expand_shape op with its producer generic op
1211/// by expanding the dimensionality of the loop in the producer op.
1212struct FoldReshapeWithGenericOpByExpansion
1213 : public OpRewritePattern<tensor::ExpandShapeOp> {
1214
1215 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1216 ControlFusionFn foldReshapes,
1217 PatternBenefit benefit = 1)
1218 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1219 controlFoldingReshapes(std::move(foldReshapes)) {}
1220
1221 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1222 PatternRewriter &rewriter) const override {
1223 // Fold only if all constraints of fusing with reshape by expansion are met.
1224 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1225 if (!producerResult) {
1226 return rewriter.notifyMatchFailure(reshapeOp,
1227 "source not produced by an operation");
1228 }
1229
1230 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1231 if (!producer) {
1232 return rewriter.notifyMatchFailure(reshapeOp,
1233 "producer not a generic op");
1234 }
1235
1237 producer,
1238 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1239 return rewriter.notifyMatchFailure(
1240 reshapeOp, "failed preconditions of fusion with producer generic op");
1241 }
1242
1243 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1244 return rewriter.notifyMatchFailure(reshapeOp,
1245 "fusion blocked by control function");
1246 }
1247
1248 std::optional<SmallVector<Value>> replacementValues =
1250 producer, reshapeOp,
1251 producer.getDpsInitOperand(producerResult.getResultNumber()),
1252 rewriter);
1253 if (!replacementValues) {
1254 return rewriter.notifyMatchFailure(reshapeOp,
1255 "fusion by expansion failed");
1256 }
1257
1258 // Find the replacement for the reshape op. Since the replacements have the
1259 // same type as the returns of the original generic op, the consumer reshape
1260 // op can be replaced by the source of the collapse_shape op that defines
1261 // the replacement.
1262 Value reshapeReplacement =
1263 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1264 .getResultNumber()];
1265 if (auto collapseOp =
1266 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1267 reshapeReplacement = collapseOp.getSrc();
1268 }
1269 rewriter.replaceOp(reshapeOp, reshapeReplacement);
1270 rewriter.replaceOp(producer, *replacementValues);
1271 return success();
1272 }
1273
1274private:
1275 ControlFusionFn controlFoldingReshapes;
1276};
1277} // namespace
1278
1279//===---------------------------------------------------------------------===//
1280// Methods and patterns to fuse reshape with linalg.generic operations by
1281// contraction of dimensions.
1282//===---------------------------------------------------------------------===//
1283
1284/// For a given list of indices in the range of the `indexingMap` that are
1285/// folded, return the indices of the corresponding domain. Return
1286/// `std::nullopt` on failure. Ensures that all the elements of the returned
1287/// reassociation are distinct.
1290 ReassociationIndicesRef rangeReassociation) {
1291 assert(indexingMap.isProjectedPermutation() &&
1292 "expected projected permutation");
1293
1294 ReassociationIndices domainReassociation = llvm::to_vector<4>(
1295 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1296 return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1297 }));
1298 // The projected permutation semantics ensures that there is no repetition of
1299 // the domain indices.
1300 return domainReassociation;
1301}
1302
1303/// For a given `dimSequence`, check if the sequence is conserved in the
1304/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1305/// Non-existence of the sequence returns true as well.
1307 ReassociationIndicesRef dimSequence) {
1308 assert(!dimSequence.empty() &&
1309 "expected non-empty list for dimension sequence");
1310 assert(indexingMap.isProjectedPermutation() &&
1311 "expected indexing map to be projected permutation");
1312
1313 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1314 sequenceElements.insert_range(dimSequence);
1315
1316 unsigned dimSequenceStart = dimSequence[0];
1317 for (const auto &expr : enumerate(indexingMap.getResults())) {
1318 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1319 // 1. Check if this start of the sequence.
1320 if (dimInMapStart == dimSequenceStart) {
1321 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1322 return false;
1323 // 1a. Check if sequence is preserved.
1324 for (const auto &dimInSequence : enumerate(dimSequence)) {
1325 unsigned dimInMap =
1326 cast<AffineDimExpr>(
1327 indexingMap.getResult(expr.index() + dimInSequence.index()))
1328 .getPosition();
1329 if (dimInMap != dimInSequence.value())
1330 return false;
1331 }
1332 // Found the sequence. Projected permutation
1333 // enforces that all AffineDimExprs in the result are unique, so no
1334 // further checks are needed.
1335 return true;
1336 }
1337 // 2. If position in the expr (which is of type AffineDimExpr) is part
1338 // of sequence, return false here. This implies the entire sequence does not
1339 // exist in the indexing map.
1340 if (sequenceElements.count(dimInMapStart))
1341 return false;
1342 }
1343 // 3. No element of sequence found. Return true.
1344 return true;
1345}
1346
1349 return llvm::all_of(maps, [&](AffineMap map) {
1350 return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1351 return isDimSequencePreserved(map, dimSequence);
1352 });
1353 });
1354}
1355
1356// Return the list of dimensions of the iteration domain that can be
1357// collapsed to allow for fusion with the a producer that is an expand_shape
1358// operation. If all dimensions created by expansion can be collapsed in the
1359// iteration space then the reshape is defunct.
1360//
1361// Example:
1362//
1363// ```mlir
1364// #map = affine_map<(d0, d1) -> (d0, d1)>
1365// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1366// %2 = tensor.empty [..] : tensor<?x4xf32>
1367// %3 = linalg.generic {
1368// indexing_maps = [#map, #map],
1369// iterator_types = ["parallel" ,"parallel"]}
1370// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1371// ```
1372//
1373// can be fused by collapsing the dimensions of the iteration space.
1374//
1375// ```mlir
1376// #map = affine_map<(d0) -> (d0)>
1377// %2 = tensor.empty [..] : tensor<?xf32>
1378// %3 = linalg.generic {
1379// indexing_maps = [#map, #map],
1380// iterator_types = ["parallel"]}
1381// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1382// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1383// ```
1384//
1385// In the following example,
1386//
1387// ```mlir
1388// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1389// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1390// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1391// %2 = tensor.empty [..] : tensor<4x?xf32>
1392// %2 = linalg.generic {
1393// indexing_maps = [#map0, #map1],
1394// iterator_types = ["parallel" ,"parallel"]}
1395// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1396// ```
1397//
1398// the reshape cannot be fused with the generic op by collapsing the op
1399// dimensions since the indexing maps will have to contain mods and divs
1400// to preserve the accesses pattern. When no dimensions of the iteration
1401// space are collapsable and empty vector is returned.
1403getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1404 ArrayRef<ReassociationIndices> reassociation) {
1405 // Some basic checks for this fusion to be valid.
1406 if (!genericOp.hasPureTensorSemantics())
1407 return {};
1408
1409 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1410 return map.isProjectedPermutation();
1411 })) {
1412 return {};
1413 }
1414
1415 // Compute all the loops with the reduction iterator types.
1416 SmallVector<unsigned> reductionDims;
1417 genericOp.getReductionDims(reductionDims);
1418
1419 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1420 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1421 auto iteratorTypes = genericOp.getIteratorTypesArray();
1422 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1423 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1424 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1425
1426 // Ignore dims that are not folded.
1427 if (foldedRangeDims.size() == 1)
1428 continue;
1429
1430 ReassociationIndices foldedIterationSpaceDims =
1431 getDomainReassociation(indexingMap, foldedRangeDims);
1432
1433 // Check that the folded iteration dims do not contain already processed
1434 // dims.
1435 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1436 return processedIterationDims.count(dim);
1437 }))
1438 continue;
1439
1440 // Check that all folded iterator types are all parallel or all reductions.
1441 utils::IteratorType startIteratorType =
1442 iteratorTypes[foldedIterationSpaceDims[0]];
1443 if (!isParallelIterator(startIteratorType) &&
1444 !isReductionIterator(startIteratorType))
1445 continue;
1446 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1447 return iteratorTypes[dim] != startIteratorType;
1448 }))
1449 continue;
1450
1451 // If the folded dimensions correspond to a "reduction" iterator type,
1452 // the folded dimensions need to be "in-order". Strictly speaking this is
1453 // not necessary, for reductions that are associative and commutative, but
1454 // using a more strict definition of reduction for now.
1455 if (isReductionIterator(startIteratorType)) {
1456 bool isContiguous = false;
1457 for (const auto &startDim : llvm::enumerate(reductionDims)) {
1458 // Move window in `reductionDims` to start of the folded iteration dims.
1459 if (startDim.value() != foldedIterationSpaceDims[0])
1460 continue;
1461 // If sizes doesnt match, trivial not contiguous. This condition should
1462 // not be hit.
1463 if (startDim.index() + foldedIterationSpaceDims.size() >
1464 reductionDims.size())
1465 break;
1466 // Check that the contiguity is maintained.
1467 isContiguous = true;
1468 for (const auto &foldedDim :
1469 llvm::enumerate(foldedIterationSpaceDims)) {
1470 if (reductionDims[foldedDim.index() + startDim.index()] !=
1471 foldedDim.value()) {
1472 isContiguous = false;
1473 break;
1474 }
1475 }
1476 break;
1477 }
1478 if (!isContiguous)
1479 continue;
1480 }
1481
1482 // Check that the sequence is preserved in all indexing maps.
1483 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1484 [&](AffineMap indexingMap) {
1485 return !isDimSequencePreserved(indexingMap,
1486 foldedIterationSpaceDims);
1487 }))
1488 continue;
1489
1490 processedIterationDims.insert_range(foldedIterationSpaceDims);
1491 iterationSpaceReassociation.emplace_back(
1492 std::move(foldedIterationSpaceDims));
1493 }
1494
1495 return iterationSpaceReassociation;
1496}
1497
1498/// Helper class to carry state while collapsing the `linalg.generic` op.
1499namespace {
1500class CollapsingInfo {
1501public:
1502 LogicalResult initialize(unsigned origNumLoops,
1503 ArrayRef<ReassociationIndices> foldedIterationDims) {
1504 llvm::SmallDenseSet<int64_t, 4> processedDims;
1505 // Find all the dims that are folded.
1506 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1507 if (foldedIterationDim.empty())
1508 continue;
1509 // If the folded dims contain dims already folded, that's illegal
1510 // specification. Repetition within a list is also illegal.
1511 for (auto dim : foldedIterationDim) {
1512 if (dim >= origNumLoops)
1513 return failure();
1514 if (processedDims.count(dim))
1515 return failure();
1516 processedDims.insert(dim);
1517 }
1518 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1519 foldedIterationDim.end());
1520 }
1521 if (processedDims.size() > origNumLoops)
1522 return failure();
1523
1524 // Add all the preserved dims of the original op as single
1525 // elements to `collapsedOpToOrigOpIterationDim`.
1526 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1527 if (processedDims.count(dim))
1528 continue;
1529 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1530 }
1531
1532 llvm::sort(collapsedOpToOrigOpIterationDim,
1534 return lhs[0] < rhs[0];
1535 });
1536 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1537 for (const auto &foldedDims :
1538 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1539 for (const auto &dim : enumerate(foldedDims.value()))
1540 origOpToCollapsedOpIterationDim[dim.value()] =
1541 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1542 }
1543 return success();
1544 }
1545
1546 /// Return mapping from collapsed loop domain to original loop domain.
1547 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1548 return collapsedOpToOrigOpIterationDim;
1549 }
1550
1551 /// Return mapping from original loop domain to collapsed loop domain. The
1552 /// mapping is a pair. First value is the dimension in the collapsed loop that
1553 /// the original loop is mapped to. Second is the relative position in folded
1554 /// list of this domain. For example if the original loop domain is 3D, and
1555 /// the collapsed loop domain is folding all of it, i.e.
1556 ///
1557 /// ```
1558 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1559 /// ```
1560 ///
1561 /// then
1562 ///
1563 /// ```
1564 /// origOpToCollapsedOpMapping[0] = {0, 0};
1565 /// origOpToCollapsedOpMapping[1] = {0, 1};
1566 /// origOpToCollapsedOpMapping[2] = {0, 2};
1567 /// origOpToCollapsedOpMapping[3] = {1, 0};
1568 /// origOpToCollapsedOpMapping[4] = {1, 1};
1569 /// ```
1570 ///
1571 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1572 return origOpToCollapsedOpIterationDim;
1573 }
1574
1575 /// Return the collapsed op iteration domain rank.
1576 unsigned getCollapsedOpIterationRank() const {
1577 return collapsedOpToOrigOpIterationDim.size();
1578 }
1579
1580private:
1581 /// Map from the iteration domain index in collapsed op to the iteration
1582 /// domain indices in the original op.
1583 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1584
1585 /// Map from iteration domain index in the original op to the iteration domain
1586 /// index in the collapsed op.
1587 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1588};
1589} // namespace
1590
1591/// Get the iterator types for the collapsed operation given the original
1592/// iterator types and collapsed dimensions.
1593static SmallVector<utils::IteratorType>
1594getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1595 const CollapsingInfo &collapsingInfo) {
1596 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1597 for (ReassociationIndicesRef foldedIterDims :
1598 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1599 assert(!foldedIterDims.empty() &&
1600 "reassociation indices expected to have non-empty sets");
1601 // Just pick the iterator type of the first folded dim. Pre-condition checks
1602 // expected to have checked that iterator types of all folded dimensions are
1603 // the same.
1604 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1605 }
1606 return collapsedIteratorTypes;
1607}
1608
1609/// Compute the indexing map in the collapsed op that corresponds to the given
1610/// `indexingMap` of the original operation.
1611static AffineMap
1612getCollapsedOpIndexingMap(AffineMap indexingMap,
1613 const CollapsingInfo &collapsingInfo) {
1614 MLIRContext *context = indexingMap.getContext();
1615 assert(indexingMap.isProjectedPermutation() &&
1616 "expected indexing map to be projected permutation");
1617 SmallVector<AffineExpr> resultExprs;
1618 auto origOpToCollapsedOpMapping =
1619 collapsingInfo.getOrigOpToCollapsedOpMapping();
1620 for (auto expr : indexingMap.getResults()) {
1621 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1622 // If the dim is not the first of the collapsed dim, do nothing.
1623 if (origOpToCollapsedOpMapping[dim].second != 0)
1624 continue;
1625 // The next n-dims are guaranteed to be collapsed. So just use the
1626 // iteration dimension of the collapsed op.
1627 resultExprs.push_back(
1628 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1629 }
1630 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1631 resultExprs, context);
1632}
1633
1634/// Return the `reassociation` indices to use to collapse the operand when the
1635/// iteration space of a generic op is collapsed.
1636static SmallVector<ReassociationIndices>
1637getOperandReassociation(AffineMap indexingMap,
1638 const CollapsingInfo &collapsingInfo) {
1639 unsigned counter = 0;
1640 SmallVector<ReassociationIndices> operandReassociation;
1641 auto origOpToCollapsedOpMapping =
1642 collapsingInfo.getOrigOpToCollapsedOpMapping();
1643 auto collapsedOpToOrigOpMapping =
1644 collapsingInfo.getCollapsedOpToOrigOpMapping();
1645 while (counter < indexingMap.getNumResults()) {
1646 unsigned dim =
1647 cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1648 // This is the start of a collapsed dimensions of the iteration that
1649 // is gauranteed to be preserved in the indexing map. The number of folded
1650 // dims is obtained from the collapsed op to original op mapping.
1651 unsigned numFoldedDims =
1652 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1653 .size();
1654 if (origOpToCollapsedOpMapping[dim].second == 0) {
1655 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1656 operandReassociation.emplace_back(range.begin(), range.end());
1657 }
1658 counter += numFoldedDims;
1659 }
1660 return operandReassociation;
1661}
1662
1663/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1664static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1665 OpOperand *opOperand,
1666 const CollapsingInfo &collapsingInfo,
1667 OpBuilder &builder) {
1668 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1669 SmallVector<ReassociationIndices> operandReassociation =
1670 getOperandReassociation(indexingMap, collapsingInfo);
1671
1672 // If the number of entries in the reassociation for the operand is same as
1673 // the number of results of the indexing map, then nothing to do for this
1674 // operand.
1675 Value operand = opOperand->get();
1676 if (operandReassociation.size() == indexingMap.getNumResults())
1677 return operand;
1678
1679 // Insert a reshape to collapse the dimensions.
1680 if (isa<MemRefType>(operand.getType())) {
1681 return memref::CollapseShapeOp::create(builder, loc, operand,
1682 operandReassociation)
1683 .getResult();
1684 }
1685 return tensor::CollapseShapeOp::create(builder, loc, operand,
1686 operandReassociation)
1687 .getResult();
1688}
1689
1690/// Modify the `linalg.index` operations in the original generic op, to its
1691/// value in the collapsed operation.
1692static void generateCollapsedIndexingRegion(
1693 Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1694 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1695 OpBuilder::InsertionGuard g(rewriter);
1696 rewriter.setInsertionPointToStart(block);
1697
1698 // Collect all the original index ops.
1699 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1700
1701 // For each folded dimension list resolve the original induction variable
1702 // values in terms of the folded dimension induction variable.
1703 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1704 // can be inverted to
1705 // i2 = i_{folded} % d2
1706 // i1 = (i_{folded} / d2) % d1
1707 // i0 = i_{folded} / (d1 * d2)
1708 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1709 for (auto foldedDims :
1710 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1711 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1712 Value newIndexVal =
1713 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1714 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1715 Value loopDim =
1716 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
1717 indexReplacementVals[dim] =
1718 rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1719 newIndexVal =
1720 rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1721 }
1722 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1723 }
1724
1725 for (auto indexOp : indexOps) {
1726 auto dim = indexOp.getDim();
1727 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1728 }
1729}
1730
1731static void collapseOperandsAndResults(LinalgOp op,
1732 const CollapsingInfo &collapsingInfo,
1733 RewriterBase &rewriter,
1734 SmallVectorImpl<Value> &inputOperands,
1735 SmallVectorImpl<Value> &outputOperands,
1736 SmallVectorImpl<Type> &resultTypes) {
1737 Location loc = op->getLoc();
1738 inputOperands =
1739 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1740 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1741 rewriter);
1742 });
1743
1744 // Get the output operands and result types.
1745 resultTypes.reserve(op.getNumDpsInits());
1746 outputOperands.reserve(op.getNumDpsInits());
1747 for (OpOperand &output : op.getDpsInitsMutable()) {
1748 Value newOutput =
1749 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1750 outputOperands.push_back(newOutput);
1751 // If the op has "buffer semantics", then the init operands are ranked
1752 // memrefs and the op has no results.
1753 if (!op.hasPureBufferSemantics())
1754 resultTypes.push_back(newOutput.getType());
1755 }
1756}
1757
1758/// Clone a `LinalgOp` to a collapsed version of same name
1759template <typename OpTy>
1760static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1761 const CollapsingInfo &collapsingInfo) {
1762 return nullptr;
1763}
1764
1765/// Collapse any `LinalgOp` that does not require any specialization such as
1766/// indexing_maps, iterator_types, etc.
1767template <>
1768LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1769 const CollapsingInfo &collapsingInfo) {
1770 SmallVector<Value> inputOperands, outputOperands;
1771 SmallVector<Type> resultTypes;
1772 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1773 outputOperands, resultTypes);
1774
1775 return clone(
1776 rewriter, origOp, resultTypes,
1777 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1778}
1779
1780/// Collapse a `GenericOp`
1781template <>
1782GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1783 GenericOp origOp,
1784 const CollapsingInfo &collapsingInfo) {
1785 SmallVector<Value> inputOperands, outputOperands;
1786 SmallVector<Type> resultTypes;
1787 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1788 outputOperands, resultTypes);
1789 SmallVector<AffineMap> indexingMaps(
1790 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1791 return getCollapsedOpIndexingMap(map, collapsingInfo);
1792 }));
1793
1794 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1795 origOp.getIteratorTypesArray(), collapsingInfo));
1796
1797 GenericOp collapsedOp = linalg::GenericOp::create(
1798 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1799 indexingMaps, iteratorTypes,
1800 [](OpBuilder &builder, Location loc, ValueRange args) {});
1801 Block *origOpBlock = &origOp->getRegion(0).front();
1802 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1803 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1804 collapsedOpBlock->getArguments());
1805 return collapsedOp;
1806}
1807
1808static LinalgOp createCollapsedOp(LinalgOp op,
1809 const CollapsingInfo &collapsingInfo,
1810 RewriterBase &rewriter) {
1811 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1812 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1813 }
1814 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1815}
1816
1817/// Implementation of fusion with reshape operation by collapsing dimensions.
1818FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1819 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1820 RewriterBase &rewriter) {
1821 // Bail on trivial no-op cases.
1822 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1823 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1824 return foldedDims.size() <= 1;
1825 }))
1826 return failure();
1827
1828 CollapsingInfo collapsingInfo;
1829 if (failed(
1830 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1831 return rewriter.notifyMatchFailure(
1832 op, "illegal to collapse specified dimensions");
1833 }
1834
1835 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1836 if (hasPureBufferSemantics &&
1837 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
1838 MemRefType memRefToCollapse =
1839 dyn_cast<MemRefType>(opOperand.get().getType());
1840 if (!memRefToCollapse)
1841 return true;
1842
1843 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1844 SmallVector<ReassociationIndices> operandReassociation =
1845 getOperandReassociation(indexingMap, collapsingInfo);
1846 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1847 memRefToCollapse, operandReassociation);
1848 }))
1849 return rewriter.notifyMatchFailure(op,
1850 "memref is not guaranteed collapsible");
1851
1852 // Bail on non-canonical ranges.
1853 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1854 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1855 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1856 return cast<IntegerAttr>(attr).getInt() == value;
1857 llvm::APInt actual;
1858 return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1859 actual.getSExtValue() == value;
1860 };
1861 if (!llvm::all_of(loopRanges, [&](Range range) {
1862 return opFoldIsConstantValue(range.offset, 0) &&
1863 opFoldIsConstantValue(range.stride, 1);
1864 })) {
1865 return rewriter.notifyMatchFailure(
1866 op, "expected all loop ranges to have zero start and unit stride");
1867 }
1868
1869 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1870
1871 Location loc = op->getLoc();
1872 SmallVector<OpFoldResult> loopBound =
1873 llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
1874
1875 if (collapsedOp.hasIndexSemantics()) {
1876 // Collect the loop range of the generic op.
1877 OpBuilder::InsertionGuard g(rewriter);
1878 rewriter.setInsertionPoint(collapsedOp);
1879 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1880 collapsingInfo, loopBound, rewriter);
1881 }
1882
1883 // Insert expanding reshape for the result to get back the original result
1884 // type.
1885 SmallVector<Value> results;
1886 for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1887 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1888 auto originalResultType =
1889 cast<ShapedType>(originalResult.value().getType());
1890 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1891 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1892 AffineMap indexingMap =
1893 op.getIndexingMapMatchingResult(originalResult.value());
1894 SmallVector<ReassociationIndices> reassociation =
1895 getOperandReassociation(indexingMap, collapsingInfo);
1896 assert(
1897 indexingMap.isProjectedPermutation() &&
1898 "Expected indexing map to be a projected permutation for collapsing");
1899 SmallVector<OpFoldResult> resultShape =
1900 applyPermutationMap(indexingMap, ArrayRef(loopBound));
1901 Value result;
1902 if (isa<MemRefType>(collapsedOpResult.getType())) {
1903 MemRefType expandShapeResultType = MemRefType::get(
1904 originalResultType.getShape(), originalResultType.getElementType());
1905 result = memref::ExpandShapeOp::create(
1906 rewriter, loc, expandShapeResultType, collapsedOpResult,
1907 reassociation, resultShape);
1908 } else {
1909 result = tensor::ExpandShapeOp::create(
1910 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1911 resultShape);
1912 }
1913 results.push_back(result);
1914 } else {
1915 results.push_back(collapsedOpResult);
1916 }
1917 }
1918 return CollapseResult{results, collapsedOp};
1919}
1920
1921namespace {
1922
1923/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1924/// contracting dimensions of the loop.
1925class FoldWithProducerReshapeOpByCollapsing
1926 : public OpRewritePattern<GenericOp> {
1927public:
1928 // TODO : support fusion with all linalg ops, not just generic.
1929 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1930 ControlFusionFn foldReshapes,
1931 PatternBenefit benefit = 1)
1932 : OpRewritePattern<GenericOp>(context, benefit),
1933 controlFoldingReshapes(std::move(foldReshapes)) {}
1934
1935 LogicalResult matchAndRewrite(GenericOp genericOp,
1936 PatternRewriter &rewriter) const override {
1937 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1938 tensor::ExpandShapeOp reshapeOp =
1939 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1940 if (!reshapeOp)
1941 continue;
1942
1943 SmallVector<ReassociationIndices> collapsableIterationDims =
1944 getCollapsableIterationSpaceDims(genericOp, &opOperand,
1945 reshapeOp.getReassociationIndices());
1946 if (collapsableIterationDims.empty() ||
1947 !controlFoldingReshapes(&opOperand)) {
1948 continue;
1949 }
1950
1951 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1952 genericOp, collapsableIterationDims, rewriter);
1953 if (!collapseResult) {
1954 return rewriter.notifyMatchFailure(
1955 genericOp, "failed to do the fusion by collapsing transformation");
1956 }
1957
1958 rewriter.replaceOp(genericOp, collapseResult->results);
1959 return success();
1960 }
1961 return failure();
1962 }
1963
1964private:
1965 ControlFusionFn controlFoldingReshapes;
1966};
1967
1968/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1969/// by expanding the dimensionality of the loop in the producer op.
1970struct FoldReshapeWithGenericOpByCollapsing
1971 : public OpRewritePattern<tensor::CollapseShapeOp> {
1972
1973 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1974 ControlFusionFn foldReshapes,
1975 PatternBenefit benefit = 1)
1976 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1977 controlFoldingReshapes(std::move(foldReshapes)) {}
1978
1979 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1980 PatternRewriter &rewriter) const override {
1981 // Fold only if all constraints of fusing with reshape by collapsing are
1982 // met.
1983 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1984 if (!producerResult) {
1985 return rewriter.notifyMatchFailure(reshapeOp,
1986 "source not produced by an operation");
1987 }
1988
1989 // TODO : support fusion with all linalg producers, not just generic.
1990 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1991 if (!producer) {
1992 return rewriter.notifyMatchFailure(reshapeOp,
1993 "producer not a generic op");
1994 }
1995
1996 SmallVector<ReassociationIndices> collapsableIterationDims =
1998 producer,
1999 producer.getDpsInitOperand(producerResult.getResultNumber()),
2000 reshapeOp.getReassociationIndices());
2001 if (collapsableIterationDims.empty()) {
2002 return rewriter.notifyMatchFailure(
2003 reshapeOp, "failed preconditions of fusion with producer generic op");
2004 }
2005
2006 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2007 return rewriter.notifyMatchFailure(reshapeOp,
2008 "fusion blocked by control function");
2009 }
2010
2011 // Set the insertion point after `producer` because there could be uses
2012 // of `producer` between it and the `tensor.collapse_shape` op.
2013 rewriter.setInsertionPointAfter(producer);
2014 std::optional<CollapseResult> collapseResult =
2015 collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
2016 if (!collapseResult) {
2017 return rewriter.notifyMatchFailure(
2018 producer, "failed to do the fusion by collapsing transformation");
2019 }
2020
2021 rewriter.replaceOp(producer, collapseResult->results);
2022 return success();
2023 }
2024
2025private:
2026 ControlFusionFn controlFoldingReshapes;
2027};
2028
2029/// Computes the collapsed padding information for the given pad operation based
2030/// on the provided collapsed shape and reassociation indices. Returns a
2031/// PadDimInfo containing the low and high padding amounts and the collapsed
2032/// shape for each dimension, or failure if the collapse is not possible.
2033static FailureOr<PadDimInfo>
2034computeCollapsedPadding(tensor::PadOp padOp,
2035 ArrayRef<ReassociationIndices> reassociations,
2036 PatternRewriter &rewriter) {
2037 // If the padding value depends on the index values of the pad operation,
2038 // then it may not be valid to collapse the dimensions, since it will change
2039 // the index values on which the padding value depends. This is not currently
2040 // supported by the pad collapsing patterns, but it could be implemented
2041 // similarly to the collapsing of linalg.generic ops with linalg.index ops in
2042 // the body, as is done in `generateCollapsedIndexingRegion`.
2043 if (!padOp.getConstantPaddingValue())
2044 return failure();
2045
2046 // Collapsed dimensions cannot have padding because this can produce strided
2047 // padding that isn't representable by a tensor.pad op. There are some special
2048 // cases where it is possible (like collapsing unit dims), but supporting
2049 // these cases is NYI, so disallow it for now.
2050 ArrayRef<int64_t> low = padOp.getStaticLow();
2051 ArrayRef<int64_t> high = padOp.getStaticHigh();
2052 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2053 for (int64_t dim : reInd) {
2054 if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2055 return failure();
2056 }
2057 }
2058
2059 // Initialize padding values for collapsed tensors with zeros
2060 ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2061 PadDimInfo padDimInfo;
2062 padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2063 padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2064
2065 // Update padding for dimensions that are not being collapsed, and compute
2066 // the collapsed padded shape.
2067 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2068 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2069 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2070 if (reInd.size() == 1) {
2071 padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2072 padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2073 }
2074 SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
2075 for (int64_t dim : reInd) {
2076 collapsedSize =
2077 collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
2078 }
2079 padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
2080 }
2081
2082 return padDimInfo;
2083}
2084
2085class FoldPadWithProducerReshapeOpByCollapsing
2086 : public OpRewritePattern<tensor::PadOp> {
2087public:
2088 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
2089 ControlFusionFn foldReshapes,
2090 PatternBenefit benefit = 1)
2091 : OpRewritePattern<tensor::PadOp>(context, benefit),
2092 controlFoldingReshapes(std::move(foldReshapes)) {}
2093
2094 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2095 PatternRewriter &rewriter) const override {
2096 tensor::ExpandShapeOp reshapeOp =
2097 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
2098 if (!reshapeOp)
2099 return failure();
2100
2101 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
2102 return rewriter.notifyMatchFailure(padOp,
2103 "fusion blocked by control function");
2104 }
2105
2106 SmallVector<ReassociationIndices> reassociations =
2107 reshapeOp.getReassociationIndices();
2108 FailureOr<PadDimInfo> maybeCollapsedPadding =
2109 computeCollapsedPadding(padOp, reassociations, rewriter);
2110 if (failed(maybeCollapsedPadding))
2111 return failure();
2112 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2113
2114 SmallVector<OpFoldResult> expandedPaddedSizes =
2115 reshapeOp.getMixedOutputShape();
2116 AffineExpr d0, d1, d2;
2117 bindDims(rewriter.getContext(), d0, d1, d2);
2118 auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
2119 Location loc = reshapeOp->getLoc();
2120 for (auto [reInd, l, h] :
2121 llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2122 collapsedPadding.highPad)) {
2123 if (reInd.size() == 1) {
2124 expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
2125 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
2126 }
2127 }
2128
2129 RankedTensorType collapsedPaddedType =
2130 padOp.getType().clone(collapsedPadding.paddedShape);
2131 auto newPadOp = tensor::PadOp::create(
2132 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2133 collapsedPadding.lowPad, collapsedPadding.highPad,
2134 padOp.getConstantPaddingValue(), padOp.getNofold());
2135
2136 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
2137 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2138 expandedPaddedSizes);
2139
2140 return success();
2141 }
2142
2143private:
2144 ControlFusionFn controlFoldingReshapes;
2145};
2146
2147class FoldReshapeWithProducerPadOpByCollapsing
2148 : public OpRewritePattern<tensor::CollapseShapeOp> {
2149public:
2150 FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2151 ControlFusionFn foldReshapes,
2152 PatternBenefit benefit = 1)
2153 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2154 controlFoldingReshapes(std::move(foldReshapes)) {}
2155
2156 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2157 PatternRewriter &rewriter) const override {
2158 tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2159 if (!padOp)
2160 return failure();
2161
2162 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2163 return rewriter.notifyMatchFailure(padOp,
2164 "fusion blocked by control function");
2165 }
2166
2167 SmallVector<ReassociationIndices> reassociations =
2168 reshapeOp.getReassociationIndices();
2169 RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2170 FailureOr<PadDimInfo> maybeCollapsedPadding =
2171 computeCollapsedPadding(padOp, reassociations, rewriter);
2172 if (failed(maybeCollapsedPadding))
2173 return failure();
2174 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2175
2176 Location loc = reshapeOp->getLoc();
2177 auto newCollapseOp = tensor::CollapseShapeOp::create(
2178 rewriter, loc, padOp.getSource(), reassociations);
2179
2180 auto newPadOp = tensor::PadOp::create(
2181 rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2182 collapsedPadding.lowPad, collapsedPadding.highPad,
2183 padOp.getConstantPaddingValue(), padOp.getNofold());
2184
2185 rewriter.replaceOp(reshapeOp, newPadOp.getResult());
2186 return success();
2187 }
2188
2189private:
2190 ControlFusionFn controlFoldingReshapes;
2191};
2192
2193/// Pattern to collapse dimensions.
2194template <typename LinalgType>
2195class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2196public:
2197 CollapseLinalgDimensions(MLIRContext *context,
2198 GetCollapsableDimensionsFn collapseDimensions,
2199 PatternBenefit benefit = 1)
2200 : OpRewritePattern<LinalgType>(context, benefit),
2201 controlCollapseDimension(std::move(collapseDimensions)) {}
2202
2203 LogicalResult matchAndRewrite(LinalgType op,
2204 PatternRewriter &rewriter) const override {
2205 SmallVector<ReassociationIndices> collapsableIterationDims =
2206 controlCollapseDimension(op);
2207 if (collapsableIterationDims.empty())
2208 return failure();
2209
2210 // Check if the specified list of dimensions to collapse is a valid list.
2211 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2212 collapsableIterationDims)) {
2213 return rewriter.notifyMatchFailure(
2214 op, "specified dimensions cannot be collapsed");
2215 }
2216
2217 std::optional<CollapseResult> collapseResult =
2218 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2219 if (!collapseResult) {
2220 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2221 }
2222 rewriter.replaceOp(op, collapseResult->results);
2223 return success();
2224 }
2225
2226private:
2227 GetCollapsableDimensionsFn controlCollapseDimension;
2228};
2229
2230} // namespace
2231
2232//===---------------------------------------------------------------------===//
2233// Methods and patterns that fuse constants with linalg.generic operations.
2234//===---------------------------------------------------------------------===//
2235
2236namespace {
2237/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2238/// handle cases where the constant is not single-valued.
2239class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2240public:
2241 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2242 : OpRewritePattern<GenericOp>(context, benefit) {}
2243
2244 LogicalResult matchAndRewrite(GenericOp genericOp,
2245 PatternRewriter &rewriter) const override {
2246 if (!genericOp.hasPureTensorSemantics())
2247 return failure();
2248 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2249 Operation *def = opOperand->get().getDefiningOp();
2250 TypedAttr constantAttr;
2251 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2252 {
2253 DenseElementsAttr splatAttr;
2254 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2255 splatAttr.isSplat() &&
2256 splatAttr.getType().getElementType().isIntOrFloat()) {
2257 constantAttr = splatAttr.getSplatValue<TypedAttr>();
2258 return true;
2259 }
2260 }
2261 {
2262 IntegerAttr intAttr;
2263 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2264 constantAttr = intAttr;
2265 return true;
2266 }
2267 }
2268 {
2269 FloatAttr floatAttr;
2270 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2271 constantAttr = floatAttr;
2272 return true;
2273 }
2274 }
2275 return false;
2276 };
2277
2278 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2279 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2280 continue;
2281
2282 // The operands and the indexing_maps of the fused operation the same as
2283 // the operands and indexing_maps of the generic operations with the
2284 // values at the constant index dropped.
2285 SmallVector<AffineMap> fusedIndexMaps;
2286 SmallVector<Value> fusedOperands;
2287 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2288 fusedIndexMaps.reserve(genericOp->getNumOperands());
2289 fusedOperands.reserve(genericOp.getNumDpsInputs());
2290 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2291 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2292 if (inputOperand == opOperand)
2293 continue;
2294 Value inputValue = inputOperand->get();
2295 fusedIndexMaps.push_back(
2296 genericOp.getMatchingIndexingMap(inputOperand));
2297 fusedOperands.push_back(inputValue);
2298 fusedLocs.push_back(inputValue.getLoc());
2299 }
2300 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2301 fusedIndexMaps.push_back(
2302 genericOp.getMatchingIndexingMap(&outputOperand));
2303
2304 // Check if the operation shapes to loops map is computable.
2305 if (!inversePermutation(
2306 concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2307 return rewriter.notifyMatchFailure(
2308 genericOp, "fused op loop bound computation failed");
2309 }
2310
2311 // Create a constant scalar value from the splat constant.
2312 Value scalarConstant =
2313 arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
2314
2315 SmallVector<Value> outputOperands = genericOp.getOutputs();
2316 auto fusedOp =
2317 GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs),
2318 genericOp->getResultTypes(),
2319 /*inputs=*/fusedOperands,
2320 /*outputs=*/outputOperands,
2321 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2322 genericOp.getIteratorTypes(),
2323 /*doc=*/nullptr,
2324 /*library_call=*/nullptr);
2325
2326 // Map the block argument corresponding to the replaced argument with the
2327 // scalar constant.
2328 Region &region = genericOp->getRegion(0);
2329 Block &entryBlock = *region.begin();
2330 IRMapping mapping;
2331 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2332 scalarConstant);
2333 Region &fusedRegion = fusedOp->getRegion(0);
2334 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2335 mapping);
2336 rewriter.replaceOp(genericOp, fusedOp->getResults());
2337 return success();
2338 }
2339 return failure();
2340 }
2341};
2342
2343} // namespace
2344
2345//===---------------------------------------------------------------------===//
2346// Miscellaneous patterns that help fusion.
2347//===---------------------------------------------------------------------===//
2348
2349namespace {
2350/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2351/// value of the `outs` operand is not used within the op. This is only
2352/// implemented for `linalg.generic` operations for now, but should hold for all
2353/// linalg structured ops.
2354struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2355 using OpRewritePattern<GenericOp>::OpRewritePattern;
2356
2357 LogicalResult matchAndRewrite(GenericOp op,
2358 PatternRewriter &rewriter) const override {
2359 rewriter.startOpModification(op);
2360 bool modifiedOutput = false;
2361 Location loc = op.getLoc();
2362 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2363 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2364 Value operandVal = opOperand.get();
2365 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2366 if (!operandType)
2367 continue;
2368
2369 // If outs is sparse, leave it to the sparsifier.
2371 continue;
2372
2373 // If outs is already an `empty` operation, nothing to do.
2374 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2375 if (definingOp)
2376 continue;
2377 modifiedOutput = true;
2378 SmallVector<OpFoldResult> mixedSizes =
2379 tensor::getMixedSizes(rewriter, loc, operandVal);
2380 Value emptyTensor = tensor::EmptyOp::create(
2381 rewriter, loc, mixedSizes, operandType.getElementType());
2382 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2383 }
2384 }
2385 if (!modifiedOutput) {
2386 rewriter.cancelOpModification(op);
2387 return failure();
2388 }
2389 rewriter.finalizeOpModification(op);
2390 return success();
2391 }
2392};
2393
2394/// Fold linalg.fill into linalg.generic
2395struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2396 using OpRewritePattern<GenericOp>::OpRewritePattern;
2397
2398 LogicalResult matchAndRewrite(GenericOp genericOp,
2399 PatternRewriter &rewriter) const override {
2400 if (!genericOp.hasPureTensorSemantics())
2401 return failure();
2402 bool fillFound = false;
2403 Block &payload = genericOp.getRegion().front();
2404 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2405 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2406 continue;
2407 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2408 if (!fillOp)
2409 continue;
2410 fillFound = true;
2411 Value fillVal = fillOp.value();
2412 auto resultType =
2413 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2414 Value convertedVal =
2415 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2416 /*isUnsignedCast =*/false);
2417 rewriter.replaceAllUsesWith(
2418 payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2419 }
2420 return success(fillFound);
2421 }
2422};
2423} // namespace
2424
2427 const ControlFusionFn &controlFoldingReshapes) {
2428 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2429 controlFoldingReshapes);
2430 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2431 controlFoldingReshapes);
2432 patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2433 controlFoldingReshapes);
2434 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2435 controlFoldingReshapes);
2436}
2437
2440 const ControlFusionFn &controlFoldingReshapes) {
2441 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2442 controlFoldingReshapes);
2443 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2444 patterns.getContext(), controlFoldingReshapes);
2445 patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2446 patterns.getContext(), controlFoldingReshapes);
2447 patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
2448 controlFoldingReshapes);
2449}
2450
2453 const ControlFusionFn &controlElementwiseOpsFusion) {
2454 auto *context = patterns.getContext();
2455 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2456 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2457 RemoveOutsDependency>(context);
2458 // Add the patterns that clean up dead operands and results.
2460}
2461
2464 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2465 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2466 CollapseLinalgDimensions<linalg::CopyOp>>(
2467 patterns.getContext(), controlCollapseDimensions);
2468}
2469
2470//===---------------------------------------------------------------------===//
2471// Passes
2472//===---------------------------------------------------------------------===//
2473
2474namespace {
2475
2476/// Pass that fuses generic ops on tensors. Used only for testing.
2477// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2478// patterns added here heavily depends on the cost function used. Having an
2479// opinionated pass of this form is not recommended. Deprecate this pass in
2480// favor of test passes that check the functionality of each of the patterns
2481// added here individually.
2482struct LinalgElementwiseOpFusionPass
2483 : public impl::LinalgElementwiseOpFusionPassBase<
2484 LinalgElementwiseOpFusionPass> {
2485 using impl::LinalgElementwiseOpFusionPassBase<
2486 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2487 void runOnOperation() override {
2488 Operation *op = getOperation();
2489 MLIRContext *context = op->getContext();
2490 RewritePatternSet patterns(context);
2491
2492 // Add folding with reshape by expansion patterns.
2493 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2494 Operation *producer = fusedOperand->get().getDefiningOp();
2495 return producer && producer->hasOneUse();
2496 };
2497
2498 // Add elementwise op fusion patterns.
2502
2503 // General canonicalization patterns.
2504 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2505 GenericOp::getCanonicalizationPatterns(patterns, context);
2506 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2507 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2508 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2509 patterns);
2510
2511 // Add constant folding patterns.
2513
2514 // Use TopDownTraversal for compile time reasons.
2515 (void)applyPatternsGreedily(op, std::move(patterns),
2516 GreedyRewriteConfig().setUseTopDownTraversal());
2517 }
2518};
2519
2520} // namespace
return success()
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
ArrayRef< ReassociationIndices > getCollapsedOpToOrigOpMapping() const
Return mapping from collapsed loop domain to original loop domain.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
lhs
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:193
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition Builders.cpp:27
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
ArrayRef< int64_t > ReassociationIndicesRef
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Definition Transforms.h:560
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.