15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "xegpu"
25 for (
size_t i = 0; i < trans.size(); i++)
26 shape[i] = old[trans[i]];
30 static std::string
makeString(T array,
bool breakline =
false) {
33 llvm::raw_string_ostream os(buf);
35 for (
size_t i = 1; i < array.size(); i++) {
36 os << array[i - 1] <<
", ";
40 os << array.back() <<
"]";
46 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
55 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
63 auto kind = attr.getValue();
64 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
65 kind == CachePolicy::STREAMING ||
kind == CachePolicy::READ_INVALIDATE;
71 auto kind = attr.getValue();
72 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
73 kind == CachePolicy::WRITE_BACK ||
kind == CachePolicy::WRITE_THROUGH;
78 TensorDescType tdescTy, UnitAttr transposeAttr,
81 if (!tdescTy.isScattered())
82 return emitError() <<
"Expects a scattered TensorDesc.";
85 return emitError() <<
"Expecting a vector type result.";
90 auto chunkSize = tdescTy.getChunkSize();
92 if (valueTy.getElementType() != tdescTy.getElementType())
94 <<
"Value should have the same element type as TensorDesc.";
96 if (tdescShape[0] != maskShape[0])
98 <<
"dim-0 of the Mask and TensorDesc should be the same.";
101 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
102 if (tdescTy.getLayoutAttr())
103 return emitError() <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
105 return emitError() <<
"doesn't need TransposeAttr for SIMT code";
109 if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
111 return emitError() <<
"rank-2 tensor has to be transposed.";
115 if (tdescShape != valueShape)
117 <<
" is neither a valid distribution for SIMT nor "
118 "consistent with the tensor descriptor for SIMD "
129 [[maybe_unused]]
auto ty = source.getType();
130 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
136 build(builder, state, tdesc, source, dynamicOffsets ,
143 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
144 Type tdesc, TypedValue<MemRefType> source,
148 assert(shape.size() && offsets.size() && strides.size() &&
149 shape.size() == strides.size() && shape.size() == offsets.size());
162 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
163 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
164 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
166 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
167 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
170 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
171 Type tdesc, TypedValue<IntegerType> source,
175 assert(shape.size() && offsets.size() && strides.size() &&
176 shape.size() == strides.size() && shape.size() == offsets.size());
189 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
190 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
191 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
193 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
194 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
198 auto rank = (int64_t)getMixedOffsets().size();
199 bool invalidRank =
false;
200 bool invalidElemTy =
false;
206 auto srcMemorySpace = getSourceMemorySpace();
207 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
208 if (srcMemorySpace != tdescMemorySpace)
209 return emitOpError(
"Memory space mismatch.")
210 <<
" Source: " << srcMemorySpace
211 <<
", TensorDesc: " << tdescMemorySpace;
215 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
217 invalidRank |= (memrefTy.getRank() != rank);
226 "Expecting the rank of shape, strides, offsets, and source (if source "
227 "is a memref) should match with each other.");
230 invalidRank = (
getType().getRank() > 2 ||
getType().getRank() > rank);
234 "Expecting the TensorDesc rank is up to 2 and not greater than the "
235 "ranks of shape, strides, offsets or the memref source.");
238 return emitOpError(
"TensorDesc should have the same element "
239 "type with the source if it is a memref.\n");
242 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
251 auto tdescTy = getTensorDescType();
252 if (tdescTy.isScattered())
253 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
256 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
259 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
262 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
271 auto tdescTy = getTensorDescType();
274 if (tdescTy.getRank() > 2)
275 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
277 if (tdescTy.isScattered())
278 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
281 return emitOpError(
"Invalid result, it should be a VectorType.\n");
284 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
287 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
290 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
292 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
293 int valueElems = valueTy.getNumElements();
298 if (valueElems < tdescElems && valueTy.getRank() == 1) {
300 if (tdescTy.getLayoutAttr())
302 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
307 if (tdescElems % valueElems)
310 <<
" is not a valid distribution for tensor descriptor "
320 if (getTranspose()) {
321 auto trans = getTranspose().value();
324 bool valid = std::all_of(trans.begin(), trans.end(), [&](
int t) {
325 return t >= 0 && t < tdescTy.getRank();
335 if (tdescTy.getRank() == 2) {
337 auto vnni_factor = valueShape.back();
338 tdescShape[axis] /= vnni_factor;
339 tdescShape.push_back(vnni_factor);
342 <<
"Invalid Packed Attr. It is ignored (available for 2D "
347 auto array_len = tdescTy.getArrayLength();
349 tdescShape.insert(tdescShape.begin(), array_len);
352 if (tdescShape != valueShape) {
353 return emitOpError() <<
"Result shape " <<
makeString(valueShape)
354 <<
" is not consistent with tensor descriptor "
365 auto dstTy = getTensorDescType();
366 auto valTy = getValueType();
368 if (dstTy.getRank() > 2)
369 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
371 if (dstTy.isScattered())
372 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
375 return emitOpError(
"Expecting a VectorType result.\n");
378 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
381 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
384 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
386 auto array_len = dstTy.getArrayLength();
388 return emitOpError(
"array length is not supported by store_nd.\n");
390 auto tdescElems = dstTy.getNumElements();
391 auto valueElems = valTy.getNumElements();
396 if (valTy.getRank() == 1 && valueElems < tdescElems) {
398 if (dstTy.getLayoutAttr())
400 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
402 if (tdescElems % valueElems) {
405 <<
" is not a valid distribution for tensor descriptor " << dstTy;
413 if (tdescShape != valueShape) {
414 return emitOpError() <<
"Value shape " <<
makeString(valueShape)
415 <<
" is not consistent with tensor descriptor "
426 auto ty = getTensorDescType();
427 if (ty.isScattered())
428 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
431 if (ty.getRank() != (int64_t)getNumOffsets()) {
432 return emitOpError(
"Invalid number of offsets.");
441 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
442 TensorDescType TensorDesc, Value source,
444 auto loc = source.getLoc();
445 int64_t size =
static_cast<int64_t
>(offsets.size());
448 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
449 build(builder, state, TensorDesc, source, offset);
452 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
453 TensorDescType TensorDesc, Value source,
456 build(builder, state, TensorDesc, source, ofrs);
460 auto tdescTy = getTensorDescType();
464 "Expecting the source is a 1D memref or pointer (uint64_t).");
466 if (!tdescTy.isScattered())
467 return emitOpError(
"Expects a scattered TensorDesc.\n");
473 auto srcMemorySpace = getSourceMemorySpace();
474 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
475 if (srcMemorySpace != tdescMemorySpace)
476 return emitOpError(
"Memory space mismatch.")
477 <<
" Source: " << srcMemorySpace
478 <<
", TensorDesc: " << tdescMemorySpace;
481 auto chunkSize = tdescTy.getChunkSize();
482 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
483 auto bitsPerLane = elemBits * chunkSize;
484 if (chunkSize > 1 && bitsPerLane % 32) {
491 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
494 auto lscConstraints = 512 * 8;
495 if (elemBits * tdescTy.getNumElements() > lscConstraints)
496 return emitOpError(
"total access size (simd_lanes * chunk_size * "
497 "sizeof(elemTy)) is upto 512 bytes.");
499 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
501 shape.push_back(chunkSize);
504 if (shape != tdescShape)
505 return emitOpError(
"Incorrect TensorDesc shape. ")
506 <<
"Expected is " <<
makeString(shape) <<
"\n";
515 auto tdescTy = getTensorDescType();
516 if (!tdescTy.isScattered())
517 return emitOpError(
"Expects a scattered TensorDesc.\n");
520 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
523 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
526 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
535 auto tdescTy = getTensorDescType();
536 auto maskTy = getMaskType();
537 auto valueTy = getValueType();
540 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
543 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
546 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
550 [&]() {
return emitOpError(); });
557 auto tdescTy = getTensorDescType();
558 auto maskTy = getMaskType();
559 auto valueTy = getValueType();
562 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
565 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
568 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
572 [&]() {
return emitOpError(); });
578 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
581 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
582 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
583 auto loc = tensorDesc.
getLoc();
584 int64_t size =
static_cast<int64_t
>(offsets.size());
587 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
588 build(builder, state, tdescTy, tensorDesc, offset);
591 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
594 build(builder, state, tensorDesc, ofrs);
601 int64_t lhsRank = getLhsType().getRank();
602 int64_t rhsRank = getRhsType().getRank();
603 int64_t resRank = getResultType().getRank();
604 auto lhsShape = getLhsType().getShape();
605 auto rhsShape = getRhsType().getShape();
606 auto resShape = getResultType().getShape();
608 if (getAcc() && getAcc().
getType() != getResultType())
609 return emitOpError(
"Expecting the acc type to be the same as result.");
614 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
615 auto numElems = getRhsType().getNumElements();
616 auto elemTy = getRhsType().getElementType();
617 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
618 if (numElems % factor != 0)
619 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
624 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
626 "expecting lhs and result to be a 2D vector, and rhs to be either "
627 "2D or 3D (packed) vector.");
628 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
629 if (bK != lhsShape[1])
630 return emitOpError(
"K-dimension mismatch.");
631 if (lhsShape[0] != resShape[0])
632 return emitOpError(
"M-dimension mismatch.");
633 if (rhsShape[1] != resShape[1])
634 return emitOpError(
"N-dimension mismatch.");
643 auto srcMap = getSrcMapAttr();
644 auto resMap = getResMapAttr();
646 return emitOpError(
"expected srcMap.");
648 return emitOpError(
"expected resMap.");
650 if (srcMap == resMap)
651 return emitOpError(
"expected different srcMap and resMap.");
654 if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
655 (!srcMap.isSgLayout() || !resMap.isSgLayout()))
657 "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
659 auto shape = getSource().getType().getShape();
660 if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
661 return emitOpError(
"invalid srcMap, data cannot be evenly distributed.");
663 if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
664 return emitOpError(
"invalid resMap, data cannot be evenly distributed.");
666 return mlir::success();
672 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
673 #define GET_OP_CLASSES
674 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
union mlir::linalg::@1192::ArityGroupAndKind::Kind kind
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static std::string makeString(T array, bool breakline=false)
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref< InFlightDiagnostic()> emitError)
static int64_t getRankOf(Value val)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.