76 #include "llvm/ADT/TypeSwitch.h"
83 #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
84 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
97 struct TosaReduceTransposes final
98 :
public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
99 void runOnOperation()
override;
125 buildMappedToValue(TransposeOp transposeOp,
132 buildMappedToValue(ReshapeOp reshapeOp,
151 std::set<TransposeOp> getGoodReplacements(
157 bool userNotContainedInValidTransposeDependencies(
158 Operation *user, std::set<TransposeOp> &validTransposes,
164 bool dependenciesAreValid(
166 std::set<TransposeOp> &validTransposes,
175 std::optional<DenseElementsAttr>
179 std::optional<DenseElementsAttr>
182 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.
getType());
183 RankedTensorType newType =
185 oldType.getElementType());
186 size_t rank = oldType.getRank();
190 if (rank <= 0 || oldType.getNumElements() <= 0) {
213 originalInputStrides[rank - 1] = 1;
215 for (int64_t i = rank - 2; i >= 0; i--)
216 originalInputStrides[i] =
217 originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
223 newInputStrides.reserve(rank);
224 for (int32_t v : perms)
225 newInputStrides.push_back(originalInputStrides[v]);
231 for (
size_t i = 0; i < rank; i++)
232 boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
235 resultArray.reserve(inputArray.size());
237 std::function<void(int64_t,
238 SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
239 processTransposeDim = [&](
auto accumulatedIndex,
auto it) {
240 if (it == boundsAndStrides.end()) {
241 resultArray.push_back(inputArray[accumulatedIndex]);
245 for (int64_t i = 0; i < it->first; i++) {
246 int64_t
j = accumulatedIndex + i * it->second;
247 processTransposeDim(
j, it + 1);
251 processTransposeDim(0, boundsAndStrides.begin());
259 bool TosaReduceTransposes::collectFanIn(
Operation *op,
265 if (!llvm::isa_and_present<tosa::TosaDialect>(op->
getDialect()))
269 if (collected.contains(op))
281 if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
282 !llvm::isa<tosa::ConstOp>(op)) {
284 if (!llvm::isa<tosa::MulOp>(op) &&
290 if (llvm::isa<tosa::MulOp>(op) && operand == op->
getOperand(2)) {
294 if (!collectFanIn(operand.getDefiningOp(), collected))
300 collected.insert(op);
309 if (perms1.size() != perms2.size())
311 int32_t n = perms1.size();
312 for (int32_t i = 0; i < n; i++)
313 if (perms2[perms1[i]] != i)
321 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
325 (!llvm::isa<tosa::MulOp>(op) &&
332 if (valuesMap.contains(v)) {
333 operands.push_back(valuesMap.at(v));
334 }
else if (llvm::isa<tosa::MulOp>(op) && v == op->
getOperand(2)) {
336 operands.push_back(v);
362 resultType.getElementType()),
367 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
370 if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
372 return transposeOp.getInput1();
375 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
378 auto reshapeOutput = reshapeOp.getOutput();
379 auto reshapeInputType =
380 llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
381 auto reshapeInputShape = reshapeInputType.getShape();
383 if (!reshapeInputType || reshapeInputShape.size() != 1)
385 auto reshapeOutputType =
386 llvm::cast<RankedTensorType>(reshapeOutput.getType());
393 auto shape = reshapeOutputType.getShape();
394 size_t ones = llvm::count(shape, 1);
396 if (ones != shape.size() - 1 &&
397 !(ones == shape.size() && reshapeInputShape[0] == 1))
408 auto foldedReshape = rewriter.
create<ReshapeOp>(
411 reshapeOutputType.getElementType()),
412 reshapeOp.getInput1(),
418 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
421 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
424 auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
425 if (!maybeNewDenseAttr.has_value())
427 auto newDenseAttr = maybeNewDenseAttr.value();
428 auto newConstOp = rewriter.
create<ConstOp>(
429 constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
433 bool TosaReduceTransposes::convertDependentOps(
445 if (valuesMap.contains(priorValue))
451 std::optional<Value> maybeValue =
453 .Case<TransposeOp, ReshapeOp, ConstOp>([&](
auto transposeOp) {
454 return buildMappedToValue(transposeOp, valuesMap, rewriter,
458 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
461 if (!maybeValue.has_value())
464 valuesMap[priorValue] = maybeValue.value();
470 bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
471 Operation *user, std::set<TransposeOp> &validTransposes,
474 return llvm::none_of(
478 const auto &[transposeOp, dependentOps] = info;
479 return validTransposes.count(transposeOp) &&
480 dependentOps.contains(user);
487 bool TosaReduceTransposes::dependenciesAreValid(
489 std::set<TransposeOp> &validTransposes,
496 if (llvm::isa<ConstOp>(op))
506 if (
auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
511 if (!llvm::equal(perms, otherTranspose.getPerms()))
513 }
else if (userNotContainedInValidTransposeDependencies(
514 user, validTransposes, transposeInfo)) {
529 std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
535 std::set<TransposeOp> ableToReplace;
536 for (
const auto &[transposeOp, _] : transposeInfo)
537 ableToReplace.insert(transposeOp);
542 for (
const auto &[transposeOp, dependentOps] : transposeInfo) {
544 if (!ableToReplace.count(transposeOp))
548 if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
550 ableToReplace.
erase(transposeOp);
558 return ableToReplace;
561 void TosaReduceTransposes::runOnOperation() {
563 if (!getOperation().getRegion().hasOneBlock())
572 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
573 permsToTransposeInfo;
578 std::vector<SmallVector<int32_t>> collectedPerms;
582 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
587 size_t expectedMaxPerms = 0;
588 getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
589 collectedPerms.reserve(expectedMaxPerms);
591 getOperation().walk([&](TransposeOp transposeOp) {
593 collectedPerms.emplace_back();
597 auto input = transposeOp.getInput1();
598 auto output = transposeOp.getOutput();
601 if (!llvm::isa<RankedTensorType>(input.
getType()) ||
602 !llvm::isa<RankedTensorType>(output.getType()))
605 llvm::for_each(transposeOp.getPerms(),
606 [&perms](
const auto i) { perms.emplace_back(i); });
609 if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
614 if (!collectFanIn(input.getDefiningOp(), dependentOps))
625 if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
632 if (!valuesMap.contains(input))
633 return signalPassFailure();
638 if (output.getType() != valuesMap.at(input).getType())
641 auto &transposeInfo = permsToTransposeInfo[perms];
648 transposeInfo.push_back({transposeOp, dependentOps});
651 totalTransposeOrder.push({transposeOp, perms});
658 std::set<TransposeOp> ableToReplace;
659 for (
auto &[perms, transposeInfo] : permsToTransposeInfo) {
667 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
668 ableToReplace.insert(goodReplacementsForPerms.begin(),
669 goodReplacementsForPerms.end());
675 while (!totalTransposeOrder.empty()) {
676 auto [transposeOp, perms] = totalTransposeOrder.top();
677 totalTransposeOrder.pop();
679 if (ableToReplace.count(transposeOp) == 0)
682 auto &valuesMap = permsToValues[perms];
683 auto input = transposeOp.getInput1();
688 if (!valuesMap.contains(input))
689 return signalPassFailure();
691 rewriter.
replaceOp(transposeOp, valuesMap.at(input));
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This class indicates that an op is tosa-elementwise (permits broadcasting, unlike Elementwise trait).
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This iterator enumerates elements in "reverse" order.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.