16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "xegpu"
26 for (
size_t i = 0; i < trans.size(); i++)
27 shape[i] = old[trans[i]];
31 static std::string
makeString(T array,
bool breakline =
false) {
34 llvm::raw_string_ostream os(buf);
36 for (
size_t i = 1; i < array.size(); i++) {
37 os << array[i - 1] <<
", ";
41 os << array.back() <<
"]";
47 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
56 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING ||
kind == CachePolicy::READ_INVALIDATE;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK ||
kind == CachePolicy::WRITE_THROUGH;
79 TensorDescType tdescTy, UnitAttr transposeAttr,
82 if (!tdescTy.isScattered())
83 return emitError() <<
"Expects a scattered TensorDesc.";
86 return emitError() <<
"Expecting a vector type result.";
91 auto chunkSize = tdescTy.getChunkSize();
93 if (valueTy.getElementType() != tdescTy.getElementType())
95 <<
"Value should have the same element type as TensorDesc.";
97 if (tdescShape[0] != maskShape[0])
99 <<
"dim-0 of the Mask and TensorDesc should be the same.";
102 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
103 if (tdescTy.getLayoutAttr())
104 return emitError() <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
106 return emitError() <<
"doesn't need TransposeAttr for SIMT code";
110 if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
112 return emitError() <<
"rank-2 tensor has to be transposed.";
116 if (tdescShape != valueShape)
118 <<
" is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
130 [[maybe_unused]]
auto ty = source.getType();
131 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
137 build(builder, state, tdesc, source, dynamicOffsets ,
144 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
145 Type tdesc, Value source,
149 assert(shape.size() && offsets.size() && strides.size() &&
150 shape.size() == strides.size() && shape.size() == offsets.size());
152 Type srcTy = source.getType();
153 assert(isa<IntegerType>(srcTy) ||
154 isa<MemRefType>(srcTy) &&
"Source has to be either int or memref.");
168 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
169 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
170 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
172 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
173 auto memrefShape = memrefTy.getShape();
174 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
178 if (staticShape == memrefShape && staticStrides == memrefStrides) {
184 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
185 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
189 auto rank = (int64_t)getMixedOffsets().size();
190 bool invalidRank =
false;
191 bool invalidElemTy =
false;
197 auto srcMemorySpace = getSourceMemorySpace();
198 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
199 if (srcMemorySpace != tdescMemorySpace)
200 return emitOpError(
"Memory space mismatch.")
201 <<
" Source: " << srcMemorySpace
202 <<
", TensorDesc: " << tdescMemorySpace;
206 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
208 invalidRank |= (memrefTy.getRank() != rank);
217 "Expecting the rank of shape, strides, offsets, and source (if source "
218 "is a memref) should match with each other.");
221 invalidRank = (
getType().getRank() > 2 ||
getType().getRank() > rank);
225 "Expecting the TensorDesc rank is up to 2 and not greater than the "
226 "ranks of shape, strides, offsets or the memref source.");
229 return emitOpError(
"TensorDesc should have the same element "
230 "type with the source if it is a memref.\n");
233 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
242 auto tdescTy = getTensorDescType();
243 if (tdescTy.isScattered())
244 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
247 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
250 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
253 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
262 auto tdescTy = getTensorDescType();
265 if (tdescTy.getRank() > 2)
266 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
268 if (tdescTy.isScattered())
269 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
272 return emitOpError(
"Invalid result, it should be a VectorType.\n");
275 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
278 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
281 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
283 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
284 int valueElems = valueTy.getNumElements();
289 if (valueElems < tdescElems && valueTy.getRank() == 1) {
291 if (tdescTy.getLayoutAttr())
293 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
298 if (tdescElems % valueElems)
301 <<
" is not a valid distribution for tensor descriptor "
311 if (getTranspose()) {
312 auto trans = getTranspose().value();
315 bool valid = std::all_of(trans.begin(), trans.end(), [&](
int t) {
316 return t >= 0 && t < tdescTy.getRank();
326 if (tdescTy.getRank() == 2) {
328 auto vnni_factor = valueShape.back();
329 tdescShape[axis] /= vnni_factor;
330 tdescShape.push_back(vnni_factor);
333 <<
"Invalid Packed Attr. It is ignored (available for 2D "
338 auto array_len = tdescTy.getArrayLength();
340 tdescShape.insert(tdescShape.begin(), array_len);
343 if (tdescShape != valueShape) {
344 return emitOpError() <<
"Result shape " <<
makeString(valueShape)
345 <<
" is not consistent with tensor descriptor "
356 auto dstTy = getTensorDescType();
357 auto valTy = getValueType();
359 if (dstTy.getRank() > 2)
360 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
362 if (dstTy.isScattered())
363 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
366 return emitOpError(
"Expecting a VectorType result.\n");
369 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
372 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
375 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
377 auto array_len = dstTy.getArrayLength();
379 return emitOpError(
"array length is not supported by store_nd.\n");
381 auto tdescElems = dstTy.getNumElements();
382 auto valueElems = valTy.getNumElements();
387 if (valTy.getRank() == 1 && valueElems < tdescElems) {
389 if (dstTy.getLayoutAttr())
391 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
393 if (tdescElems % valueElems) {
396 <<
" is not a valid distribution for tensor descriptor " << dstTy;
404 if (tdescShape != valueShape) {
405 return emitOpError() <<
"Value shape " <<
makeString(valueShape)
406 <<
" is not consistent with tensor descriptor "
417 auto ty = getTensorDescType();
418 if (ty.isScattered())
419 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
422 if (ty.getRank() != (int64_t)getNumOffsets()) {
423 return emitOpError(
"Invalid number of offsets.");
432 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
433 TensorDescType TensorDesc, Value source,
435 auto loc = source.getLoc();
436 int64_t size =
static_cast<int64_t
>(offsets.size());
439 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
440 build(builder, state, TensorDesc, source, offset);
443 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
444 TensorDescType TensorDesc, Value source,
447 build(builder, state, TensorDesc, source, ofrs);
451 auto tdescTy = getTensorDescType();
455 "Expecting the source is a 1D memref or pointer (uint64_t).");
457 if (!tdescTy.isScattered())
458 return emitOpError(
"Expects a scattered TensorDesc.\n");
464 auto srcMemorySpace = getSourceMemorySpace();
465 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
466 if (srcMemorySpace != tdescMemorySpace)
467 return emitOpError(
"Memory space mismatch.")
468 <<
" Source: " << srcMemorySpace
469 <<
", TensorDesc: " << tdescMemorySpace;
472 auto chunkSize = tdescTy.getChunkSize();
473 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
474 auto bitsPerLane = elemBits * chunkSize;
475 if (chunkSize > 1 && bitsPerLane % 32) {
482 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
485 auto lscConstraints = 512 * 8;
486 if (elemBits * tdescTy.getNumElements() > lscConstraints)
487 return emitOpError(
"total access size (simd_lanes * chunk_size * "
488 "sizeof(elemTy)) is upto 512 bytes.");
490 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
492 shape.push_back(chunkSize);
495 if (shape != tdescShape)
496 return emitOpError(
"Incorrect TensorDesc shape. ")
497 <<
"Expected is " <<
makeString(shape) <<
"\n";
506 auto tdescTy = getTensorDescType();
507 if (!tdescTy.isScattered())
508 return emitOpError(
"Expects a scattered TensorDesc.\n");
511 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
514 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
517 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
526 auto tdescTy = getTensorDescType();
527 auto maskTy = getMaskType();
528 auto valueTy = getValueType();
531 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
534 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
537 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
541 [&]() {
return emitOpError(); });
548 auto tdescTy = getTensorDescType();
549 auto maskTy = getMaskType();
550 auto valueTy = getValueType();
553 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
556 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
559 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
563 [&]() {
return emitOpError(); });
569 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
572 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
573 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
574 auto loc = tensorDesc.
getLoc();
575 int64_t size =
static_cast<int64_t
>(offsets.size());
578 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
579 build(builder, state, tdescTy, tensorDesc, offset);
582 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
585 build(builder, state, tensorDesc, ofrs);
592 int64_t lhsRank = getLhsType().getRank();
593 int64_t rhsRank = getRhsType().getRank();
594 int64_t resRank = getResultType().getRank();
595 auto lhsShape = getLhsType().getShape();
596 auto rhsShape = getRhsType().getShape();
597 auto resShape = getResultType().getShape();
599 if (getAcc() && getAcc().
getType() != getResultType())
600 return emitOpError(
"Expecting the acc type to be the same as result.");
605 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
606 auto numElems = getRhsType().getNumElements();
607 auto elemTy = getRhsType().getElementType();
608 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
609 if (numElems % factor != 0)
610 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
615 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
617 "expecting lhs and result to be a 2D vector, and rhs to be either "
618 "2D or 3D (packed) vector.");
619 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
620 if (bK != lhsShape[1])
621 return emitOpError(
"K-dimension mismatch.");
622 if (lhsShape[0] != resShape[0])
623 return emitOpError(
"M-dimension mismatch.");
624 if (rhsShape[1] != resShape[1])
625 return emitOpError(
"N-dimension mismatch.");
634 auto srcMap = getSrcMapAttr();
635 auto resMap = getResMapAttr();
637 return emitOpError(
"expected srcMap.");
639 return emitOpError(
"expected resMap.");
641 if (srcMap == resMap)
642 return emitOpError(
"expected different srcMap and resMap.");
645 if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
646 (!srcMap.isSgLayout() || !resMap.isSgLayout()))
648 "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
650 auto shape = getSource().getType().getShape();
651 if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
652 return emitOpError(
"invalid srcMap, data cannot be evenly distributed.");
654 if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
655 return emitOpError(
"invalid resMap, data cannot be evenly distributed.");
657 return mlir::success();
663 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
664 #define GET_OP_CLASSES
665 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
@ Type
An inlay hint that for a type annotation.
static std::string makeString(T array, bool breakline=false)
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref< InFlightDiagnostic()> emitError)
static int64_t getRankOf(Value val)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.