MLIR 23.0.0git
ArithToArmSME.cpp
Go to the documentation of this file.
1//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/Pass/Pass.h"
16
17namespace mlir {
18#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19#include "mlir/Conversion/Passes.h.inc"
20} // namespace mlir
21
22#define DEBUG_TYPE "arith-to-arm-sme"
23
24using namespace mlir;
25
26//===----------------------------------------------------------------------===//
27// Conversion helpers
28//===----------------------------------------------------------------------===//
29
30/// Returns true if 'val' is a splat of zero, false otherwise.
31static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32 if (llvm::isa<FloatType>(elemType))
33 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34 if (llvm::isa<IntegerType>(elemType))
35 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36 return false;
37}
38
39namespace {
40
41//===----------------------------------------------------------------------===//
42// ConstantOp
43//===----------------------------------------------------------------------===//
44
45/// Conversion pattern for dense arith.constant.
46struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
47 using Base::Base;
48
49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50 PatternRewriter &rewriter) const final {
51 auto tileType = dyn_cast<VectorType>(constantOp.getType());
52 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
53 return failure();
54
55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56 if (!denseAttr || !denseAttr.isSplat())
57 return failure();
58
59 auto tileElementType = tileType.getElementType();
60
61 // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62 if (isSplatZero(tileElementType, denseAttr)) {
63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64 return success();
65 }
66
67 // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
68 // ops that broadcast the constant to each tile slice.
69 auto loc = constantOp.getLoc();
70
71 // To fill a tile with a constant, we create a 1-D splat of the constant,
72 // then move that into each tile slice (the largest unit we can set at once,
73 // outside of operations like the outerproduct).
74 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75 auto denseAttr1D = DenseElementsAttr::get(
76 tileSliceType, denseAttr.getSplatValue<Attribute>());
77 auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
78
79 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
80 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81 Value currentTile) {
82 // Create 'arm_sme.insert_tile_slice' to write vector to tile
83 // slice.
84 auto nextTile = arm_sme::InsertTileSliceOp::create(
85 b, loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86 return nextTile.getResult();
87 };
89 rewriter, loc, initTile, makeLoopBody);
90 rewriter.replaceOp(constantOp, forOp.getResult(0));
91
92 return success();
93 }
94};
95
96} // namespace
97
98//===----------------------------------------------------------------------===//
99// Pattern population
100//===----------------------------------------------------------------------===//
101
103 RewritePatternSet &patterns) {
104 patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
105}
106
107//===----------------------------------------------------------------------===//
108// Pass definition
109//===----------------------------------------------------------------------===//
110
111namespace {
112struct ArithToArmSMEConversionPass final
113 : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
116
117 void runOnOperation() override {
118 RewritePatternSet patterns(&getContext());
120 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
121 return signalPassFailure();
122 }
123};
124} // namespace
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition Utils.cpp:86
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...