MLIR  21.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1 //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/TypeUtilities.h"
15 
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "xegpu"
19 
20 namespace mlir {
21 namespace xegpu {
22 
24  SmallVector<int64_t> &shape) {
25  SmallVector<int64_t> old = shape;
26  for (size_t i = 0; i < trans.size(); i++)
27  shape[i] = old[trans[i]];
28 }
29 
30 template <typename T>
31 static std::string makeString(T array, bool breakline = false) {
32  std::string buf;
33  buf.clear();
34  llvm::raw_string_ostream os(buf);
35  os << "[";
36  for (size_t i = 1; i < array.size(); i++) {
37  os << array[i - 1] << ", ";
38  if (breakline)
39  os << "\n\t\t";
40  }
41  os << array.back() << "]";
42  return buf;
43 }
44 
47  if (auto ty = llvm::dyn_cast<ShapedType>(type))
48  shape = SmallVector<int64_t>(ty.getShape());
49  else
50  shape.push_back(1);
51  return shape;
52 }
53 
54 static int64_t getRankOf(Value val) {
55  auto type = val.getType();
56  if (auto ty = llvm::dyn_cast<ShapedType>(type))
57  return ty.getRank();
58  return 0;
59 }
60 
61 static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62  if (!attr)
63  return true;
64  auto kind = attr.getValue();
65  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66  kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67 }
68 
69 static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70  if (!attr)
71  return true;
72  auto kind = attr.getValue();
73  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74  kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75 }
76 
77 static LogicalResult
78 isValidGatherScatterParams(Type maskTy, VectorType valueTy,
79  TensorDescType tdescTy, UnitAttr transposeAttr,
81 
82  if (!tdescTy.isScattered())
83  return emitError() << "Expects a scattered TensorDesc.";
84 
85  if (!valueTy)
86  return emitError() << "Expecting a vector type result.";
87 
88  auto maskShape = getShapeOf(maskTy);
89  auto valueShape = getShapeOf(valueTy);
90  auto tdescShape = getShapeOf(tdescTy);
91  auto chunkSize = tdescTy.getChunkSize();
92 
93  if (valueTy.getElementType() != tdescTy.getElementType())
94  return emitError()
95  << "Value should have the same element type as TensorDesc.";
96 
97  if (tdescShape[0] != maskShape[0])
98  return emitError()
99  << "dim-0 of the Mask and TensorDesc should be the same.";
100 
101  // a valid shape for SIMT case
102  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
103  if (tdescTy.getLayoutAttr())
104  return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
105  if (transposeAttr)
106  return emitError() << "doesn't need TransposeAttr for SIMT code";
107  return success();
108  }
109 
110  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
111  if (!transposeAttr)
112  return emitError() << "rank-2 tensor has to be transposed.";
113  transpose({1, 0}, tdescShape);
114  }
115 
116  if (tdescShape != valueShape)
117  return emitError() << "Value shape " << makeString(valueShape)
118  << " is neither a valid distribution for SIMT nor "
119  "consistent with the tensor descriptor for SIMD "
120  << tdescTy;
121  return success();
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // XeGPU_CreateNdDescOp
126 //===----------------------------------------------------------------------===//
127 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
128  Type tdesc, TypedValue<MemRefType> source,
130  [[maybe_unused]] auto ty = source.getType();
131  assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
132 
133  llvm::SmallVector<int64_t> staticOffsets;
134  llvm::SmallVector<Value> dynamicOffsets;
135  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
136 
137  build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
138  ValueRange({}) /* empty dynamic shape */,
139  ValueRange({}) /* empty dynamic strides */,
140  staticOffsets /* const offsets */, {} /* empty const shape*/,
141  {} /* empty const strides*/);
142 }
143 
144 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
145  Type tdesc, Value source,
149  assert(shape.size() && offsets.size() && strides.size() &&
150  shape.size() == strides.size() && shape.size() == offsets.size());
151 
152  Type srcTy = source.getType();
153  assert(isa<IntegerType>(srcTy) ||
154  isa<MemRefType>(srcTy) && "Source has to be either int or memref.");
155 
156  llvm::SmallVector<Value> dynamicOffsets;
157  llvm::SmallVector<Value> dynamicShape;
158  llvm::SmallVector<Value> dynamicStrides;
159 
160  llvm::SmallVector<int64_t> staticOffsets;
161  llvm::SmallVector<int64_t> staticShape;
162  llvm::SmallVector<int64_t> staticStrides;
163 
164  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
165  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
166  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
167 
168  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
169  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
170  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
171 
172  if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
173  auto memrefShape = memrefTy.getShape();
174  auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
175 
176  // if shape and strides are from Memref, we don't need attributes for them
177  // to keep the IR print clean.
178  if (staticShape == memrefShape && staticStrides == memrefStrides) {
179  staticShapeAttr = DenseI64ArrayAttr();
180  staticStridesAttr = DenseI64ArrayAttr();
181  }
182  }
183 
184  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
185  dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
186 }
187 
188 LogicalResult CreateNdDescOp::verify() {
189  auto rank = (int64_t)getMixedOffsets().size();
190  bool invalidRank = false;
191  bool invalidElemTy = false;
192 
193  // Memory space of created TensorDesc should match with the source.
194  // Both source and TensorDesc are considered for global memory by default,
195  // if the memory scope attr is not specified. If source is an integer,
196  // it is considered as ptr to global memory.
197  auto srcMemorySpace = getSourceMemorySpace();
198  auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
199  if (srcMemorySpace != tdescMemorySpace)
200  return emitOpError("Memory space mismatch.")
201  << " Source: " << srcMemorySpace
202  << ", TensorDesc: " << tdescMemorySpace;
203 
204  // check source type matches the rank if it is a memref.
205  // It also should have the same ElementType as TensorDesc.
206  auto memrefTy = dyn_cast<MemRefType>(getSourceType());
207  if (memrefTy) {
208  invalidRank |= (memrefTy.getRank() != rank);
209  invalidElemTy |= memrefTy.getElementType() != getElementType();
210  }
211 
212  // mismatches among shape, strides, and offsets are
213  // already handeled by OffsetSizeAndStrideOpInterface.
214  // So they are not check here.
215  if (invalidRank)
216  return emitOpError(
217  "Expecting the rank of shape, strides, offsets, and source (if source "
218  "is a memref) should match with each other.");
219 
220  // check result TensorDesc rank
221  invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
222 
223  if (invalidRank)
224  return emitOpError(
225  "Expecting the TensorDesc rank is up to 2 and not greater than the "
226  "ranks of shape, strides, offsets or the memref source.");
227 
228  if (invalidElemTy)
229  return emitOpError("TensorDesc should have the same element "
230  "type with the source if it is a memref.\n");
231 
232  if (getType().isScattered())
233  return emitOpError("Expects a non-scattered TensorDesc.\n");
234 
235  return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // XeGPU_PrefetchNdOp
240 //===----------------------------------------------------------------------===//
241 LogicalResult PrefetchNdOp::verify() {
242  auto tdescTy = getTensorDescType();
243  if (tdescTy.isScattered())
244  return emitOpError("Expects a non-scattered TensorDesc.\n");
245 
246  if (!isReadHintOrNone(getL1HintAttr()))
247  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
248 
249  if (!isReadHintOrNone(getL2HintAttr()))
250  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
251 
252  if (!isReadHintOrNone(getL3HintAttr()))
253  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
254 
255  return success();
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // XeGPU_LoadNdOp
260 //===----------------------------------------------------------------------===//
261 LogicalResult LoadNdOp::verify() {
262  auto tdescTy = getTensorDescType();
263  auto valueTy = getType();
264 
265  if (tdescTy.getRank() > 2)
266  return emitOpError("Expecting a 1D/2D TensorDesc.\n");
267 
268  if (tdescTy.isScattered())
269  return emitOpError("Expects a non-scattered TensorDesc.\n");
270 
271  if (!valueTy)
272  return emitOpError("Invalid result, it should be a VectorType.\n");
273 
274  if (!isReadHintOrNone(getL1HintAttr()))
275  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
276 
277  if (!isReadHintOrNone(getL2HintAttr()))
278  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
279 
280  if (!isReadHintOrNone(getL3HintAttr()))
281  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
282 
283  int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
284  int valueElems = valueTy.getNumElements();
285 
286  // If the result vector is 1D and has less elements than the tensor
287  // descriptor, it is supposed to be a SIMT op. The layout attribute in
288  // tensor_desc is not needed.
289  if (valueElems < tdescElems && valueTy.getRank() == 1) {
290  // SIMT mode doesn't need LayoutAttr.
291  if (tdescTy.getLayoutAttr())
292  return emitOpError()
293  << "TensorDesc doesn't need LayoutAttr for SIMT code";
294 
295  // For SIMT code, the load is evenly distributed across all lanes in a
296  // subgroup. Since subgroup size is arch dependent, we only check even
297  // distribution here.
298  if (tdescElems % valueElems)
299  return emitOpError()
300  << "Result shape " << makeString(getShapeOf(valueTy))
301  << " is not a valid distribution for tensor descriptor "
302  << tdescTy;
303 
304  return success();
305  }
306 
307  // Check SIMD mode.
308  auto tdescShape = getShapeOf(tdescTy);
309  auto valueShape = getShapeOf(valueTy);
310 
311  if (getTranspose()) {
312  auto trans = getTranspose().value();
313 
314  // Make sure the transpose value is valid.
315  bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
316  return t >= 0 && t < tdescTy.getRank();
317  });
318 
319  if (valid)
320  transpose(trans, tdescShape);
321  else
322  mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
323  }
324 
325  if (getPacked()) {
326  if (tdescTy.getRank() == 2) {
327  const int axis = 0;
328  auto vnni_factor = valueShape.back();
329  tdescShape[axis] /= vnni_factor;
330  tdescShape.push_back(vnni_factor);
331  } else {
332  mlir::emitWarning(getLoc())
333  << "Invalid Packed Attr. It is ignored (available for 2D "
334  "TensorDesc only).";
335  }
336  }
337 
338  auto array_len = tdescTy.getArrayLength();
339  if (array_len > 1) {
340  tdescShape.insert(tdescShape.begin(), array_len);
341  }
342 
343  if (tdescShape != valueShape) {
344  return emitOpError() << "Result shape " << makeString(valueShape)
345  << " is not consistent with tensor descriptor "
346  << tdescTy;
347  }
348 
349  return success();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // XeGPU_StoreNdOp
354 //===----------------------------------------------------------------------===//
355 LogicalResult StoreNdOp::verify() {
356  auto dstTy = getTensorDescType(); // Tile
357  auto valTy = getValueType(); // Vector
358 
359  if (dstTy.getRank() > 2)
360  return emitOpError("Expecting a 1D/2D TensorDesc.\n");
361 
362  if (dstTy.isScattered())
363  return emitOpError("Expects a non-scattered TensorDesc.\n");
364 
365  if (!valTy)
366  return emitOpError("Expecting a VectorType result.\n");
367 
368  if (!isWriteHintOrNone(getL1HintAttr()))
369  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
370 
371  if (!isWriteHintOrNone(getL2HintAttr()))
372  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
373 
374  if (!isWriteHintOrNone(getL3HintAttr()))
375  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
376 
377  auto array_len = dstTy.getArrayLength();
378  if (array_len > 1)
379  return emitOpError("array length is not supported by store_nd.\n");
380 
381  auto tdescElems = dstTy.getNumElements();
382  auto valueElems = valTy.getNumElements();
383 
384  // Similar to LoadNdOp, if the value vector is 1D and has less elements than
385  // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
386  // in tensor_desc is not needed.
387  if (valTy.getRank() == 1 && valueElems < tdescElems) {
388  // SIMT mode doesn't need LayoutAttr.
389  if (dstTy.getLayoutAttr())
390  return emitOpError()
391  << "TensorDesc doesn't need LayoutAttr for SIMT code";
392 
393  if (tdescElems % valueElems) {
394  return emitOpError()
395  << "Value shape " << makeString(getShapeOf(valTy))
396  << " is not a valid distribution for tensor descriptor " << dstTy;
397  }
398  return success();
399  }
400 
401  // SIMD code should have the same shape as the tensor descriptor.
402  auto tdescShape = getShapeOf(dstTy);
403  auto valueShape = getShapeOf(valTy);
404  if (tdescShape != valueShape) {
405  return emitOpError() << "Value shape " << makeString(valueShape)
406  << " is not consistent with tensor descriptor "
407  << dstTy;
408  }
409 
410  return success();
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // XeGPU_UpdateNDOffsetOp
415 //===----------------------------------------------------------------------===//
416 LogicalResult UpdateNdOffsetOp::verify() {
417  auto ty = getTensorDescType();
418  if (ty.isScattered())
419  return emitOpError("Expects a non-scattered TensorDesc.\n");
420 
421  // number of offsets specified must match the rank of the tensor descriptor
422  if (ty.getRank() != (int64_t)getNumOffsets()) {
423  return emitOpError("Invalid number of offsets.");
424  }
425  return success();
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // XeGPU_CreateDescOp
430 //===----------------------------------------------------------------------===//
431 
432 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
433  TensorDescType TensorDesc, Value source,
435  auto loc = source.getLoc();
436  int64_t size = static_cast<int64_t>(offsets.size());
437  auto type = VectorType::get(size, builder.getIndexType());
438  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
439  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
440  build(builder, state, TensorDesc, source, offset);
441 }
442 
443 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
444  TensorDescType TensorDesc, Value source,
445  llvm::ArrayRef<int64_t> offsets) {
446  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
447  build(builder, state, TensorDesc, source, ofrs);
448 }
449 
450 LogicalResult CreateDescOp::verify() {
451  auto tdescTy = getTensorDescType();
452 
453  if (getRankOf(getSource()) > 1)
454  return emitOpError(
455  "Expecting the source is a 1D memref or pointer (uint64_t).");
456 
457  if (!tdescTy.isScattered())
458  return emitOpError("Expects a scattered TensorDesc.\n");
459 
460  // Memory space of created TensorDesc should match with the source.
461  // Both source and TensorDesc are considered for global memory by default,
462  // if the memory scope attr is not specified. If source is an integer,
463  // it is considered as ptr to global memory.
464  auto srcMemorySpace = getSourceMemorySpace();
465  auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
466  if (srcMemorySpace != tdescMemorySpace)
467  return emitOpError("Memory space mismatch.")
468  << " Source: " << srcMemorySpace
469  << ", TensorDesc: " << tdescMemorySpace;
470 
471  // check total size
472  auto chunkSize = tdescTy.getChunkSize();
473  auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
474  auto bitsPerLane = elemBits * chunkSize;
475  if (chunkSize > 1 && bitsPerLane % 32) {
476  // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
477  // For 32-bit data, the hardware can support larger larger chunk size. So
478  // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
479  // But this requires the total size is 32 bit aligned to make the
480  // optimization work.
481  return emitOpError(
482  "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
483  }
484 
485  auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
486  if (elemBits * tdescTy.getNumElements() > lscConstraints)
487  return emitOpError("total access size (simd_lanes * chunk_size * "
488  "sizeof(elemTy)) is upto 512 bytes.");
489 
490  SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
491  if (chunkSize != 1)
492  shape.push_back(chunkSize);
493 
494  auto tdescShape = getShapeOf(tdescTy);
495  if (shape != tdescShape)
496  return emitOpError("Incorrect TensorDesc shape. ")
497  << "Expected is " << makeString(shape) << "\n";
498 
499  return success();
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // XeGPU_PrefetchOp
504 //===----------------------------------------------------------------------===//
505 LogicalResult PrefetchOp::verify() {
506  auto tdescTy = getTensorDescType();
507  if (!tdescTy.isScattered())
508  return emitOpError("Expects a scattered TensorDesc.\n");
509 
510  if (!isReadHintOrNone(getL1HintAttr()))
511  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
512 
513  if (!isReadHintOrNone(getL2HintAttr()))
514  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
515 
516  if (!isReadHintOrNone(getL3HintAttr()))
517  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
518 
519  return success();
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // XeGPU_LoadGatherOp
524 //===----------------------------------------------------------------------===//
525 LogicalResult LoadGatherOp::verify() {
526  auto tdescTy = getTensorDescType();
527  auto maskTy = getMaskType();
528  auto valueTy = getValueType();
529 
530  if (!isReadHintOrNone(getL1HintAttr()))
531  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
532 
533  if (!isReadHintOrNone(getL2HintAttr()))
534  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
535 
536  if (!isReadHintOrNone(getL3HintAttr()))
537  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
538 
539  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
540  getTransposeAttr(),
541  [&]() { return emitOpError(); });
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // XeGPU_StoreScatterOp
546 //===----------------------------------------------------------------------===//
547 LogicalResult StoreScatterOp::verify() {
548  auto tdescTy = getTensorDescType();
549  auto maskTy = getMaskType();
550  auto valueTy = getValueType();
551 
552  if (!isWriteHintOrNone(getL1HintAttr()))
553  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
554 
555  if (!isWriteHintOrNone(getL2HintAttr()))
556  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
557 
558  if (!isWriteHintOrNone(getL3HintAttr()))
559  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
560 
561  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
562  getTransposeAttr(),
563  [&]() { return emitOpError(); });
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // XeGPU_UpdateOffsetOp
568 //===----------------------------------------------------------------------===//
569 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
570  mlir::Value tensorDesc,
572  auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
573  assert(tdescTy && "Expecting the source is a TensorDescType value.");
574  auto loc = tensorDesc.getLoc();
575  int64_t size = static_cast<int64_t>(offsets.size());
576  auto type = VectorType::get({size}, builder.getIndexType());
577  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
578  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
579  build(builder, state, tdescTy, tensorDesc, offset);
580 }
581 
582 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
583  Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
584  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
585  build(builder, state, tensorDesc, ofrs);
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // XeGPU_DpasOp
590 //===----------------------------------------------------------------------===//
591 LogicalResult DpasOp::verify() {
592  int64_t lhsRank = getLhsType().getRank();
593  int64_t rhsRank = getRhsType().getRank();
594  int64_t resRank = getResultType().getRank();
595  auto lhsShape = getLhsType().getShape();
596  auto rhsShape = getRhsType().getShape();
597  auto resShape = getResultType().getShape();
598 
599  if (getAcc() && getAcc().getType() != getResultType())
600  return emitOpError("Expecting the acc type to be the same as result.");
601 
602  // SIMT code: the size of the B operand has to be a multiple of 32 bits.
603  // It skips the semantic check since lack of architecture information.
604  // Users need to ensure the correctness.
605  if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
606  auto numElems = getRhsType().getNumElements();
607  auto elemTy = getRhsType().getElementType();
608  auto factor = 32 / elemTy.getIntOrFloatBitWidth();
609  if (numElems % factor != 0)
610  return emitOpError("Expecting B operand to be a multiple of 32 bits.");
611  return success();
612  }
613 
614  // SIMD code
615  if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
616  return emitOpError(
617  "expecting lhs and result to be a 2D vector, and rhs to be either "
618  "2D or 3D (packed) vector.");
619  auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
620  if (bK != lhsShape[1])
621  return emitOpError("K-dimension mismatch.");
622  if (lhsShape[0] != resShape[0])
623  return emitOpError("M-dimension mismatch.");
624  if (rhsShape[1] != resShape[1])
625  return emitOpError("N-dimension mismatch.");
626 
627  return success();
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // XeGPU_ConvertLayoutOp
632 //===----------------------------------------------------------------------===//
633 LogicalResult ConvertLayoutOp::verify() {
634  auto srcMap = getSrcMapAttr();
635  auto resMap = getResMapAttr();
636  if (!srcMap)
637  return emitOpError("expected srcMap.");
638  if (!resMap)
639  return emitOpError("expected resMap.");
640 
641  if (srcMap == resMap)
642  return emitOpError("expected different srcMap and resMap.");
643 
644  // both srcMap and resMap should be WgLayout or SgLayout at the same time.
645  if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
646  (!srcMap.isSgLayout() || !resMap.isSgLayout()))
647  return emitOpError(
648  "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
649 
650  auto shape = getSource().getType().getShape();
651  if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
652  return emitOpError("invalid srcMap, data cannot be evenly distributed.");
653 
654  if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
655  return emitOpError("invalid resMap, data cannot be evenly distributed.");
656 
657  return mlir::success();
658 }
659 
660 } // namespace xegpu
661 } // namespace mlir
662 
663 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
664 #define GET_OP_CLASSES
665 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class helps build Operations.
Definition: Builders.h:204
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
@ Type
An inlay hint that for a type annotation.
static std::string makeString(T array, bool breakline=false)
Definition: XeGPUOps.cpp:31
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref< InFlightDiagnostic()> emitError)
Definition: XeGPUOps.cpp:78
static int64_t getRankOf(Value val)
Definition: XeGPUOps.cpp:54
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:61
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
static SmallVector< int64_t > getShapeOf(Type type)
Definition: XeGPUOps.cpp:45
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.