MLIR  22.0.0git
TosaReduceTransposes.cpp
Go to the documentation of this file.
1 //===- TosaReduceTransposes.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // ----------
10 // Motivation:
11 // ----------
12 
13 // Some legalization pathways introduce redundant tosa.TRANSPOSE
14 // operations that result in avoidable data movement. For example,
15 // PyTorch -> TOSA contains a lot of unnecessary transposes due
16 // to conversions between NCHW and NHWC.
17 
18 // We wish to remove all the ones that we can, since in general
19 // it is possible to remove the overwhelming majority.
20 
21 // -------------------
22 // High-Level Overview:
23 // -------------------
24 
25 // The pass works through the transpose operators in the program. It begins at
26 // some transpose operator with an associated permutations tensor. It traverses
27 // upwards through the dependencies of this transpose and verifies that we
28 // encounter only operators with the TosaElementwiseOperator trait and terminate
29 // in either constants, reshapes, or transposes.
30 
31 // We then evaluate whether there are any additional restrictions (the
32 // transposes it terminates in must invert the one we began at, and the reshapes
33 // must be ones in which we can fold the transpose into), and then we hoist the
34 // transpose through the intervening operators, folding it at the constants,
35 // reshapes, and transposes.
36 
37 // Finally, we ensure that we do not need both the transposed form (the form
38 // that had the transpose hoisted through it) and the untransposed form (which
39 // it was prior), by analyzing the usages of those dependent operators of a
40 // given transpose we are attempting to hoist and replace.
41 
42 // If they are such that it would require both forms to be necessary, then we do
43 // not replace the hoisted transpose, causing the new chain to be dead.
44 // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
45 // chain will ever then be live, resulting in no duplication.
46 
47 // We then perform a simple one-pass DCE, so no canonicalization is necessary.
48 
49 // -----------
50 // Future Work:
51 // -----------
52 
53 // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
54 // hoisted
55 // transposes with different permutation tensors.
56 
57 // (2) Expand the class of foldable upstream ReshapeOp we permit beyond
58 // N -> 1x1x...x1xNx1x...x1x1.
59 
60 // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
61 // those that form the identity.
62 
63 // (4) Add support for more instructions besides TosaElementwiseOperator as
64 // the intervening ones (for example, the reduce_* operators).
65 
66 // (5) Support hoisting transposes up to an input parameter.
67 
68 //===----------------------------------------------------------------------===//
69 
74 #include "mlir/IR/Iterators.h"
75 #include "llvm/ADT/TypeSwitch.h"
76 #include <set>
77 #include <stack>
78 
79 namespace mlir {
80 namespace tosa {
81 #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
82 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
83 } // namespace tosa
84 } // namespace mlir
85 
86 using namespace mlir;
87 using namespace mlir::tosa;
88 
89 //===----------------------------------------------------------------------===//
90 // TOSA Reduce Transposes Pass.
91 //===----------------------------------------------------------------------===//
92 
93 namespace {
94 
95 struct TosaReduceTransposes final
96  : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
97  void runOnOperation() override;
98 
99 private:
100  // This will collect all the data dependencies for the given Operation
101  // up to and including ConstOp, ReshapeOp, and TransposeOp.
102  bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
103  bool convertDependentOps(SetVector<Operation *> &dependentOps,
104  DenseMap<Value, Value> &valuesMap,
105  IRRewriter &rewriter,
106  ArrayRef<int32_t> hoistedPerms);
107 
108  // Checks if the two permutations, when applied consecutively, result
109  // in the identity.
110  bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
111  ArrayRef<int32_t> perms2);
112 
113  // This is meant to apply to operations with the TosaElementwiseOperator
114  // trait.
115  std::optional<Value>
116  buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
117  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
118 
119  // This updates valuesMap when we encounter another TransposeOp as a
120  // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
121  // this %1 = tosa.transpose %0 <- when tracking back from this
122  std::optional<Value>
123  buildMappedToValue(TransposeOp transposeOp,
124  const DenseMap<Value, Value> &valuesMap,
125  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
126 
127  // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
128  // it creates new ReshapeOp with that fold.
129  std::optional<Value>
130  buildMappedToValue(ReshapeOp reshapeOp,
131  const DenseMap<Value, Value> &valuesMap,
132  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
133 
134  // We may have something like:
135  // %0 = tosa.const
136  // %1 = tosa.transpose
137  // %2 = tosa.add %0, %1
138  // %3 = tosa.transpose %2
139  // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
140  // in MobilenetV3.
141  std::optional<Value>
142  buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
143  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
144 
145  // Checks which TransposeOp we should "replace", turning their converted
146  // chains of ops, through which they were propagated, "live", and the old code
147  // "dead." Attempts to avoid doing so when doing so would result in the old
148  // code staying "live," resulting in duplication.
149  std::set<TransposeOp> getGoodReplacements(
150  ArrayRef<int32_t> perms,
151  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
152  &transposeInfo);
153 
154  // Helper function for dependenciesAreValid.
155  bool userNotContainedInValidTransposeDependencies(
156  Operation *user, std::set<TransposeOp> &validTransposes,
157  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
158  &transposeInfo);
159 
160  // Helper function for getGoodReplacements to check if some TransposeOp's
161  // dependencies are OK.
162  bool dependenciesAreValid(
163  ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
164  std::set<TransposeOp> &validTransposes,
165  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
166  &transposeInfo);
167 
168  // Applies perms to the DenseElementsAttr.
169  // If it returns std::nullopt, it also triggers pass failure, since verifier
170  // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
171  // it should fail).
172  // This is a basic API and may benefit from refactor into the core MLIR APIs.
173  std::optional<DenseElementsAttr>
174  transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
175 };
176 
177 std::optional<DenseElementsAttr>
178 TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
179  ArrayRef<int32_t> perms) {
180  RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
181  ArrayRef<int64_t> oldShape = oldType.getShape();
182  int64_t rank = oldType.getRank();
183 
184  // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
185  // 0. If not in place, something is very wrong.
186  if (rank <= 0 || oldType.getNumElements() <= 0) {
187  signalPassFailure();
188  return std::nullopt;
189  }
190 
191  auto newShape = applyTOSAPermutation(oldShape, perms);
192  RankedTensorType newType =
193  RankedTensorType::get(newShape, oldType.getElementType());
194 
195  if (input.isSplat()) {
196  return input.reshape(newType);
197  }
198 
199  auto rawData = input.getRawData();
200  if (!rawData.data()) {
201  return std::nullopt;
202  }
203 
204  // The algorithm is approximately as follows:
205  // 1. Determine the strides of both input and output tensors in row-major
206  // order
207  // 2. Iterate through the output tensor linearly.
208  // 3. For each output position, decompose the linear index into
209  // multi-dimensional coordinates using output strides.
210  // 4. Use the permutation to map output coordinates to input coordinates and
211  // calculate the source linear index.
212 
213  // Example: perms [2, 0, 1]; input 2x3x4; output 4x2x3
214  // for output linear index 11: decompose to output[1][1][2]
215  // using output strides [6,3,1]. Map to input coordinates using
216  // perms: dim 0→2, dim 1→0, dim 2→1, giving source position
217  // calculated as 1*inputStrides[2] + 1*inputStrides[0] + 2*inputStrides[1]
218  // = 1*1 + 1*12 + 2*4 = 21
219 
220  size_t elementSize = oldType.getElementTypeBitWidth() / 8;
221  int64_t numElements = oldType.getNumElements();
222 
223  SmallVector<char> outputBuffer(numElements * elementSize);
224  const char *inputPtr = rawData.data();
225  char *outputPtr = outputBuffer.data();
226 
227  auto calculateStrides = [](ArrayRef<int64_t> shape) -> SmallVector<int64_t> {
228  int64_t rank = shape.size();
229  SmallVector<int64_t> strides(rank);
230  strides[rank - 1] = 1;
231  for (int64_t i = rank - 2; i >= 0; --i) {
232  strides[i] = strides[i + 1] * shape[i + 1];
233  }
234  return strides;
235  };
236 
237  // Calculate strides for both input and output tensors
238  SmallVector<int64_t> inputStrides = calculateStrides(oldShape);
239  SmallVector<int64_t> outputStrides = calculateStrides(newShape);
240 
241  auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
242  int64_t tempDestIndex = destLinearIndex;
243  int64_t sourceLinearIndex = 0;
244 
245  // Decompose linear destination index into multi-dimensional
246  // coordinates dividing by output strides.
247  // Simultaneously map these coordinates through the permutation
248  // to calculate the corresponding source linear index.
249  for (auto j : llvm::seq<int64_t>(rank)) {
250  int64_t destCoord = tempDestIndex / outputStrides[j];
251  tempDestIndex %= outputStrides[j];
252  sourceLinearIndex += destCoord * inputStrides[perms[j]];
253  }
254 
255  return sourceLinearIndex;
256  };
257 
258  for (auto destLinearIndex : llvm::seq<int64_t>(numElements)) {
259  int64_t sourceLinearIndex = mapCoordinates(destLinearIndex);
260 
261  // Copy the element from source to destination using type-agnostic byte
262  // copying.
263  std::memcpy(outputPtr + destLinearIndex * elementSize,
264  inputPtr + sourceLinearIndex * elementSize, elementSize);
265  }
266 
267  return DenseElementsAttr::getFromRawBuffer(newType, outputBuffer);
268 }
269 
270 // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
271 // as the sources of the data dependencies, and TosaElementWiseOperator
272 // after that, if the function returns true.
273 bool TosaReduceTransposes::collectFanIn(Operation *op,
274  SetVector<Operation *> &collected) {
275  // Can occur if defined through the parameter to a func.func.
276  if (!op)
277  return false;
278 
279  if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
280  return false;
281 
282  // Prevent extra work if already seen.
283  if (collected.contains(op))
284  return true;
285 
286  // Throw it out so later don't have to deal with this.
287  if (op->getNumResults() != 1 ||
288  !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
289  return false;
290 
291  // We don't wish to traverse up a ReshapeOp, since generally we can't
292  // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
293  // will have no in-edges in the data dependency graph we construct for
294  // the downstream TransposeOp.
295  if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
296  !llvm::isa<tosa::ConstOp>(op)) {
297 
298  if (!llvm::isa<tosa::MulOp>(op) &&
300  return false;
301 
302  for (Value operand : op->getOperands()) {
303  // If this is a problem in future, think about alternatives to recursion.
304  if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
305  // do not recurse into MulOp's shift operand
306  continue;
307  }
308  if (!collectFanIn(operand.getDefiningOp(), collected))
309  return false;
310  }
311  }
312 
313  // Insert in topological order.
314  collected.insert(op);
315 
316  return true;
317 }
318 
319 // Assuming that due to the verification of TransposeOp perms arrays are
320 // permutations of 0 - perms.size() - 1.
321 bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
322  ArrayRef<int32_t> perms2) {
323  if (perms1.size() != perms2.size())
324  return false;
325  int32_t n = perms1.size();
326  for (int32_t i = 0; i < n; i++)
327  if (perms2[perms1[i]] != i)
328  return false;
329  return true;
330 }
331 
332 // Primary overload for those with TosaElementwiseOperator trait.
333 // The other ones handle the case of the operations that occur at the
334 // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
335 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
336  Operation *op, const DenseMap<Value, Value> &valuesMap,
337  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
338  if (op->getNumResults() != 1 ||
339  (!llvm::isa<tosa::MulOp>(op) &&
341  return std::nullopt;
342 
343  auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
344  SmallVector<Value, 3> operands;
345  for (Value v : op->getOperands()) {
346  if (valuesMap.contains(v)) {
347  operands.push_back(valuesMap.at(v));
348  } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
349  // special case for MulOp's shift operand
350  operands.push_back(v);
351  } else {
352  return std::nullopt;
353  }
354  }
355 
356  // Conceptually, we propagate the hoisted TransposeOp through
357  // these interveaning operations. For example,
358 
359  // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
360  // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
361  // tensor<3x2xi32>
362 
363  // becomes:
364  // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
365  // tensor<3x2xi32>
366  // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
367 
368  // We construct this new tosa.clamp here, but it doesn't
369  // turn "live" until the transpose being hoisted through this chain
370  // is replaced with the proper value from the new chain.
371 
372  return rewriter
373  .create(op->getLoc(), op->getName().getIdentifier(), operands,
375  applyTOSAPermutation(resultType.getShape(), hoistedPerms),
376  resultType.getElementType()),
377  op->getAttrs())
378  ->getResult(0);
379 }
380 
381 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
382  TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
383  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
384  if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
385  return std::nullopt;
386  return transposeOp.getInput1();
387 }
388 
389 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
390  ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
391  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
392  auto reshapeOutput = reshapeOp.getOutput();
393  auto reshapeInputType =
394  llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
395  auto reshapeInputShape = reshapeInputType.getShape();
396  // want reshape N -> 1x1x...x1xNx1x...x1x1
397  if (!reshapeInputType || reshapeInputShape.size() != 1)
398  return std::nullopt;
399  auto reshapeOutputType =
400  llvm::cast<RankedTensorType>(reshapeOutput.getType());
401 
402  // Instead of inserting a TransposeOp here, we check if we can fold it into
403  // the ReshapeOp. There is more complex cases where this is possible, and
404  // this check can be extended.
405 
406  // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
407  auto shape = reshapeOutputType.getShape();
408  size_t ones = llvm::count(shape, 1);
409  // N == 1 and N != 1
410  if (ones != shape.size() - 1 &&
411  !(ones == shape.size() && reshapeInputShape[0] == 1))
412  return std::nullopt;
413 
414  // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
416  if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
417  newShape)) {
418  // this mean shape is not constant
419  return std::nullopt;
420  }
421  ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
422  auto foldedReshape = ReshapeOp::create(
423  rewriter, reshapeOp.getLoc(),
424  RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
425  reshapeOutputType.getElementType()),
426  reshapeOp.getInput1(),
428  hoistedPerms)));
429  return foldedReshape->getResult(0);
430 }
431 
432 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
433  ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
434  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
435  auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
436  if (!denseAttr)
437  return std::nullopt;
438  auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
439  if (!maybeNewDenseAttr.has_value())
440  return std::nullopt;
441  auto newDenseAttr = maybeNewDenseAttr.value();
442  auto newConstOp = ConstOp::create(rewriter, constOp.getLoc(),
443  newDenseAttr.getType(), newDenseAttr);
444  return newConstOp->getResult(0);
445 }
446 
447 bool TosaReduceTransposes::convertDependentOps(
448  SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
449  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
450 
451  for (Operation *op : dependentOps) {
452  if (!op || op->getNumResults() != 1)
453  return false;
454 
455  Value priorValue = op->getResult(0);
456 
457  // It's possible on a prior transposeOp we had the same dependency and
458  // already resolved it.
459  if (valuesMap.contains(priorValue))
460  continue;
461 
462  // Keep converted ops close to the original.
463  rewriter.setInsertionPointAfter(op);
464 
465  std::optional<Value> maybeValue =
467  .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
468  return buildMappedToValue(transposeOp, valuesMap, rewriter,
469  hoistedPerms);
470  })
471  .Default([&](Operation *op) {
472  return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
473  });
474 
475  if (!maybeValue.has_value())
476  return false;
477 
478  valuesMap[priorValue] = maybeValue.value();
479  }
480 
481  return true;
482 }
483 
484 bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
485  Operation *user, std::set<TransposeOp> &validTransposes,
486  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
487  &transposeInfo) {
488  return llvm::none_of(
489  transposeInfo,
490  [&validTransposes,
491  user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
492  const auto &[transposeOp, dependentOps] = info;
493  return validTransposes.count(transposeOp) &&
494  dependentOps.contains(user);
495  });
496 }
497 
498 // Dependencies are valid for an operation if none of them occur outside
499 // of the proper fan-in cones of the hoisted TransposeOp with the same perms
500 // that we can replace. Described in more detail within.
501 bool TosaReduceTransposes::dependenciesAreValid(
502  ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
503  std::set<TransposeOp> &validTransposes,
504  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
505  &transposeInfo) {
506  for (Operation *op : dependentOps) {
507 
508  // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
509  // This can be changed later if we find the memory impact is too high.
510  if (llvm::isa<ConstOp>(op))
511  continue;
512 
513  for (OpOperand &use : op->getUses()) {
514  // Want the uses to be (1) contained in the dependentOps of other
515  // validTransposes, or (2) to be directly used in a TransposeOp with the
516  // same perms. For (2) it means the fan-in is a subset of our
517  // dependentOps, so it is also a validTranspose that will eventually be
518  // replaced.
519  Operation *user = use.getOwner();
520  if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
521  // Can later think about cases where transpose -> transpose
522  // or reshape -> transpose, where the transposes are not necessarily
523  // the same perms as the hoisted, if implementing a more general
524  // transform. These could be permitted.
525  if (!llvm::equal(perms, otherTranspose.getPerms()))
526  return false;
527  } else if (userNotContainedInValidTransposeDependencies(
528  user, validTransposes, transposeInfo)) {
529  return false;
530  }
531  }
532  }
533 
534  return true;
535 }
536 
537 // Getting the set of TransposeOp that we can replace without causing
538 // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
539 // dead code. This is done by iterating the set until convergence, since
540 // if you are used outside your own fan-in cone, it's possible to be used
541 // in another fan-in cone of a TransposeOp that is being replaced -- unless
542 // we find that that one has a usage outside of it too.
543 std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
544  ArrayRef<int32_t> perms,
545  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
546  &transposeInfo) {
547  // Initially, we assume they are all good to replace,
548  // and we whittle them down based on our criteria.
549  std::set<TransposeOp> ableToReplace;
550  for (const auto &[transposeOp, _] : transposeInfo)
551  ableToReplace.insert(transposeOp);
552 
553  bool gotRid;
554  do {
555  gotRid = false;
556  for (const auto &[transposeOp, dependentOps] : transposeInfo) {
557  // We don't care about it. Already invalidated.
558  if (!ableToReplace.count(transposeOp))
559  continue;
560 
561  // Check for validity.
562  if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
563  transposeInfo)) {
564  ableToReplace.erase(transposeOp);
565  gotRid = true;
566  break;
567  }
568  }
569 
570  } while (gotRid);
571 
572  return ableToReplace;
573 }
574 
575 void TosaReduceTransposes::runOnOperation() {
576  // We want to operate only within a single block.
577  if (!getOperation().getRegion().hasOneBlock())
578  return;
579 
580  IRRewriter rewriter(&getContext());
581  // For each perms, maintain a mapping for converted ops, avoid duplication.
583  // For each perms, we keep track of which TransposeOp are eligible
584  // for replacement alongside their dependentOps.
586  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
587  permsToTransposeInfo;
588 
589  // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
590  // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
591  // since no guarantee of smallness.
592  std::vector<SmallVector<int32_t>> collectedPerms;
593 
594  // This keeps track of the order across all eligible-for-replacement
595  // TransposeOp and their perms, a necessity for the final replacements.
596  std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
597 
598  // We want to reserve the space up front, since SmallVector stores some data
599  // internally and the ArrayRef can reference that, which we don't want to get
600  // invalidated.
601  size_t expectedMaxPerms = 0;
602  getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
603  collectedPerms.reserve(expectedMaxPerms);
604 
605  getOperation().walk([&](TransposeOp transposeOp) {
606  SetVector<Operation *> dependentOps;
607  collectedPerms.emplace_back();
608  SmallVector<int32_t> &perms = collectedPerms.back();
609 
610  // Dynamic shapes are OK, but the incompatible ones will be rejected later.
611  auto input = transposeOp.getInput1();
612  auto output = transposeOp.getOutput();
613 
614  // However, we don't support unranked tensors.
615  if (!llvm::isa<RankedTensorType>(input.getType()) ||
616  !llvm::isa<RankedTensorType>(output.getType()))
617  return;
618 
619  llvm::append_range(perms, transposeOp.getPerms());
620 
621  // We let --canonicalize deal with identity transpose.
622  if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
623  return;
624 
625  // Can fail if some set of basic invariants is not met that we want to
626  // perform our conversions.
627  if (!collectFanIn(input.getDefiningOp(), dependentOps))
628  return;
629 
630  // Want to associate valuesMap for already converted of the same perms,
631  // since it's possible multiple hoisted transposes w/ different perms
632  // converge on an op, which would result in different transformations.
633  DenseMap<Value, Value> &valuesMap = permsToValues[perms];
634 
635  // Attempt to perform the conversions and placements into IR
636  // without turning inserted code "live". Also fills out valuesMap.
637  // Fails if there is an intermediary we do not support.
638  if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
639  // Some additional operations may have been inserted, but will be
640  // removed by dead code elimination.
641  return;
642 
643  // This should not happen. If it does -- it's unexpected,
644  // so we fail the pass.
645  if (!valuesMap.contains(input))
646  return signalPassFailure();
647 
648  // It's possible the types are not compatible (because of dynamic shapes),
649  // and in these cases, want to resolve dynamic shapes before running the
650  // pass.
651  if (output.getType() != valuesMap.at(input).getType())
652  return;
653 
654  auto &transposeInfo = permsToTransposeInfo[perms];
655 
656  // In general, we might also want to introduce "newDependentOps"
657  // if there are new usages that don't fall inside the original fan-ins
658  // (like the TransposeOp we insert for ReshapeOp),
659  // but in this case, that is specialized enough and overlaps
660  // with another direct-use TransposeOp case we need to cover anyway.
661  transposeInfo.emplace_back(transposeOp, dependentOps);
662 
663  // This is for the final replacement across all transposes.
664  totalTransposeOrder.emplace(transposeOp, perms);
665  });
666 
667  // We want to do a full fan-in analysis on a perms-level,
668  // since if we do it on a multi-perms level, and they share (due to a shared
669  // dependency on a Reshape) then we would also get duplicate ops.
670  // Const is special cased.
671  std::set<TransposeOp> ableToReplace;
672  for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
673  // Gives us back replacements that would never result in any duplicate
674  // operations being inserted by us in the IR (i.e, our goal is only to
675  // remove transposes, and not create a "new chain" to do so, but replace
676  // the existing chains).
677  // Ideally, --canonicalize is run before this pass, since it helps this
678  // analysis by removing dead code to allow more potentially acceptable
679  // transformations.
680  auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
681  ableToReplace.insert(goodReplacementsForPerms.begin(),
682  goodReplacementsForPerms.end());
683  }
684 
685  // We want to do replacement across all transposes
686  // in reverse order, due to invalidation of valuesMap mappings
687  // if we did it otherwise.
688  while (!totalTransposeOrder.empty()) {
689  auto [transposeOp, perms] = totalTransposeOrder.top();
690  totalTransposeOrder.pop();
691 
692  if (ableToReplace.count(transposeOp) == 0)
693  continue;
694 
695  auto &valuesMap = permsToValues[perms];
696  auto input = transposeOp.getInput1();
697 
698  // The purpose of this reverse iteration
699  // is to avoid valuesMap invalidation. If it happens,
700  // something is wrong.
701  if (!valuesMap.contains(input))
702  return signalPassFailure();
703 
704  rewriter.replaceOp(transposeOp, valuesMap.at(input));
705  }
706 
707  // We can remove all dead code by going in reverse.
708  // This is because we would remove usages before we
709  // see the users.
710  getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
711  [&](Operation *op) {
712  if (isOpTriviallyDead(op))
713  rewriter.eraseOp(op);
714  });
715 }
716 
717 } // namespace
static MLIRContext * getContext(OpFoldResult val)
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents an operand of an operation.
Definition: Value.h:257
This class indicates that an op is tosa-elementwise (permits broadcasting, unlike Elementwise trait).
Definition: TosaOps.h:114
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:846
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This iterator enumerates elements in "reverse" order.
Definition: Iterators.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.