MLIR  22.0.0git
XeGPUVectorLinearize.cpp
Go to the documentation of this file.
1 //===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/Pass/Pass.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/DebugLog.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #include <optional>
24 
25 namespace mlir {
26 namespace xegpu {
27 #define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
28 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29 } // namespace xegpu
30 } // namespace mlir
31 
32 #define DEBUG_TYPE "xegpu-vector-linearize"
33 
34 using namespace mlir;
35 
36 namespace {
37 struct XeGPUVectorLinearizePass final
38  : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
39  void runOnOperation() override {
40  // vector.broadcast and vector.gather requires progressive lowering
41  {
46  // vector.transpose lowering
47  // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
49  patterns, vector::VectorTransposeLowering::Shuffle16x16);
50  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
51  return signalPassFailure();
52  }
53 
54  // Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
55  // <1x1x...x1xdk>.
56  {
58  vector::UnrollVectorOptions vectorOptions;
59  vectorOptions.setNativeShapeFn(
60  [](Operation *op) -> std::optional<SmallVector<int64_t>> {
61  auto extractVectorType = [](Operation *op) -> VectorType {
62  if (auto loadOp = dyn_cast<vector::LoadOp>(op))
63  return loadOp.getVectorType();
64  if (auto storeOp = dyn_cast<vector::StoreOp>(op))
65  return storeOp.getVectorType();
66  return nullptr;
67  };
68 
69  VectorType vecType = extractVectorType(op);
70  if (!vecType)
71  return std::nullopt;
72 
73  // Only handle rank >= 2 so we actually unroll something.
74  int64_t rank = vecType.getRank();
75  if (rank < 2)
76  return std::nullopt;
77 
78  ArrayRef<int64_t> shape = vecType.getShape();
79  // Produce native shape: 1 x 1 x ... x (original last dim).
80  SmallVector<int64_t> native(rank, 1);
81  native.back() = shape.back();
82  return native;
83  });
84  vector::populateVectorUnrollPatterns(patterns, vectorOptions);
85  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
86  LDBG() << "Unroll failed.";
87  return signalPassFailure();
88  }
89  }
90 
91  // Use vector linearization patterns
92  {
93  MLIRContext &context = getContext();
94  TypeConverter converter;
95  RewritePatternSet patterns(&context);
96  ConversionTarget target(context);
97  vector::populateForVectorLinearize(converter, target);
98  vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
99  vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
100  patterns);
102  target);
103  if (failed(applyPartialConversion(getOperation(), target,
104  std::move(patterns)))) {
105  LDBG() << "Linearization failed.";
106  return signalPassFailure();
107  }
108  }
109  }
110 };
111 } // namespace
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Type conversion class.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)