MLIR  21.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1 //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #include "llvm/Support/Debug.h"
16 
17 #define DEBUG_TYPE "xegpu"
18 
19 namespace mlir {
20 namespace xegpu {
21 
23  SmallVector<int64_t> &shape) {
24  SmallVector<int64_t> old = shape;
25  for (size_t i = 0; i < trans.size(); i++)
26  shape[i] = old[trans[i]];
27 }
28 
29 template <typename T>
30 static std::string makeString(T array, bool breakline = false) {
31  std::string buf;
32  buf.clear();
33  llvm::raw_string_ostream os(buf);
34  os << "[";
35  for (size_t i = 1; i < array.size(); i++) {
36  os << array[i - 1] << ", ";
37  if (breakline)
38  os << "\n\t\t";
39  }
40  os << array.back() << "]";
41  return buf;
42 }
43 
46  if (auto ty = llvm::dyn_cast<ShapedType>(type))
47  shape = SmallVector<int64_t>(ty.getShape());
48  else
49  shape.push_back(1);
50  return shape;
51 }
52 
53 static int64_t getRankOf(Value val) {
54  auto type = val.getType();
55  if (auto ty = llvm::dyn_cast<ShapedType>(type))
56  return ty.getRank();
57  return 0;
58 }
59 
60 static bool isReadHintOrNone(const CachePolicyAttr &attr) {
61  if (!attr)
62  return true;
63  auto kind = attr.getValue();
64  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65  kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
66 }
67 
68 static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
69  if (!attr)
70  return true;
71  auto kind = attr.getValue();
72  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73  kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74 }
75 
76 static LogicalResult
77 isValidGatherScatterParams(Type maskTy, VectorType valueTy,
78  TensorDescType tdescTy, UnitAttr transposeAttr,
80 
81  if (!tdescTy.isScattered())
82  return emitError() << "Expects a scattered TensorDesc.";
83 
84  if (!valueTy)
85  return emitError() << "Expecting a vector type result.";
86 
87  auto maskShape = getShapeOf(maskTy);
88  auto valueShape = getShapeOf(valueTy);
89  auto tdescShape = getShapeOf(tdescTy);
90  auto chunkSize = tdescTy.getChunkSize();
91 
92  if (valueTy.getElementType() != tdescTy.getElementType())
93  return emitError()
94  << "Value should have the same element type as TensorDesc.";
95 
96  if (tdescShape[0] != maskShape[0])
97  return emitError()
98  << "dim-0 of the Mask and TensorDesc should be the same.";
99 
100  // a valid shape for SIMT case
101  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
102  if (tdescTy.getLayoutAttr())
103  return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
104  if (transposeAttr)
105  return emitError() << "doesn't need TransposeAttr for SIMT code";
106  return success();
107  }
108 
109  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
110  if (!transposeAttr)
111  return emitError() << "rank-2 tensor has to be transposed.";
112  transpose({1, 0}, tdescShape);
113  }
114 
115  if (tdescShape != valueShape)
116  return emitError() << "Value shape " << makeString(valueShape)
117  << " is neither a valid distribution for SIMT nor "
118  "consistent with the tensor descriptor for SIMD "
119  << tdescTy;
120  return success();
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // XeGPU_CreateNdDescOp
125 //===----------------------------------------------------------------------===//
126 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
127  Type tdesc, TypedValue<MemRefType> source,
129  [[maybe_unused]] auto ty = source.getType();
130  assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
131 
132  llvm::SmallVector<int64_t> staticOffsets;
133  llvm::SmallVector<Value> dynamicOffsets;
134  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
135 
136  build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
137  ValueRange({}) /* empty dynamic shape */,
138  ValueRange({}) /* empty dynamic strides */,
139  staticOffsets /* const offsets */, {} /* empty const shape*/,
140  {} /* empty const strides*/);
141 }
142 
143 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
144  Type tdesc, TypedValue<MemRefType> source,
148  assert(shape.size() && offsets.size() && strides.size() &&
149  shape.size() == strides.size() && shape.size() == offsets.size());
150 
151  llvm::SmallVector<int64_t> staticOffsets;
152  llvm::SmallVector<int64_t> staticShape;
153  llvm::SmallVector<int64_t> staticStrides;
154  llvm::SmallVector<Value> dynamicOffsets;
155  llvm::SmallVector<Value> dynamicShape;
156  llvm::SmallVector<Value> dynamicStrides;
157 
158  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
159  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
160  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
161 
162  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
163  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
164  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
165 
166  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
167  dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
168 }
169 
170 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
171  Type tdesc, TypedValue<IntegerType> source,
175  assert(shape.size() && offsets.size() && strides.size() &&
176  shape.size() == strides.size() && shape.size() == offsets.size());
177 
178  llvm::SmallVector<int64_t> staticOffsets;
179  llvm::SmallVector<int64_t> staticShape;
180  llvm::SmallVector<int64_t> staticStrides;
181  llvm::SmallVector<Value> dynamicOffsets;
182  llvm::SmallVector<Value> dynamicShape;
183  llvm::SmallVector<Value> dynamicStrides;
184 
185  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
186  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
187  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
188 
189  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
190  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
191  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
192 
193  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
194  dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
195 }
196 
197 LogicalResult CreateNdDescOp::verify() {
198  auto rank = (int64_t)getMixedOffsets().size();
199  bool invalidRank = false;
200  bool invalidElemTy = false;
201 
202  // Memory space of created TensorDesc should match with the source.
203  // Both source and TensorDesc are considered for global memory by default,
204  // if the memory scope attr is not specified. If source is an integer,
205  // it is considered as ptr to global memory.
206  auto srcMemorySpace = getSourceMemorySpace();
207  auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
208  if (srcMemorySpace != tdescMemorySpace)
209  return emitOpError("Memory space mismatch.")
210  << " Source: " << srcMemorySpace
211  << ", TensorDesc: " << tdescMemorySpace;
212 
213  // check source type matches the rank if it is a memref.
214  // It also should have the same ElementType as TensorDesc.
215  auto memrefTy = dyn_cast<MemRefType>(getSourceType());
216  if (memrefTy) {
217  invalidRank |= (memrefTy.getRank() != rank);
218  invalidElemTy |= memrefTy.getElementType() != getElementType();
219  }
220 
221  // mismatches among shape, strides, and offsets are
222  // already handeled by OffsetSizeAndStrideOpInterface.
223  // So they are not check here.
224  if (invalidRank)
225  return emitOpError(
226  "Expecting the rank of shape, strides, offsets, and source (if source "
227  "is a memref) should match with each other.");
228 
229  // check result TensorDesc rank
230  invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
231 
232  if (invalidRank)
233  return emitOpError(
234  "Expecting the TensorDesc rank is up to 2 and not greater than the "
235  "ranks of shape, strides, offsets or the memref source.");
236 
237  if (invalidElemTy)
238  return emitOpError("TensorDesc should have the same element "
239  "type with the source if it is a memref.\n");
240 
241  if (getType().isScattered())
242  return emitOpError("Expects a non-scattered TensorDesc.\n");
243 
244  return success();
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // XeGPU_PrefetchNdOp
249 //===----------------------------------------------------------------------===//
250 LogicalResult PrefetchNdOp::verify() {
251  auto tdescTy = getTensorDescType();
252  if (tdescTy.isScattered())
253  return emitOpError("Expects a non-scattered TensorDesc.\n");
254 
255  if (!isReadHintOrNone(getL1HintAttr()))
256  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
257 
258  if (!isReadHintOrNone(getL2HintAttr()))
259  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
260 
261  if (!isReadHintOrNone(getL3HintAttr()))
262  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
263 
264  return success();
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // XeGPU_LoadNdOp
269 //===----------------------------------------------------------------------===//
270 LogicalResult LoadNdOp::verify() {
271  auto tdescTy = getTensorDescType();
272  auto valueTy = getType();
273 
274  if (tdescTy.getRank() > 2)
275  return emitOpError("Expecting a 1D/2D TensorDesc.\n");
276 
277  if (tdescTy.isScattered())
278  return emitOpError("Expects a non-scattered TensorDesc.\n");
279 
280  if (!valueTy)
281  return emitOpError("Invalid result, it should be a VectorType.\n");
282 
283  if (!isReadHintOrNone(getL1HintAttr()))
284  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
285 
286  if (!isReadHintOrNone(getL2HintAttr()))
287  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
288 
289  if (!isReadHintOrNone(getL3HintAttr()))
290  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
291 
292  int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
293  int valueElems = valueTy.getNumElements();
294 
295  // If the result vector is 1D and has less elements than the tensor
296  // descriptor, it is supposed to be a SIMT op. The layout attribute in
297  // tensor_desc is not needed.
298  if (valueElems < tdescElems && valueTy.getRank() == 1) {
299  // SIMT mode doesn't need LayoutAttr.
300  if (tdescTy.getLayoutAttr())
301  return emitOpError()
302  << "TensorDesc doesn't need LayoutAttr for SIMT code";
303 
304  // For SIMT code, the load is evenly distributed across all lanes in a
305  // subgroup. Since subgroup size is arch dependent, we only check even
306  // distribution here.
307  if (tdescElems % valueElems)
308  return emitOpError()
309  << "Result shape " << makeString(getShapeOf(valueTy))
310  << " is not a valid distribution for tensor descriptor "
311  << tdescTy;
312 
313  return success();
314  }
315 
316  // Check SIMD mode.
317  auto tdescShape = getShapeOf(tdescTy);
318  auto valueShape = getShapeOf(valueTy);
319 
320  if (getTranspose()) {
321  auto trans = getTranspose().value();
322 
323  // Make sure the transpose value is valid.
324  bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
325  return t >= 0 && t < tdescTy.getRank();
326  });
327 
328  if (valid)
329  transpose(trans, tdescShape);
330  else
331  mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
332  }
333 
334  if (getPacked()) {
335  if (tdescTy.getRank() == 2) {
336  const int axis = 0;
337  auto vnni_factor = valueShape.back();
338  tdescShape[axis] /= vnni_factor;
339  tdescShape.push_back(vnni_factor);
340  } else {
341  mlir::emitWarning(getLoc())
342  << "Invalid Packed Attr. It is ignored (available for 2D "
343  "TensorDesc only).";
344  }
345  }
346 
347  auto array_len = tdescTy.getArrayLength();
348  if (array_len > 1) {
349  tdescShape.insert(tdescShape.begin(), array_len);
350  }
351 
352  if (tdescShape != valueShape) {
353  return emitOpError() << "Result shape " << makeString(valueShape)
354  << " is not consistent with tensor descriptor "
355  << tdescTy;
356  }
357 
358  return success();
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // XeGPU_StoreNdOp
363 //===----------------------------------------------------------------------===//
364 LogicalResult StoreNdOp::verify() {
365  auto dstTy = getTensorDescType(); // Tile
366  auto valTy = getValueType(); // Vector
367 
368  if (dstTy.getRank() > 2)
369  return emitOpError("Expecting a 1D/2D TensorDesc.\n");
370 
371  if (dstTy.isScattered())
372  return emitOpError("Expects a non-scattered TensorDesc.\n");
373 
374  if (!valTy)
375  return emitOpError("Expecting a VectorType result.\n");
376 
377  if (!isWriteHintOrNone(getL1HintAttr()))
378  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
379 
380  if (!isWriteHintOrNone(getL2HintAttr()))
381  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
382 
383  if (!isWriteHintOrNone(getL3HintAttr()))
384  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
385 
386  auto array_len = dstTy.getArrayLength();
387  if (array_len > 1)
388  return emitOpError("array length is not supported by store_nd.\n");
389 
390  auto tdescElems = dstTy.getNumElements();
391  auto valueElems = valTy.getNumElements();
392 
393  // Similar to LoadNdOp, if the value vector is 1D and has less elements than
394  // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
395  // in tensor_desc is not needed.
396  if (valTy.getRank() == 1 && valueElems < tdescElems) {
397  // SIMT mode doesn't need LayoutAttr.
398  if (dstTy.getLayoutAttr())
399  return emitOpError()
400  << "TensorDesc doesn't need LayoutAttr for SIMT code";
401 
402  if (tdescElems % valueElems) {
403  return emitOpError()
404  << "Value shape " << makeString(getShapeOf(valTy))
405  << " is not a valid distribution for tensor descriptor " << dstTy;
406  }
407  return success();
408  }
409 
410  // SIMD code should have the same shape as the tensor descriptor.
411  auto tdescShape = getShapeOf(dstTy);
412  auto valueShape = getShapeOf(valTy);
413  if (tdescShape != valueShape) {
414  return emitOpError() << "Value shape " << makeString(valueShape)
415  << " is not consistent with tensor descriptor "
416  << dstTy;
417  }
418 
419  return success();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // XeGPU_UpdateNDOffsetOp
424 //===----------------------------------------------------------------------===//
425 LogicalResult UpdateNdOffsetOp::verify() {
426  auto ty = getTensorDescType();
427  if (ty.isScattered())
428  return emitOpError("Expects a non-scattered TensorDesc.\n");
429 
430  // number of offsets specified must match the rank of the tensor descriptor
431  if (ty.getRank() != (int64_t)getNumOffsets()) {
432  return emitOpError("Invalid number of offsets.");
433  }
434  return success();
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // XeGPU_CreateDescOp
439 //===----------------------------------------------------------------------===//
440 
441 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
442  TensorDescType TensorDesc, Value source,
444  auto loc = source.getLoc();
445  int64_t size = static_cast<int64_t>(offsets.size());
446  auto type = VectorType::get(size, builder.getIndexType());
447  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
448  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
449  build(builder, state, TensorDesc, source, offset);
450 }
451 
452 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
453  TensorDescType TensorDesc, Value source,
454  llvm::ArrayRef<int64_t> offsets) {
455  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
456  build(builder, state, TensorDesc, source, ofrs);
457 }
458 
459 LogicalResult CreateDescOp::verify() {
460  auto tdescTy = getTensorDescType();
461 
462  if (getRankOf(getSource()) > 1)
463  return emitOpError(
464  "Expecting the source is a 1D memref or pointer (uint64_t).");
465 
466  if (!tdescTy.isScattered())
467  return emitOpError("Expects a scattered TensorDesc.\n");
468 
469  // Memory space of created TensorDesc should match with the source.
470  // Both source and TensorDesc are considered for global memory by default,
471  // if the memory scope attr is not specified. If source is an integer,
472  // it is considered as ptr to global memory.
473  auto srcMemorySpace = getSourceMemorySpace();
474  auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
475  if (srcMemorySpace != tdescMemorySpace)
476  return emitOpError("Memory space mismatch.")
477  << " Source: " << srcMemorySpace
478  << ", TensorDesc: " << tdescMemorySpace;
479 
480  // check total size
481  auto chunkSize = tdescTy.getChunkSize();
482  auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
483  auto bitsPerLane = elemBits * chunkSize;
484  if (chunkSize > 1 && bitsPerLane % 32) {
485  // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
486  // For 32-bit data, the hardware can support larger larger chunk size. So
487  // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
488  // But this requires the total size is 32 bit aligned to make the
489  // optimization work.
490  return emitOpError(
491  "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
492  }
493 
494  auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
495  if (elemBits * tdescTy.getNumElements() > lscConstraints)
496  return emitOpError("total access size (simd_lanes * chunk_size * "
497  "sizeof(elemTy)) is upto 512 bytes.");
498 
499  SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
500  if (chunkSize != 1)
501  shape.push_back(chunkSize);
502 
503  auto tdescShape = getShapeOf(tdescTy);
504  if (shape != tdescShape)
505  return emitOpError("Incorrect TensorDesc shape. ")
506  << "Expected is " << makeString(shape) << "\n";
507 
508  return success();
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // XeGPU_PrefetchOp
513 //===----------------------------------------------------------------------===//
514 LogicalResult PrefetchOp::verify() {
515  auto tdescTy = getTensorDescType();
516  if (!tdescTy.isScattered())
517  return emitOpError("Expects a scattered TensorDesc.\n");
518 
519  if (!isReadHintOrNone(getL1HintAttr()))
520  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
521 
522  if (!isReadHintOrNone(getL2HintAttr()))
523  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
524 
525  if (!isReadHintOrNone(getL3HintAttr()))
526  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
527 
528  return success();
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // XeGPU_LoadGatherOp
533 //===----------------------------------------------------------------------===//
534 LogicalResult LoadGatherOp::verify() {
535  auto tdescTy = getTensorDescType();
536  auto maskTy = getMaskType();
537  auto valueTy = getValueType();
538 
539  if (!isReadHintOrNone(getL1HintAttr()))
540  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
541 
542  if (!isReadHintOrNone(getL2HintAttr()))
543  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
544 
545  if (!isReadHintOrNone(getL3HintAttr()))
546  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
547 
548  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
549  getTransposeAttr(),
550  [&]() { return emitOpError(); });
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // XeGPU_StoreScatterOp
555 //===----------------------------------------------------------------------===//
556 LogicalResult StoreScatterOp::verify() {
557  auto tdescTy = getTensorDescType();
558  auto maskTy = getMaskType();
559  auto valueTy = getValueType();
560 
561  if (!isWriteHintOrNone(getL1HintAttr()))
562  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
563 
564  if (!isWriteHintOrNone(getL2HintAttr()))
565  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
566 
567  if (!isWriteHintOrNone(getL3HintAttr()))
568  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
569 
570  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
571  getTransposeAttr(),
572  [&]() { return emitOpError(); });
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // XeGPU_UpdateOffsetOp
577 //===----------------------------------------------------------------------===//
578 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
579  mlir::Value tensorDesc,
581  auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
582  assert(tdescTy && "Expecting the source is a TensorDescType value.");
583  auto loc = tensorDesc.getLoc();
584  int64_t size = static_cast<int64_t>(offsets.size());
585  auto type = VectorType::get({size}, builder.getIndexType());
586  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
587  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
588  build(builder, state, tdescTy, tensorDesc, offset);
589 }
590 
591 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
592  Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
593  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
594  build(builder, state, tensorDesc, ofrs);
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // XeGPU_DpasOp
599 //===----------------------------------------------------------------------===//
600 LogicalResult DpasOp::verify() {
601  int64_t lhsRank = getLhsType().getRank();
602  int64_t rhsRank = getRhsType().getRank();
603  int64_t resRank = getResultType().getRank();
604  auto lhsShape = getLhsType().getShape();
605  auto rhsShape = getRhsType().getShape();
606  auto resShape = getResultType().getShape();
607 
608  if (getAcc() && getAcc().getType() != getResultType())
609  return emitOpError("Expecting the acc type to be the same as result.");
610 
611  // SIMT code: the size of the B operand has to be a multiple of 32 bits.
612  // It skips the semantic check since lack of architecture information.
613  // Users need to ensure the correctness.
614  if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
615  auto numElems = getRhsType().getNumElements();
616  auto elemTy = getRhsType().getElementType();
617  auto factor = 32 / elemTy.getIntOrFloatBitWidth();
618  if (numElems % factor != 0)
619  return emitOpError("Expecting B operand to be a multiple of 32 bits.");
620  return success();
621  }
622 
623  // SIMD code
624  if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
625  return emitOpError(
626  "expecting lhs and result to be a 2D vector, and rhs to be either "
627  "2D or 3D (packed) vector.");
628  auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
629  if (bK != lhsShape[1])
630  return emitOpError("K-dimension mismatch.");
631  if (lhsShape[0] != resShape[0])
632  return emitOpError("M-dimension mismatch.");
633  if (rhsShape[1] != resShape[1])
634  return emitOpError("N-dimension mismatch.");
635 
636  return success();
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // XeGPU_ConvertLayoutOp
641 //===----------------------------------------------------------------------===//
642 LogicalResult ConvertLayoutOp::verify() {
643  auto srcMap = getSrcMapAttr();
644  auto resMap = getResMapAttr();
645  if (!srcMap)
646  return emitOpError("expected srcMap.");
647  if (!resMap)
648  return emitOpError("expected resMap.");
649 
650  if (srcMap == resMap)
651  return emitOpError("expected different srcMap and resMap.");
652 
653  // both srcMap and resMap should be WgLayout or SgLayout at the same time.
654  if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
655  (!srcMap.isSgLayout() || !resMap.isSgLayout()))
656  return emitOpError(
657  "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
658 
659  auto shape = getSource().getType().getShape();
660  if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
661  return emitOpError("invalid srcMap, data cannot be evenly distributed.");
662 
663  if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
664  return emitOpError("invalid resMap, data cannot be evenly distributed.");
665 
666  return mlir::success();
667 }
668 
669 } // namespace xegpu
670 } // namespace mlir
671 
672 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
673 #define GET_OP_CLASSES
674 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
union mlir::linalg::@1192::ArityGroupAndKind::Kind kind
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class helps build Operations.
Definition: Builders.h:205
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static std::string makeString(T array, bool breakline=false)
Definition: XeGPUOps.cpp:30
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref< InFlightDiagnostic()> emitError)
Definition: XeGPUOps.cpp:77
static int64_t getRankOf(Value val)
Definition: XeGPUOps.cpp:53
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:68
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:60
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
static SmallVector< int64_t > getShapeOf(Type type)
Definition: XeGPUOps.cpp:44
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
This represents an operation in an abstracted form, suitable for use with the builder APIs.