MLIR  21.0.0git
XeGPUUtils.cpp
Go to the documentation of this file.
1 //===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utility methods for working with the XeGPU dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/ValueRange.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include <cstdint>
25 #include <numeric>
26 
27 using namespace mlir;
28 
29 /// convert ArrayRef<ValueRange> into SmallVector<Value>
31  SmallVector<Value> result;
32  for (const auto &vals : values)
33  llvm::append_range(result, vals);
34  return result;
35 }
36 
37 FailureOr<VectorType>
38 mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
39  auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
40  // It only works for subgroup level layout, which only has lane_layout
41  // and lane_data, and is to distribute a SIMD code into SIMT code.
42  if (!layout || !layout.isSgLayout())
43  return failure();
44 
45  SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
46  SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
47  auto tdescShape = tdescTy.getShape();
48  auto elementType = tdescTy.getElementType();
49 
50  // compute sgSize by multiply elements of laneLayout
51  // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
52  // e.g. for 1D layout, sgSize = laneLayout[0]
53  auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,
54  std::multiplies<int64_t>());
55 
56  // Case 1: regular loads/stores
57  auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
58  if (scatterAttr) {
59  auto chunkSize = scatterAttr.getChunkSize().getInt();
60  // Verify if the first dimension of the tensor descriptor shape is
61  // distributable.
62  assert(tdescShape[0] == laneLayout[0] &&
63  "tensor descriptor shape is not distributable");
64  return VectorType::get({chunkSize}, elementType);
65  }
66 
67  // Case 2: block loads/stores
68  // Check if the tensor descriptor shape is distributable.
69  int64_t tensorSize = 1;
70  for (auto [tdescDim, laneDim, laneDataDim] :
71  llvm::zip_equal(tdescShape, laneLayout, laneData)) {
72  assert((tdescDim % (laneDim * laneDataDim) == 0) &&
73  "tensor descriptor shape is not distributable");
74  tensorSize *= tdescDim;
75  }
76  // tensorSize must be adjusted for array_length.
77  tensorSize *= tdescTy.getArrayLength();
78 
79  return VectorType::get({tensorSize / sgSize}, elementType);
80 }
81 
82 FailureOr<VectorType>
83 mlir::xegpu::getDistributedVectorType(VectorType originalType,
84  xegpu::LayoutAttr layout) {
85  int64_t rank = originalType.getRank();
86  // Distributed vector type is only supported for 1D, 2D and 3D vectors.
87  if (rank < 1 || rank > 3)
88  return failure();
89  ArrayRef<int64_t> shape = originalType.getShape();
90  // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
91  // of the 3D vector.
92  int arrayLength = 1;
93  if (rank == 3) {
94  arrayLength = shape[0];
95  shape = shape.drop_front();
96  }
97  auto helperTdescTy = xegpu::TensorDescType::get(
98  shape, originalType.getElementType(), arrayLength,
99  /*boundary_check=*/true,
100  /*memory_space=*/xegpu::MemorySpace::Global, layout);
101  return xegpu::getDistributedVectorType(helperTdescTy);
102 }
103 
104 std::string xegpu::getLayoutName(const OpOperand &operand) {
105  const StringRef prefix("layout_operand_");
106  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
107  return llvm::formatv("{0}{1}", prefix, idx).str();
108 }
109 
110 std::string xegpu::getLayoutName(const OpResult result) {
111  const StringRef prefix = "layout_result_";
112  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
113 }
114 
115 xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
116  if (!value)
117  return nullptr;
118 
119  if (auto tdescTy =
120  dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
121  return tdescTy.getLayoutAttr();
122 
123  if (auto result = dyn_cast<OpResult>(value)) {
124  Operation *defOp = result.getDefiningOp();
125  assert(defOp && "result must have a defining op");
126 
127  // for LoadNdOp, the layout is stored in the tensor descriptor
128  if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
129  return getLayoutAttr(loadNd.getTensorDesc());
130 
131  std::string layoutName = getLayoutName(result);
132  if (defOp->hasAttr(layoutName))
133  return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
134  }
135 
136  if (auto arg = dyn_cast<BlockArgument>(value)) {
137  auto parentOp = arg.getOwner()->getParentOp();
138  if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
139  OpOperand *tiedInit = loop.getTiedLoopInit(arg);
140  return getLayoutAttr(tiedInit->get());
141  }
142  }
143 
144  return nullptr;
145 }
146 
147 xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
148  Operation *op = opr.getOwner();
149  std::string layoutName = xegpu::getLayoutName(opr);
150  if (op->hasAttr(layoutName))
151  return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
152  return getLayoutAttr(opr.get());
153 }
154 
155 template <typename T, typename>
156 void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
157  Operation *owner = operandOrResult.getOwner();
158  std::string name = xegpu::getLayoutName(operandOrResult);
159  if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
160  owner->setAttr(name, layout);
161 }
162 
163 // Explicit instantiation for OpResult
164 template void
165 xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
166  const mlir::xegpu::LayoutAttr layout);
167 
168 // Explicit instantiation for OpOperand
169 template void
170 xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
171  const mlir::xegpu::LayoutAttr layout);
172 
174  function_ref<LayoutAttr(Value)> getLayoutImpl) {
175  op->walk([&](Operation *nestOp) {
176  for (OpOperand &opr : nestOp->getOpOperands()) {
177  auto layout = getLayoutImpl(opr.get());
178  setLayoutAttr(opr, layout);
179  }
180  for (OpResult result : nestOp->getOpResults()) {
181  auto layout = getLayoutImpl(result);
182  setLayoutAttr(result, layout);
183  }
184  });
185 }
186 
189  Value value, ArrayRef<int64_t> shape) {
190  auto vecTy = dyn_cast<VectorType>(value.getType());
191  if (!vecTy)
192  return {value};
193 
194  ArrayRef<int64_t> srcShape = vecTy.getShape();
195  if (!computeShapeRatio(srcShape, shape))
196  return {value};
197 
198  SmallVector<Value> result;
199  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
200  SmallVector<int64_t> staticStrides(offsets.size(), 1);
201  result.push_back(builder.create<vector::ExtractStridedSliceOp>(
202  loc, value, offsets, shape, staticStrides));
203  }
204 
205  return result;
206 }
207 
209  ValueRange values,
210  ArrayRef<int64_t> shape) {
211  VectorType inputTy = dyn_cast<VectorType>(values[0].getType());
212  assert(llvm::all_of(values.getTypes(),
213  [&](Type type) { return type == inputTy; }) &&
214  "values must be of the same VectorType");
215 
216  Type elemTy = inputTy.getElementType();
217  ArrayRef<int64_t> tileShape = inputTy.getShape();
218 
219  VectorType resultTy = VectorType::get(shape, elemTy);
220  auto zeroAttr = builder.getZeroAttr(elemTy);
221  Value result = builder.create<arith::ConstantOp>(
222  loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr));
223 
224  for (auto [src, offsets] :
225  llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
226  SmallVector<int64_t> staticStrides(offsets.size(), 1);
227  result = builder.create<vector::InsertStridedSliceOp>(
228  loc, src, result, offsets, staticStrides);
229  }
230  return result;
231 }
232 
234  Operation *op, TypeConverter converter) {
235  MLIRContext *context = op->getContext();
236 
237  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
238  Location loc) -> Value {
239  return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
240  .getResult(0);
241  };
242 
243  { // convert VectorType to RankedTensorType for SCF Structural ops
244  TypeConverter converter;
245  converter.addConversion([](Type type) -> Type { return type; });
246  converter.addConversion([](VectorType type) -> Type {
247  return RankedTensorType::get(type.getShape(), type.getElementType());
248  });
249  converter.addSourceMaterialization(materializeCast);
250  converter.addTargetMaterialization(materializeCast);
251 
252  mlir::ConversionTarget target(*context);
253  target.addLegalOp<UnrealizedConversionCastOp>();
254 
257  target);
258  (void)mlir::applyPartialConversion(op, target, std::move(patterns));
259  }
260 
261  { // propagate the layout attribute to RankedTensorType by checking
262  // BuiltInUnrealizedCastOps
263  // for VectorType to RankedTensorType cast.
264  op->walk([](UnrealizedConversionCastOp castOp) {
265  if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
266  return WalkResult::skip();
267 
268  Value input = castOp.getInputs()[0];
269  Value result = castOp.getResults()[0];
270  auto inputTy = dyn_cast<VectorType>(input.getType());
271  auto resultTy = dyn_cast<RankedTensorType>(result.getType());
272 
273  // Only look at ops casting from VectorType to RankedTensorType
274  if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
275  return WalkResult::skip();
276 
277  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
278  if (!layout)
279  return WalkResult::skip();
280 
281  RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
282  result.setType(newTy);
283 
284  // update the arguments if user is a LoopLike op.
285  for (OpOperand &use : result.getUses()) {
286  if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
287  BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
288  arg.setType(newTy);
289  }
290  // whileOp has two regions, the BlockArgument of the after region
291  // is not exposed by LoopLikeOpInterface
292  if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
293  unsigned idx = use.getOperandNumber();
294  BlockArgument arg = whileOp.getAfterArguments()[idx];
295  arg.setType(newTy);
296  }
297  }
298  return WalkResult::advance();
299  });
300 
301  // using yieldOp as anchor to update the result type of its ParentOp
302  op->walk([](scf::YieldOp yieldOp) {
303  Operation *parentOp = yieldOp->getParentOp();
304  for (OpResult r : parentOp->getOpResults()) {
305  unsigned idx = r.getResultNumber();
306  Type resultTy = r.getType();
307  Type yieldTy = yieldOp.getResults()[idx].getType();
308  if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
309  r.setType(yieldTy);
310  }
311  });
312  }
313 
314  { // perform the conversion from RankedTensorType to VectorType based on the
315  // LayoutAttr
316 
317  // Handle the UnrealizedConversionCastOp introduced by the first step.
318  // For vector->RankedTensorType, it will simply forward the inputs.
319  // For RankedTensorType->vector, it will update the inputs with the
320  // one from the adaptor.
321  class UnrealizedConversionCastOpPattern
322  : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
323  using OpConversionPattern<
324  mlir::UnrealizedConversionCastOp>::OpConversionPattern;
325 
326  mlir::LogicalResult
327  matchAndRewrite(mlir::UnrealizedConversionCastOp op,
328  OneToNOpAdaptor adaptor,
329  ConversionPatternRewriter &rewriter) const override {
330  auto inputs = op.getOperands();
331  auto outputs = op.getOutputs();
332 
333  if (inputs.size() != 1 || outputs.size() != 1)
334  return failure();
335 
336  auto inputTy = inputs[0].getType();
337  auto outputTy = outputs[0].getType();
338 
339  if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
340  rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
341  return success();
342  }
343 
344  if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
345  SmallVector<Value> values = flattenValues(adaptor.getInputs());
346  auto newOp = rewriter.create<UnrealizedConversionCastOp>(
347  op.getLoc(), outputTy, values);
348  rewriter.replaceOp(op, newOp);
349  return success();
350  }
351  return failure();
352  }
353  };
354 
355  converter.addSourceMaterialization(materializeCast);
356  converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
357  ValueRange inputs, Location loc) {
358  return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
359  .getResults();
360  });
361 
362  mlir::ConversionTarget target(*context);
363  target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
364  [](UnrealizedConversionCastOp op) {
365  auto isTensorTy = [](Type type) {
366  return isa<RankedTensorType>(type);
367  };
368  return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
369  llvm::none_of(op->getResultTypes(), isTensorTy);
370  });
372  patterns.insert<UnrealizedConversionCastOpPattern>(context);
374  target);
375  (void)mlir::applyPartialConversion(op, target, std::move(patterns));
376  }
377 }
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
convert ArrayRef<ValueRange> into SmallVector<Value>
Definition: XeGPUUtils.cpp:30
This class represents an argument of a Block.
Definition: Value.h:309
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
bool hasAttrOfType(NameT &&name)
Definition: Operation.h:575
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
operand_type_range getOperandTypes()
Definition: Operation.h:397
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
result_range getOpResults()
Definition: Operation.h:420
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
Definition: XeGPUUtils.cpp:208
LayoutAttr getLayoutAttr(const Value value)
Retrieves the LayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:115
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
Definition: XeGPUUtils.cpp:156
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
Definition: XeGPUUtils.cpp:104
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
Definition: XeGPUUtils.cpp:233
void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)
Set the LayoutAttr for each OpOperand and OpResult of the given operation.
Definition: XeGPUUtils.cpp:173
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Definition: XeGPUUtils.cpp:188
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Definition: XeGPUUtils.cpp:38
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.