MLIR 23.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
16
17#include "llvm/Support/Debug.h"
18
19#define DEBUG_TYPE "xegpu"
20
21using namespace mlir;
22using namespace mlir::xegpu;
23
24template <typename T>
25static std::string makeString(T array, bool breakline = false) {
26 std::string buf;
27 buf.clear();
28 llvm::raw_string_ostream os(buf);
29 os << "[";
30 for (size_t i = 1; i < array.size(); i++) {
31 os << array[i - 1] << ", ";
32 if (breakline)
33 os << "\n\t\t";
34 }
35 os << array.back() << "]";
36 return buf;
37}
38
41 if (auto ty = llvm::dyn_cast<ShapedType>(type))
42 shape = SmallVector<int64_t>(ty.getShape());
43 else
44 shape.push_back(1);
45 return shape;
46}
47
48static bool isReadHintOrNone(const CachePolicyAttr &attr) {
49 if (!attr)
50 return true;
51 auto kind = attr.getValue();
52 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
53 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
54}
55
56static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
57 if (!attr)
58 return true;
59 auto kind = attr.getValue();
60 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
61 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
62}
63
64static LogicalResult
66 VectorType valueTy, int64_t chunkSize,
68
69 auto maskVecTy = dyn_cast<VectorType>(maskTy);
70 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
71 if (!valueTy) {
72 if (chunkSize > 1)
73 return emitError() << "Expecting chunk size == 1 for scalar result";
74 if (maskVecTy || offsetsVecTy)
75 return emitError() << "Expecting scalar mask and offsets.";
76 else if (maskVecTy && offsetsVecTy)
77 return emitError() << "Expecting a vector type result.";
78 return success();
79 }
80
81 auto valueSize = valueTy.getNumElements();
82 // SIMT mode with scalar mask and offsets.
83 if (!maskVecTy && !offsetsVecTy) {
84 if (valueSize != chunkSize)
85 return emitError() << "value elements must match chunk size "
86 << chunkSize;
87 return success();
88 }
89 auto maskShape = getShapeOf(maskTy);
90 auto valueShape = getShapeOf(valueTy);
91
92 if (!maskVecTy)
93 return emitError() << "Expecting a vector type mask.";
94 int64_t maskSize = maskVecTy.getNumElements();
95
96 if (chunkSize > 1) {
97 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
98 return emitError() << "value elements must match chunk size "
99 << chunkSize;
100 } else {
101 if (valueSize != maskSize)
102 return emitError()
103 << "Mask should match value except the chunk size dim.";
104 }
105 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
106 if (maskSize == 1)
107 return success();
108 if (chunkSize > 1)
109 expectedMaskShape.pop_back();
110 if (expectedMaskShape != maskShape)
111 return emitError() << "Mask should match value except the chunk size dim.";
112
113 return success();
114}
115
116LogicalResult
117IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
118 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
120
121 if (!dataTy) {
122 if (subgroup_block_io)
123 return emitError() << "subgroup_block_io "
124 "are only allowed when result is a VectorType.";
125 else
126 return success();
127 }
128
129 ArrayRef<int64_t> dataShape = dataTy.getShape();
130 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
131
132 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
133 ArrayAttr strideAttr = mdescTy.getStrideAttr();
134 SmallVector<int64_t> strides;
135 for (Attribute attr : strideAttr.getValue()) {
136 strides.push_back(cast<IntegerAttr>(attr).getInt());
137 }
138 if (subgroup_block_io && layout) {
139 auto laneData = layout.getEffectiveLaneDataAsInt();
140 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
141 if (!laneData.empty()) {
142 bool isLaneDataContiguous =
143 std::all_of(laneData.begin(), std::prev(laneData.end()),
144 [](int x) { return x == 1; });
145 if (!isLaneDataContiguous)
146 return emitError() << "With subgroup_block_io, accessed data must be "
147 "contiguous and coalesced.";
148 for (size_t i = 0; i < laneData.size(); ++i) {
149 if (laneLayout[i] != blockShape[i])
150 return emitError() << "With subgroup_block_io, the block shape must "
151 "match the lane layout.";
152 if (laneLayout[i] != 1 && strides[i] != 1)
153 return emitError() << "With subgroup_block_io, the distributed "
154 "dimensions must be contiguous.";
155 }
156 }
157 }
158
159 if (layout && !layout.isDistributable(
160 SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
161 return emitError() << "Value shape is not distributable with the layout";
162
163 if (dataShape.size() == mdescShape.size()) {
164 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
165 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
166 return emitError() << "data shape must not exceed mem_desc shape.";
167 }
168 // if the subgroup_block_io attribute is set, mdescTy must have block
169 // attribute
170 if (subgroup_block_io && !blockShape.size())
171 return emitError() << "mem_desc must have block attribute when "
172 "subgroup_block_io is set.";
173 return success();
174}
175
176//===----------------------------------------------------------------------===//
177// XeGPU_CreateNdDescOp
178//===----------------------------------------------------------------------===//
179
180void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
181 Type tdesc, TypedValue<MemRefType> source) {
182 [[maybe_unused]] auto ty = source.getType();
183 assert(ty.hasStaticShape() && "expecting a memref with static shape");
184
185 build(builder, state, tdesc, source, ValueRange({}) /* empty dynamic shape */,
186 ValueRange({}) /* empty dynamic strides */,
187 DenseI64ArrayAttr({}) /* empty const shape*/,
188 DenseI64ArrayAttr({}) /* empty const strides*/);
189}
190
191void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
192 Type tdesc, Value source,
195 Type srcTy = source.getType();
196 assert((isa<IntegerType, MemRefType>(srcTy)) &&
197 "Source has to be either int or memref.");
198
199 llvm::SmallVector<Value> dynamicShape;
200 llvm::SmallVector<Value> dynamicStrides;
201
202 llvm::SmallVector<int64_t> staticShape;
203 llvm::SmallVector<int64_t> staticStrides;
204
205 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
206 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
207
208 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
209 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
210
211 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
212 auto memrefShape = memrefTy.getShape();
213 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
214
215 // if shape and strides are from Memref, we don't need attributes for them
216 // to keep the IR print clean (only do so for full-static case, otherwise
217 // printer would fail trying to print empty array-attr).
218 if (staticShape == memrefShape && staticStrides == memrefStrides &&
219 dynamicShape.empty() && dynamicStrides.empty()) {
220 staticShapeAttr = DenseI64ArrayAttr();
221 staticStridesAttr = DenseI64ArrayAttr();
222 }
223 }
224
225 build(builder, state, tdesc, source, dynamicShape, dynamicStrides,
226 staticShapeAttr, staticStridesAttr);
227}
228
229LogicalResult CreateNdDescOp::verify() {
230 size_t rank = getMixedSizes().size();
231 bool invalidRank = rank != getMixedStrides().size();
232 bool invalidElemTy = false;
233
234 // Memory space of created TensorDesc should match with the source.
235 // Both source and TensorDesc are considered for global memory by default,
236 // if the memory scope attr is not specified. If source is an integer,
237 // it is considered as ptr to global memory.
238 auto srcMemorySpace = getSourceMemorySpace();
239 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
240 if (srcMemorySpace != tdescMemorySpace)
241 return emitOpError("Memory space mismatch.")
242 << " Source: " << srcMemorySpace
243 << ", TensorDesc: " << tdescMemorySpace;
244
245 // check source type matches the rank if it is a memref.
246 // It also should have the same ElementType as TensorDesc.
247 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
248 invalidElemTy |= memrefTy.getElementType() != getElementType();
249
250 if (llvm::isa<IntegerType>(getSourceType())) {
251 // strides and shape must present for integer source.
252 if (getMixedStrides().empty() || getMixedSizes().empty())
253 return emitOpError("expecting strides and shape to be present for "
254 "integer source.");
255 }
256
257 if (invalidRank)
258 return emitOpError(
259 "Expecting the rank of shape, strides, and source (if source "
260 "is a memref) should match with each other.");
261
262 // check result TensorDesc rank
263 if (getType().getRank() > (int64_t)rank)
264 return emitOpError("Expecting the TensorDesc rank is not greater than the "
265 "ranks of shape, strides or the memref source.");
266
267 if (invalidElemTy)
268 return emitOpError("TensorDesc should have the same element "
269 "type with the source if it is a memref.\n");
270
271 return success();
272}
273
274//===----------------------------------------------------------------------===//
275// XeGPU_PrefetchNdOp
276//===----------------------------------------------------------------------===//
277
278void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
279 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
280 xegpu::CachePolicyAttr l1_hint,
281 xegpu::CachePolicyAttr l2_hint,
282 xegpu::CachePolicyAttr l3_hint,
283 xegpu::DistributeLayoutAttr layout) {
284 SmallVector<Value> dynamicOffsets;
285 SmallVector<int64_t> staticOffsets;
286 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
287
288 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
289
290 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
291 l2_hint, l3_hint, /*anchor_layout=*/layout);
292}
293
294LogicalResult PrefetchNdOp::verify() {
295 auto tdescTy = getTensorDescType();
296
297 if (!isReadHintOrNone(getL1HintAttr()))
298 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
299
300 if (!isReadHintOrNone(getL2HintAttr()))
301 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
302
303 if (!isReadHintOrNone(getL3HintAttr()))
304 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
305
306 int64_t tDescRank = tdescTy.getRank();
307 int64_t offsetSize = getMixedOffsets().size();
308 if (offsetSize != tDescRank)
309 return emitOpError(
310 "Mismatched ranks between offsets and tensor descriptor");
311
312 if (auto layout = getAnchorLayout()) {
313 if (!layout.isDistributable(getShapeOf(tdescTy)))
314 return emitOpError(
315 "TensorDesc shape is not distributable with the layout");
316 }
317
318 return success();
319}
320
321//===----------------------------------------------------------------------===//
322// XeGPU_LoadNdOp
323//===----------------------------------------------------------------------===//
324
325void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
326 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
327 UnitAttr packed, DenseI64ArrayAttr transpose,
328 xegpu::CachePolicyAttr l1_hint,
329 xegpu::CachePolicyAttr l2_hint,
330 xegpu::CachePolicyAttr l3_hint,
331 xegpu::DistributeLayoutAttr layout) {
332 SmallVector<Value> dynamicOffsets;
333 SmallVector<int64_t> staticOffsets;
334 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
335
336 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
337
338 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
339 packed, transpose, l1_hint, l2_hint, l3_hint,
340 /*anchor_layout=*/layout);
341}
342
343LogicalResult LoadNdOp::verify() {
344 auto tdescTy = getTensorDescType();
345 auto valueTy = getType();
346
347 if (tdescTy.getRank() > 2)
348 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
349
350 if (!valueTy)
351 return emitOpError("Invalid result, it should be a VectorType.\n");
352
353 if (!isReadHintOrNone(getL1HintAttr()))
354 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
355
356 if (!isReadHintOrNone(getL2HintAttr()))
357 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
358
359 if (!isReadHintOrNone(getL3HintAttr()))
360 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
361
362 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
363 int valueElems = valueTy.getNumElements();
364
365 // If the result vector is 1D and has less elements than the tensor
366 // descriptor, it is supposed to be a SIMT op. The layout attribute in
367 // tensor_desc is not needed.
368 if (valueElems < tdescElems && valueTy.getRank() == 1) {
369 // SIMT mode doesn't need LayoutAttr.
370 if (tdescTy.getLayoutAttr())
371 return emitOpError()
372 << "TensorDesc doesn't need LayoutAttr for SIMT code";
373
374 // For SIMT code, the load is evenly distributed across all lanes in a
375 // subgroup. Since subgroup size is arch dependent, we only check even
376 // distribution here.
377 if (tdescElems % valueElems)
378 return emitOpError()
379 << "Result shape " << makeString(getShapeOf(valueTy))
380 << " is not a valid distribution for tensor descriptor "
381 << tdescTy;
382
383 return success();
384 }
385
386 // Check SIMD mode.
387 auto tdescShape = getShapeOf(tdescTy);
388 auto valueShape = getShapeOf(valueTy);
389
390 if (getTranspose()) {
391 auto trans = getTranspose().value();
392 // Make sure the transpose value is valid, and apply it
393 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
394 tdescShape = applyPermutation(tdescShape, trans);
395 else
396 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
397 }
398
399 if (getPacked()) {
400 if (tdescTy.getRank() == 2) {
401 const int axis = 0;
402 auto vnni_factor = valueShape.back();
403 tdescShape[axis] /= vnni_factor;
404 tdescShape.push_back(vnni_factor);
405 } else {
406 mlir::emitWarning(getLoc())
407 << "Invalid Packed Attr. It is ignored (available for 2D "
408 "TensorDesc only).";
409 }
410 }
411
412 // Handle array_length. Two result shape conventions are accepted:
413 // * 3D shape: leading array_length dimension prepended, e.g. descriptor
414 // 16x16 with array_length=2 -> [2, 16, 16].
415 // * Stacked 2D shape: array blocks stacked along the non-FCD (first)
416 // dimension, e.g. descriptor 16x16 with array_length=2 -> [32, 16].
417 auto array_len = tdescTy.getArrayLength();
418 SmallVector<int64_t> stacked2DShape(tdescShape);
419 SmallVector<int64_t> threeDShape(tdescShape);
420 if (array_len > 1 && !tdescShape.empty()) {
421 stacked2DShape[0] *= array_len;
422 threeDShape.insert(threeDShape.begin(), array_len);
423 }
424
425 if (valueShape != stacked2DShape && valueShape != threeDShape)
426 return emitOpError() << "Result shape " << makeString(valueShape)
427 << " is not consistent with tensor descriptor "
428 << tdescTy;
429
430 int64_t tDescRank = tdescTy.getRank();
431 int64_t offsetSize = getMixedOffsets().size();
432 if (offsetSize != tDescRank)
433 return emitOpError(
434 "Mismatched ranks between offsets and tensor descriptor");
435
436 if (auto layout = getAnchorLayout()) {
437 if (!layout.isDistributable(getShapeOf(tdescTy)))
438 return emitOpError(
439 "TensorDesc shape is not distributable with the layout");
440 }
441
442 return success();
443}
444
445//===----------------------------------------------------------------------===//
446// XeGPU_StoreNdOp
447//===----------------------------------------------------------------------===//
448
449void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
450 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
451 xegpu::CachePolicyAttr l1_hint,
452 xegpu::CachePolicyAttr l2_hint,
453 xegpu::CachePolicyAttr l3_hint,
454 xegpu::DistributeLayoutAttr layout) {
455 SmallVector<Value> dynamicOffsets;
456 SmallVector<int64_t> staticOffsets;
457 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
458
459 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
460
461 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
462 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
463}
464
465LogicalResult StoreNdOp::verify() {
466 auto dstTy = getTensorDescType(); // Tile
467 auto valTy = getValueType(); // Vector
468
469 if (dstTy.getRank() > 2)
470 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
471
472 if (!valTy)
473 return emitOpError("Expecting a VectorType result.\n");
474
475 if (!isWriteHintOrNone(getL1HintAttr()))
476 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
477
478 if (!isWriteHintOrNone(getL2HintAttr()))
479 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
480
481 if (!isWriteHintOrNone(getL3HintAttr()))
482 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
483
484 auto array_len = dstTy.getArrayLength();
485 if (array_len > 1)
486 return emitOpError("array length is not supported by store_nd.\n");
487
488 auto tdescElems = dstTy.getNumElements();
489 auto valueElems = valTy.getNumElements();
490
491 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
492 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
493 // in tensor_desc is not needed.
494 if (valTy.getRank() == 1 && valueElems < tdescElems) {
495 // SIMT mode doesn't need LayoutAttr.
496 if (dstTy.getLayoutAttr())
497 return emitOpError()
498 << "TensorDesc doesn't need LayoutAttr for SIMT code";
499
500 if (tdescElems % valueElems)
501 return emitOpError()
502 << "Value shape " << makeString(getShapeOf(valTy))
503 << " is not a valid distribution for tensor descriptor " << dstTy;
504
505 return success();
506 }
507
508 // SIMD code should have the same shape as the tensor descriptor.
509 auto tdescShape = getShapeOf(dstTy);
510 auto valueShape = getShapeOf(valTy);
511 if (tdescShape != valueShape)
512 return emitOpError() << "Value shape " << makeString(valueShape)
513 << " is not consistent with tensor descriptor "
514 << dstTy;
515
516 int64_t tDescRank = dstTy.getRank();
517 int64_t offsetSize = getMixedOffsets().size();
518 if (offsetSize != tDescRank)
519 return emitOpError(
520 "Mismatched ranks between offsets and tensor descriptor");
521
522 if (auto layout = getAnchorLayout()) {
523 if (!layout.isDistributable(tdescShape))
524 return emitOpError(
525 "TensorDesc shape is not distributable with the layout");
526 }
527
528 return success();
529}
530
531//===----------------------------------------------------------------------===//
532// XeGPU_PrefetchOp
533//===----------------------------------------------------------------------===//
534LogicalResult PrefetchOp::verify() {
535 if (!isReadHintOrNone(getL1HintAttr()))
536 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
537
538 if (!isReadHintOrNone(getL2HintAttr()))
539 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
540
541 if (!isReadHintOrNone(getL3HintAttr()))
542 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
543
544 auto srcTy = getSourceType();
545 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
546 return emitOpError("offset_align_byte is required with integer source.");
547
548 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
549 return emitOpError("offset_align_byte only allowed with integer source.");
550
551 if (auto layout = getAnchorLayout()) {
552 // get the offset operand and its shape
553 auto offsetsTy = getOffsets().getType();
554 if (llvm::isa<VectorType>(offsetsTy) &&
555 !layout.isDistributable(getShapeOf(offsetsTy)))
556 return emitOpError("offset shape is not distributable with the layout");
557 }
558
559 return success();
560}
561
562//===----------------------------------------------------------------------===//
563// XeGPU_LoadGatherOp
564//===----------------------------------------------------------------------===//
565LogicalResult LoadGatherOp::verify() {
566 auto maskTy = getMaskType();
567 auto valueTy = getValueType();
568
569 if (!isReadHintOrNone(getL1HintAttr()))
570 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
571
572 if (!isReadHintOrNone(getL2HintAttr()))
573 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
574
575 if (!isReadHintOrNone(getL3HintAttr()))
576 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
577
578 auto srcTy = getSourceType();
579 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
580 auto memTy = dyn_cast<MemRefType>(srcTy);
581
582 if (memTy && (getElementType() != memTy.getElementType()))
583 return emitError() << "Value should have the same element type as MemRef.";
584
585 if (auto layout = getAnchorLayout()) {
586 if (!layout.isDistributable(getShapeOf(valueTy)))
587 return emitOpError("Value shape is not distributable with the layout");
588 }
589
590 auto offsetsTy = getOffsets().getType();
591 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
592 [&]() { return emitOpError(); });
593}
594
595void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
596 Type valueType, Value source,
597 ArrayRef<OpFoldResult> offsets, Value mask,
598 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
599 xegpu::CachePolicyAttr l2_hint,
600 xegpu::CachePolicyAttr l3_hint) {
601 auto loc = source.getLoc();
602 int64_t size = static_cast<int64_t>(offsets.size());
603 auto type = VectorType::get(size, builder.getIndexType());
604 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
605 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
606
607 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
608 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
609}
610
611void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
612 Type valueType, Value source,
613 ArrayRef<OpFoldResult> offsets, Value mask,
614 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
615 xegpu::CachePolicyAttr l2_hint,
616 xegpu::CachePolicyAttr l3_hint,
617 DistributeLayoutAttr layout) {
618 auto loc = source.getLoc();
619 int64_t size = static_cast<int64_t>(offsets.size());
620 auto type = VectorType::get(size, builder.getIndexType());
621 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
622 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
623
624 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
625 l2_hint, l3_hint, layout);
626}
627
628//===----------------------------------------------------------------------===//
629// XeGPU_StoreScatterOp
630//===----------------------------------------------------------------------===//
631LogicalResult StoreScatterOp::verify() {
632 auto maskTy = getMaskType();
633 auto valueTy = getValueType();
634
635 if (!isWriteHintOrNone(getL1HintAttr()))
636 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
637
638 if (!isWriteHintOrNone(getL2HintAttr()))
639 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
640
641 if (!isWriteHintOrNone(getL3HintAttr()))
642 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
643
644 auto destTy = getDestType();
645 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
646 auto memTy = dyn_cast<MemRefType>(destTy);
647
648 if (memTy && (getElementType() != memTy.getElementType()))
649 return emitError() << "Value should have the same element type as MemRef.";
650
651 if (auto layout = getAnchorLayout()) {
652 if (!layout.isDistributable(getShapeOf(valueTy)))
653 return emitOpError("Value shape is not distributable with the layout");
654 }
655
656 auto offsetsTy = getOffsets().getType();
657 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
658 [&]() { return emitOpError(); });
659}
660
661void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
662 Value value, Value dest,
663 ArrayRef<OpFoldResult> offsets, Value mask,
664 IntegerAttr chunk_size,
665 xegpu::CachePolicyAttr l1_hint,
666 xegpu::CachePolicyAttr l2_hint,
667 xegpu::CachePolicyAttr l3_hint) {
668 auto loc = dest.getLoc();
669 int64_t size = static_cast<int64_t>(offsets.size());
670 auto type = VectorType::get(size, builder.getIndexType());
671 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
672 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
673
674 // Call the correct builder overload that does not expect result types.
675 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
676 l3_hint, /*anchor_layout=*/nullptr);
677}
678
679void StoreScatterOp::build(
680 OpBuilder &builder, OperationState &state, Value value, Value dest,
681 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
682 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
683 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
684 auto loc = dest.getLoc();
685 int64_t size = static_cast<int64_t>(offsets.size());
686 auto type = VectorType::get(size, builder.getIndexType());
687 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
688 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
689
690 // Call the correct builder overload that does not expect result types.
691 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
692 l3_hint, layout);
693}
694
695//===----------------------------------------------------------------------===//
696// DPAS Common Verification Helpers
697//===----------------------------------------------------------------------===//
698
699// Helper to verify layout distributability for a value
700static LogicalResult
702 std::optional<DistributeLayoutAttr> layout,
703 ArrayRef<int64_t> shape, StringRef operandName) {
704 if (layout && !layout->isDistributable(
705 SmallVector<int64_t>(shape.begin(), shape.end())))
706 return op->emitOpError(operandName)
707 << " shape is not distributable with the layout";
708 return success();
709}
710
711// Helper to verify M, N, K dimensions match between A, B, and result matrices
712static LogicalResult verifyDpasDimensions(Operation *op,
713 ArrayRef<int64_t> aShape,
714 ArrayRef<int64_t> bShape,
715 ArrayRef<int64_t> resShape) {
716
717 auto aRank = aShape.size();
718 auto bRank = bShape.size();
719 auto resRank = resShape.size();
720 if (aRank == 1 && bRank == 1 && resRank == 1)
721 return success();
722
723 // Validate A and B are 2D
724 if (aRank != 2)
725 return op->emitOpError("A operand must be a 2D vector.");
726 if (bRank < 2 || bRank > 3)
727 return op->emitOpError("B operand must be a 2D or 3D vector.");
728 if (resRank != 2)
729 return op->emitOpError("Result must be a 2D vector.");
730
731 // Calculate effective K dimension for B (handle 3D packed case)
732 int64_t bK = bRank == 3 ? bShape[0] * bShape[2] : bShape[0];
733
734 // Verify K dimension match between A and B
735 if (bK != aShape[1])
736 return op->emitOpError("K-dimension mismatch: A has K=")
737 << aShape[1] << " but B has K=" << bK << ".";
738
739 // Verify M dimension match between A and result
740 if (aShape[0] != resShape[0])
741 return op->emitOpError("M-dimension mismatch: A has M=")
742 << aShape[0] << " but result has M=" << resShape[0] << ".";
743
744 // Verify N dimension match between B and result
745 if (bShape[1] != resShape[1])
746 return op->emitOpError("N-dimension mismatch: B has N=")
747 << bShape[1] << " but result has N=" << resShape[1] << ".";
748
749 return success();
750}
751
752// Helper to verify accumulator matches result type
753static LogicalResult verifyDpasAccumulator(Operation *op, Type accType,
754 Type resultType) {
755 if (accType != resultType)
756 return op->emitOpError("Accumulator type must match result type.");
757 return success();
758}
759
760//===----------------------------------------------------------------------===//
761// XeGPU_DpasOp
762//===----------------------------------------------------------------------===//
763LogicalResult DpasOp::verify() {
764 auto lhsShape = getLhsType().getShape();
765 auto rhsShape = getRhsType().getShape();
766 auto resShape = getResultType().getShape();
767
768 // Verify layout distributability
769 if (failed(
770 verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
771 return failure();
772 if (failed(verifyLayoutDistributable(*this, getLayoutA(), lhsShape, "A")))
773 return failure();
774 if (failed(verifyLayoutDistributable(*this, getLayoutB(), rhsShape, "B")))
775 return failure();
776
777 // Verify accumulator if present
778 if (getAcc() &&
779 failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
780 return failure();
781
782 return verifyDpasDimensions(*this, lhsShape, rhsShape, resShape);
783}
784
785//===----------------------------------------------------------------------===//
786// XeGPU_ConvertLayoutOp
787//===----------------------------------------------------------------------===//
788LogicalResult ConvertLayoutOp::verify() {
789 auto srcLayout = getInputLayout();
790 auto resLayout = getTargetLayout();
791 if (!srcLayout)
792 return emitOpError("expected input layout.");
793 if (!resLayout)
794 return emitOpError("expected target layout.");
795
796 // both input and target layouts should be WgLayout or SgLayout at the same
797 // time.
798 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
799 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
800 return emitOpError("expected input layout and target layout be WgLayout or "
801 "SgLayout at the same time.");
802
803 Type srcType = getSource().getType();
804 if (llvm::isa<VectorType>(srcType)) {
805 SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
806 if (!srcLayout.isDistributable(shape))
807 return emitOpError(
808 "invalid input layout, data cannot be evenly distributed.");
809
810 if (!resLayout.isDistributable(shape))
811 return emitOpError(
812 "invalid target layout, data cannot be evenly distributed.");
813 }
814 return mlir::success();
815}
816
817//===----------------------------------------------------------------------===//
818// XeGPU_LoadMatrixOp
819//===----------------------------------------------------------------------===//
820void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
823 DistributeLayoutAttr layout) {
824 llvm::SmallVector<Value> dynamicOffsets;
825 llvm::SmallVector<int64_t> staticOffsets;
826 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
827 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
828 // Call the generated builder with all parameters (including optional ones as
829 // nullptr/empty)
830 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
831 /*subgroup_block_io=*/nullptr, layout);
832}
833
834LogicalResult LoadMatrixOp::verify() {
835
836 auto resTy = dyn_cast<VectorType>(getRes().getType());
837 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
838 MemDescType mdescTy = getMemDesc().getType();
839
840 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
841 getLayoutAttr(), [&]() { return emitError(); });
842}
843
844//===----------------------------------------------------------------------===//
845// XeGPU_StoreMatrixOp
846//===----------------------------------------------------------------------===//
847void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
850 DistributeLayoutAttr layout) {
851 llvm::SmallVector<Value> dynamicOffsets;
852 llvm::SmallVector<int64_t> staticOffsets;
853 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
854 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
855 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
856 /*subgroup_block_io=*/nullptr, layout);
857}
858
859LogicalResult StoreMatrixOp::verify() {
860
861 auto dataTy = dyn_cast<VectorType>(getData().getType());
862 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
863 MemDescType mdescTy = getMemDesc().getType();
864 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
865 getLayoutAttr(), [&]() { return emitError(); });
866}
867
868//===----------------------------------------------------------------------===//
869// XeGPU_TruncfOp
870//===----------------------------------------------------------------------===//
871
872LogicalResult TruncfOp::verify() {
873 auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
874 auto resultVecType = dyn_cast<VectorType>(getResult().getType());
875
876 if (sourceVecType.getElementTypeBitWidth() <=
877 resultVecType.getElementTypeBitWidth())
878 return emitOpError("input type must be wider than result type.");
879
880 return success();
881}
882
883//===----------------------------------------------------------------------===//
884// XeGPU_DpasMxOp
885//===----------------------------------------------------------------------===//
886
887LogicalResult DpasMxOp::verify() {
888 auto aShape = getAType().getShape();
889 auto bShape = getBType().getShape();
890 auto resShape = getResultType().getShape();
891
892 // Verify layout distributability for A, B, and result
893 if (failed(
894 verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
895 return failure();
896 if (failed(verifyLayoutDistributable(*this, getLayoutA(), aShape, "A")))
897 return failure();
898 if (failed(verifyLayoutDistributable(*this, getLayoutB(), bShape, "B")))
899 return failure();
900
901 // Verify accumulator if present
902 if (getAcc() &&
903 failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
904 return failure();
905
906 // Verify M, N, K dimensions
907 if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape)))
908 return failure();
909
910 // Validate scale_a if present
911 if (getScaleA()) {
912 auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
913 // Only validate if scale is a vector (scalars are always valid)
914 if (scaleAVecType && scaleAVecType.getRank() > 1) {
915 auto scaleAShape = scaleAVecType.getShape();
916
917 if (scaleAVecType.getRank() != 2)
918 return emitOpError("Scale A must be a 2D vector when not a scalar.");
919
920 // Verify layout distributability for scale_a
921 if (failed(verifyLayoutDistributable(*this, getLayoutAScale(),
922 scaleAShape, "ScaleA")))
923 return failure();
924
925 // Validate M dimension: scale_a[0] must match a[0]
926 if (scaleAShape[0] != aShape[0])
927 return emitOpError("Scale A M dimension [")
928 << scaleAShape[0] << "] must match A M dimension [" << aShape[0]
929 << "].";
930 }
931 }
932
933 // Validate scale_b if present
934 if (getScaleB()) {
935 auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
936 // Only validate if scale is a vector (scalars are always valid)
937 if (scaleBVecType && scaleBVecType.getRank() > 1) {
938 auto scaleBShape = scaleBVecType.getShape();
939
940 if (scaleBVecType.getRank() != 2)
941 return emitOpError("Scale B must be a 2D vector when not a scalar.");
942
943 // Verify layout distributability for scale_b
944 if (failed(verifyLayoutDistributable(*this, getLayoutBScale(),
945 scaleBShape, "ScaleB")))
946 return failure();
947
948 // Validate N dimension: scale_b[1] must match b[1]
949 if (scaleBShape[1] != bShape[1])
950 return emitOpError("Scale B N dimension [")
951 << scaleBShape[1] << "] must match B N dimension [" << bShape[1]
952 << "].";
953 }
954 }
955
956 // Validate scale K dimension compatibility if both scales are present and
957 // vectors
958 if (getScaleA() && getScaleB()) {
959 auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
960 auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
961
962 if (scaleAVecType && scaleBVecType && scaleAVecType.getRank() > 1 &&
963 scaleBVecType.getRank() > 1) {
964 auto scaleAShape = scaleAVecType.getShape();
965 auto scaleBShape = scaleBVecType.getShape();
966
967 // Validate scale K dimension compatibility: scale_a[1] must match
968 // scale_b[0]
969 if (scaleAShape[1] != scaleBShape[0])
970 return emitOpError("Scale K dimension mismatch: scale_a has K=")
971 << scaleAShape[1] << " but scale_b has K=" << scaleBShape[0]
972 << ".";
973 }
974 }
975
976 return success();
977}
978
979namespace mlir {
980#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
981} // namespace mlir
982#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
983#define GET_OP_CLASSES
984#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:791
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:39
static LogicalResult verifyDpasAccumulator(Operation *op, Type accType, Type resultType)
Definition XeGPUOps.cpp:753
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:117
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:25
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:56
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:48
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:65
static LogicalResult verifyDpasDimensions(Operation *op, ArrayRef< int64_t > aShape, ArrayRef< int64_t > bShape, ArrayRef< int64_t > resShape)
Definition XeGPUOps.cpp:712
static LogicalResult verifyLayoutDistributable(Operation *op, std::optional< DistributeLayoutAttr > layout, ArrayRef< int64_t > shape, StringRef operandName)
Definition XeGPUOps.cpp:701
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IndexType getIndexType()
Definition Builders.cpp:55
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This represents an operation in an abstracted form, suitable for use with the builder APIs.