MLIR 22.0.0git
ArithToArmSME.cpp
Go to the documentation of this file.
1//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/Pass/Pass.h"
16
17namespace mlir {
18#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19#include "mlir/Conversion/Passes.h.inc"
20} // namespace mlir
21
22#define DEBUG_TYPE "arith-to-arm-sme"
23
24using namespace mlir;
25
26//===----------------------------------------------------------------------===//
27// Conversion helpers
28//===----------------------------------------------------------------------===//
29
30/// Returns true if 'val' is a splat of zero, false otherwise.
31static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32 if (llvm::isa<FloatType>(elemType))
33 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34 if (llvm::isa<IntegerType>(elemType))
35 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36 return false;
37}
38
39namespace {
40
41//===----------------------------------------------------------------------===//
42// ConstantOp
43//===----------------------------------------------------------------------===//
44
45/// Conversion pattern for dense arith.constant.
46struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
47 using Base::Base;
48
49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50 PatternRewriter &rewriter) const final {
51 auto tileType = dyn_cast<VectorType>(constantOp.getType());
52 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
53 return failure();
54
55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56 if (!denseAttr || !denseAttr.isSplat())
57 return failure();
58
59 auto tileElementType = tileType.getElementType();
60
61 // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62 if (isSplatZero(tileElementType, denseAttr)) {
63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64 return success();
65 }
66
67 // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
68 // ops that broadcast the constant to each tile slice.
69 auto loc = constantOp.getLoc();
70
71 // To fill a tile with a constant, we create a 1-D splat of the constant,
72 // then move that into each tile slice (the largest unit we can set at once,
73 // outside of operations like the outerproduct).
74 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75 auto denseAttr1D = DenseElementsAttr::get(
76 tileSliceType, denseAttr.getSplatValue<Attribute>());
77 auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
78
79 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
80 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81 Value currentTile) {
82 // Create 'arm_sme.insert_tile_slice' to write vector to tile
83 // slice.
84 auto nextTile = arm_sme::InsertTileSliceOp::create(
85 b, loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86 return nextTile.getResult();
87 };
89 rewriter, loc, initTile, makeLoopBody);
90 rewriter.replaceOp(constantOp, forOp.getResult(0));
91
92 return success();
93 }
94};
95
96} // namespace
97
98//===----------------------------------------------------------------------===//
99// Pattern population
100//===----------------------------------------------------------------------===//
101
104 patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
105}
106
107//===----------------------------------------------------------------------===//
108// Pass definition
109//===----------------------------------------------------------------------===//
110
111namespace {
112struct ArithToArmSMEConversionPass final
113 : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
116
117 void runOnOperation() override {
120 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
121 return signalPassFailure();
122 }
123};
124} // namespace
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition Utils.cpp:89
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...