75#include "llvm/ADT/TypeSwitch.h"
81#define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
82#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
95struct TosaReduceTransposes final
97 void runOnOperation()
override;
104 DenseMap<Value, Value> &valuesMap,
105 IRRewriter &rewriter,
106 ArrayRef<int32_t> hoistedPerms);
110 bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
111 ArrayRef<int32_t> perms2);
116 buildMappedToValue(Operation *op,
const DenseMap<Value, Value> &valuesMap,
117 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
123 buildMappedToValue(TransposeOp transposeOp,
124 const DenseMap<Value, Value> &valuesMap,
125 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
130 buildMappedToValue(ReshapeOp reshapeOp,
131 const DenseMap<Value, Value> &valuesMap,
132 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
142 buildMappedToValue(ConstOp constOp,
const DenseMap<Value, Value> &valuesMap,
143 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
149 std::set<TransposeOp> getGoodReplacements(
150 ArrayRef<int32_t> perms,
155 bool userNotContainedInValidTransposeDependencies(
156 Operation *user, std::set<TransposeOp> &validTransposes,
162 bool dependenciesAreValid(
164 std::set<TransposeOp> &validTransposes,
173 std::optional<DenseElementsAttr>
174 transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
177std::optional<DenseElementsAttr>
180 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.
getType());
181 ArrayRef<int64_t> oldShape = oldType.getShape();
182 int64_t rank = oldType.getRank();
186 if (rank <= 0 || oldType.getNumElements() <= 0) {
192 RankedTensorType newType =
193 RankedTensorType::get(newShape, oldType.getElementType());
200 if (!rawData.data()) {
220 size_t elementSize = oldType.getElementTypeBitWidth() / 8;
221 int64_t numElements = oldType.getNumElements();
223 SmallVector<char> outputBuffer(numElements * elementSize);
224 const char *inputPtr = rawData.data();
225 char *outputPtr = outputBuffer.data();
227 auto calculateStrides = [](ArrayRef<int64_t> shape) -> SmallVector<int64_t> {
228 int64_t rank = shape.size();
229 SmallVector<int64_t> strides(rank);
230 strides[rank - 1] = 1;
231 for (int64_t i = rank - 2; i >= 0; --i) {
232 strides[i] = strides[i + 1] * shape[i + 1];
238 SmallVector<int64_t> inputStrides = calculateStrides(oldShape);
239 SmallVector<int64_t> outputStrides = calculateStrides(newShape);
241 auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
242 int64_t tempDestIndex = destLinearIndex;
243 int64_t sourceLinearIndex = 0;
249 for (
auto j : llvm::seq<int64_t>(rank)) {
250 int64_t destCoord = tempDestIndex / outputStrides[j];
251 tempDestIndex %= outputStrides[j];
252 sourceLinearIndex += destCoord * inputStrides[perms[j]];
255 return sourceLinearIndex;
258 for (
auto destLinearIndex : llvm::seq<int64_t>(numElements)) {
259 int64_t sourceLinearIndex = mapCoordinates(destLinearIndex);
263 std::memcpy(outputPtr + destLinearIndex * elementSize,
264 inputPtr + sourceLinearIndex * elementSize, elementSize);
273bool TosaReduceTransposes::collectFanIn(Operation *op,
279 if (!llvm::isa_and_present<tosa::TosaDialect>(op->
getDialect()))
283 if (collected.contains(op))
295 if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
296 !llvm::isa<tosa::ConstOp>(op)) {
298 if (!llvm::isa<tosa::MulOp>(op) &&
299 !op->
hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
304 if (llvm::isa<tosa::MulOp>(op) && operand == op->
getOperand(2)) {
308 if (!collectFanIn(operand.getDefiningOp(), collected))
314 collected.insert(op);
321bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
322 ArrayRef<int32_t> perms2) {
323 if (perms1.size() != perms2.size())
325 int32_t n = perms1.size();
326 for (int32_t i = 0; i < n; i++)
327 if (perms2[perms1[i]] != i)
335std::optional<Value> TosaReduceTransposes::buildMappedToValue(
336 Operation *op,
const DenseMap<Value, Value> &valuesMap,
337 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
339 (!llvm::isa<tosa::MulOp>(op) &&
340 !op->
hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
344 SmallVector<Value, 3> operands;
346 if (valuesMap.contains(v)) {
347 operands.push_back(valuesMap.at(v));
348 }
else if (llvm::isa<tosa::MulOp>(op) && v == op->
getOperand(2)) {
350 operands.push_back(v);
374 RankedTensorType::get(
376 resultType.getElementType()),
381std::optional<Value> TosaReduceTransposes::buildMappedToValue(
382 TransposeOp transposeOp,
const DenseMap<Value, Value> &valuesMap,
383 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
384 if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
386 return transposeOp.getInput1();
389std::optional<Value> TosaReduceTransposes::buildMappedToValue(
390 ReshapeOp reshapeOp,
const DenseMap<Value, Value> &valuesMap,
391 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
392 auto reshapeOutput = reshapeOp.getOutput();
393 auto reshapeInputType =
394 llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
395 auto reshapeInputShape = reshapeInputType.getShape();
397 if (!reshapeInputType || reshapeInputShape.size() != 1)
399 auto reshapeOutputType =
400 llvm::cast<RankedTensorType>(reshapeOutput.getType());
407 auto shape = reshapeOutputType.getShape();
408 size_t ones = llvm::count(shape, 1);
410 if (ones != shape.size() - 1 &&
411 !(ones == shape.size() && reshapeInputShape[0] == 1))
415 llvm::SmallVector<int64_t> newShape;
421 ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
422 auto foldedReshape = ReshapeOp::create(
423 rewriter, reshapeOp.getLoc(),
425 reshapeOutputType.getElementType()),
426 reshapeOp.getInput1(),
429 return foldedReshape->getResult(0);
432std::optional<Value> TosaReduceTransposes::buildMappedToValue(
433 ConstOp constOp,
const DenseMap<Value, Value> &valuesMap,
434 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
435 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
438 auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
439 if (!maybeNewDenseAttr.has_value())
441 auto newDenseAttr = maybeNewDenseAttr.value();
442 auto newConstOp = ConstOp::create(rewriter, constOp.getLoc(),
443 newDenseAttr.getType(), newDenseAttr);
444 return newConstOp->getResult(0);
447bool TosaReduceTransposes::convertDependentOps(
449 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
451 for (Operation *op : dependentOps) {
459 if (valuesMap.contains(priorValue))
465 std::optional<Value> maybeValue =
466 llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
467 .Case<TransposeOp, ReshapeOp, ConstOp>([&](
auto transposeOp) {
468 return buildMappedToValue(transposeOp, valuesMap, rewriter,
471 .Default([&](Operation *op) {
472 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
475 if (!maybeValue.has_value())
478 valuesMap[priorValue] = maybeValue.value();
484bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
485 Operation *user, std::set<TransposeOp> &validTransposes,
488 return llvm::none_of(
492 const auto &[transposeOp, dependentOps] = info;
493 return validTransposes.count(transposeOp) &&
494 dependentOps.contains(user);
501bool TosaReduceTransposes::dependenciesAreValid(
503 std::set<TransposeOp> &validTransposes,
506 for (Operation *op : dependentOps) {
510 if (llvm::isa<ConstOp>(op))
513 for (OpOperand &use : op->
getUses()) {
519 Operation *user = use.getOwner();
520 if (
auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
525 if (!llvm::equal(perms, otherTranspose.getPerms()))
527 }
else if (userNotContainedInValidTransposeDependencies(
528 user, validTransposes, transposeInfo)) {
543std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
544 ArrayRef<int32_t> perms,
549 std::set<TransposeOp> ableToReplace;
550 for (
const auto &[transposeOp, _] : transposeInfo)
551 ableToReplace.insert(transposeOp);
556 for (
const auto &[transposeOp, dependentOps] : transposeInfo) {
558 if (!ableToReplace.count(transposeOp))
562 if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
564 ableToReplace.erase(transposeOp);
572 return ableToReplace;
575void TosaReduceTransposes::runOnOperation() {
577 if (!getOperation().getRegion().hasOneBlock())
586 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
587 permsToTransposeInfo;
592 std::vector<SmallVector<int32_t>> collectedPerms;
596 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
601 size_t expectedMaxPerms = 0;
602 getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
603 collectedPerms.reserve(expectedMaxPerms);
605 getOperation().walk([&](TransposeOp transposeOp) {
607 collectedPerms.emplace_back();
608 SmallVector<int32_t> &perms = collectedPerms.back();
611 auto input = transposeOp.getInput1();
612 auto output = transposeOp.getOutput();
615 if (!llvm::isa<RankedTensorType>(input.
getType()) ||
616 !llvm::isa<RankedTensorType>(output.getType()))
619 llvm::append_range(perms, transposeOp.getPerms());
622 if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
627 if (!collectFanIn(input.getDefiningOp(), dependentOps))
633 DenseMap<Value, Value> &valuesMap = permsToValues[perms];
638 if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
645 if (!valuesMap.contains(input))
651 if (output.getType() != valuesMap.at(input).getType())
654 auto &transposeInfo = permsToTransposeInfo[perms];
661 transposeInfo.emplace_back(transposeOp, dependentOps);
664 totalTransposeOrder.emplace(transposeOp, perms);
671 std::set<TransposeOp> ableToReplace;
672 for (
auto &[perms, transposeInfo] : permsToTransposeInfo) {
680 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
681 ableToReplace.insert(goodReplacementsForPerms.begin(),
682 goodReplacementsForPerms.end());
688 while (!totalTransposeOrder.empty()) {
689 auto [transposeOp, perms] = totalTransposeOrder.top();
690 totalTransposeOrder.pop();
692 if (ableToReplace.count(transposeOp) == 0)
695 auto &valuesMap = permsToValues[perms];
696 auto input = transposeOp.getInput1();
701 if (!valuesMap.contains(input))
702 return signalPassFailure();
704 rewriter.
replaceOp(transposeOp, valuesMap.at(input));
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
void signalPassFailure()
Signal that some invariant was broken when running.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type getType() const
Return the type of this value.
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This iterator enumerates elements in "reverse" order.