MLIR 22.0.0git
ConvertToReplicatedConstantCompositePass.cpp
Go to the documentation of this file.
1//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert a splat composite spirv.Constant and
10// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
11// spirv.EXT.SpecConstantCompositeReplicate respectively.
12//
13//===----------------------------------------------------------------------===//
14
18
19namespace mlir::spirv {
20#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
21#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
22
23namespace {
24
25static Type getArrayElemType(Attribute attr) {
26 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
27 return typedAttr.getType();
28 }
29
30 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
31 return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size());
32 }
33
34 return nullptr;
35}
36
37static std::pair<Attribute, uint32_t>
38getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) {
39 auto compositeType = dyn_cast_or_null<spirv::CompositeType>(valueType);
40 if (!compositeType)
41 return {nullptr, 1};
42
43 if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
44 return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
45 }
46
47 if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
48 if (llvm::all_equal(arrayAttr)) {
49 Attribute attr = arrayAttr[0];
50 uint32_t numElements = arrayAttr.size();
51
52 // Find the inner-most splat value for array of composites
53 auto [newAttr, newNumElements] =
54 getSplatAttrAndNumElements(attr, getArrayElemType(attr));
55 if (newAttr) {
56 attr = newAttr;
57 numElements *= newNumElements;
58 }
59 return {attr, numElements};
60 }
61 }
62
63 return {nullptr, 1};
64}
65
66struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
67 using Base::Base;
68
69 LogicalResult matchAndRewrite(spirv::ConstantOp op,
70 PatternRewriter &rewriter) const override {
71 auto [attr, numElements] =
72 getSplatAttrAndNumElements(op.getValue(), op.getType());
73 if (!attr)
74 return rewriter.notifyMatchFailure(op, "composite is not splat");
75
76 if (numElements == 1)
77 return rewriter.notifyMatchFailure(op,
78 "composite has only one constituent");
79
80 rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
81 op, op.getType(), attr);
82 return success();
83 }
84};
85
86struct SpecConstantCompositeOpConversion final
87 : OpRewritePattern<spirv::SpecConstantCompositeOp> {
88 using Base::Base;
89
90 LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
91 PatternRewriter &rewriter) const override {
92 auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
93 if (!compositeType)
94 return rewriter.notifyMatchFailure(op, "not a composite constant");
95
96 ArrayAttr constituents = op.getConstituents();
97 if (constituents.size() == 1)
98 return rewriter.notifyMatchFailure(op,
99 "composite has only one consituent");
100
101 if (!llvm::all_equal(constituents))
102 return rewriter.notifyMatchFailure(op, "composite is not splat");
103
104 auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
105 if (!splatConstituent)
106 return rewriter.notifyMatchFailure(
107 op, "expected flat symbol reference for splat constituent");
108
109 rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
110 op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
111
112 return success();
113 }
114};
115
116struct ConvertToReplicatedConstantCompositePass final
117 : spirv::impl::SPIRVReplicatedConstantCompositePassBase<
118 ConvertToReplicatedConstantCompositePass> {
119 void runOnOperation() override {
120 MLIRContext *context = &getContext();
121 RewritePatternSet patterns(context);
122 patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
123 context);
124 walkAndApplyPatterns(getOperation(), std::move(patterns));
125 }
126};
127
128} // namespace
129} // namespace mlir::spirv
return success()
ArrayAttr()
b getContext())
static ArrayType get(Type elementType, unsigned elementCount)
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.