20 #define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
21 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
25 static Type getArrayElemType(Attribute attr) {
26 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
27 return typedAttr.getType();
30 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
31 return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size());
37 static std::pair<Attribute, uint32_t>
38 getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) {
39 auto compositeType = dyn_cast_or_null<spirv::CompositeType>(valueType);
43 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
44 return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
47 if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
48 if (llvm::all_equal(arrayAttr)) {
49 Attribute attr = arrayAttr[0];
50 uint32_t numElements = arrayAttr.size();
53 auto [newAttr, newNumElements] =
54 getSplatAttrAndNumElements(attr, getArrayElemType(attr));
57 numElements *= newNumElements;
59 return {attr, numElements};
66 struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
69 LogicalResult matchAndRewrite(spirv::ConstantOp op,
70 PatternRewriter &rewriter)
const override {
71 auto [attr, numElements] =
72 getSplatAttrAndNumElements(op.getValue(), op.getType());
74 return rewriter.notifyMatchFailure(op,
"composite is not splat");
77 return rewriter.notifyMatchFailure(op,
78 "composite has only one constituent");
80 rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
81 op, op.getType(), attr);
86 struct SpecConstantCompositeOpConversion final
87 : OpRewritePattern<spirv::SpecConstantCompositeOp> {
90 LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
91 PatternRewriter &rewriter)
const override {
92 auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
94 return rewriter.notifyMatchFailure(op,
"not a composite constant");
96 ArrayAttr constituents = op.getConstituents();
97 if (constituents.size() == 1)
98 return rewriter.notifyMatchFailure(op,
99 "composite has only one consituent");
101 if (!llvm::all_equal(constituents))
102 return rewriter.notifyMatchFailure(op,
"composite is not splat");
104 auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
105 if (!splatConstituent)
106 return rewriter.notifyMatchFailure(
107 op,
"expected flat symbol reference for splat constituent");
109 rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
110 op,
TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
116 struct ConvertToReplicatedConstantCompositePass final
117 : spirv::impl::SPIRVReplicatedConstantCompositePassBase<
118 ConvertToReplicatedConstantCompositePass> {
119 void runOnOperation()
override {
121 RewritePatternSet
patterns(context);
122 patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
static MLIRContext * getContext(OpFoldResult val)
static ArrayType get(Type elementType, unsigned elementCount)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...