MLIR 22.0.0git
XeGPUVectorLinearize.cpp
Go to the documentation of this file.
1//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/Pass/Pass.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/raw_ostream.h"
22
23#include <optional>
24
25namespace mlir {
26namespace xegpu {
27#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
28#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29} // namespace xegpu
30} // namespace mlir
31
32#define DEBUG_TYPE "xegpu-vector-linearize"
33
34using namespace mlir;
35
36namespace {
37struct XeGPUVectorLinearizePass final
38 : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
39 void runOnOperation() override {
40 // vector.broadcast and vector.gather requires progressive lowering
41 {
46 // vector.transpose lowering
47 // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
49 patterns, vector::VectorTransposeLowering::Shuffle16x16);
50 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
51 return signalPassFailure();
52 }
53
54 // Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
55 // <1x1x...x1xdk>.
56 {
58 vector::UnrollVectorOptions vectorOptions;
59 vectorOptions.setNativeShapeFn(
60 [](Operation *op) -> std::optional<SmallVector<int64_t>> {
61 auto extractVectorType = [](Operation *op) -> VectorType {
62 if (auto loadOp = dyn_cast<vector::LoadOp>(op))
63 return loadOp.getVectorType();
64 if (auto storeOp = dyn_cast<vector::StoreOp>(op))
65 return storeOp.getVectorType();
66 return nullptr;
67 };
68
69 VectorType vecType = extractVectorType(op);
70 if (!vecType)
71 return std::nullopt;
72
73 // Only handle rank >= 2 so we actually unroll something.
74 int64_t rank = vecType.getRank();
75 if (rank < 2)
76 return std::nullopt;
77
78 ArrayRef<int64_t> shape = vecType.getShape();
79 // Produce native shape: 1 x 1 x ... x (original last dim).
80 SmallVector<int64_t> native(rank, 1);
81 native.back() = shape.back();
82 return native;
83 });
84 vector::populateVectorUnrollPatterns(patterns, vectorOptions);
85 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
86 LDBG() << "Unroll failed.";
87 return signalPassFailure();
88 }
89 }
90
91 // Use vector linearization patterns
92 {
93 MLIRContext &context = getContext();
94 TypeConverter converter;
96 ConversionTarget target(context);
97 vector::populateForVectorLinearize(converter, target);
98 vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
99 vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
100 patterns);
102 target);
103 if (failed(applyPartialConversion(getOperation(), target,
104 std::move(patterns)))) {
105 LDBG() << "Linearization failed.";
106 return signalPassFailure();
107 }
108 }
109 }
110};
111} // namespace
b getContext())
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)